-
Notifications
You must be signed in to change notification settings - Fork 4
/
datahandler.py
150 lines (113 loc) · 4.3 KB
/
datahandler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""""An interface to work with the database
Attributes:
arxivid_docid (dict): dict with arxivids and the id of the document in the db. Is used to find papers by their arxivid
db (TinyDB): The TinyDB database with all papers
DB_PATH (str): File to store the db in
serialization (DateTimeSerializer): a custom serializer to support datetime objects
"""
import numpy as np
from tinydb import TinyDB, Query
from tinydb_serialization import SerializationMiddleware
from mydate import DateTimeSerializer
import sys
##############################INITIALIZE THE DATABASE##############################
DB_PATH = 'data/db.json'
readPath=True
if readPath:
DB_PATH = open("PATHS.txt").readline().strip()
sys.stderr.write("Warning! DB_PATH is: %s\n"%(DB_PATH))
#Tinydb database with a custom serializer to support datetime objects.
serialization = SerializationMiddleware()
serialization.register_serializer(DateTimeSerializer(), 'TinyDate')
db = TinyDB(DB_PATH, storage=serialization)
def get_arxivid_docid_dict():
"""Iterates over the whole db and crates a dict with the arxivids of a paper and
the document id in the dataset.
Returns:
dict: dict with arxivids and the id of the document in the db
"""
arxivid_docid = {}
for paper in db:
arxivid_docid[paper['arxivid']] = paper.doc_id
return arxivid_docid
arxivid_docid = get_arxivid_docid_dict()
##############################FUNCTIONS TO WORK WITH THE DATABASE##############################
def add_paper(paper):
"""Adds a single paper to the database.
If a paper with the same arxiv id is already in the db, the paper will be updated and not inserted.
Args:
paper (dict): a dict containing the keys arxivid,created,citations, authors, title, abstract
"""
if paper['arxivid'] not in arxivid_docid:
did = db.insert(paper)
arxivid_docid[paper['arxivid']] = did
else:
update_paper(paper['arxivid'],paper)
def update_paper(arxivid,fields):
"""Updates a paper in the db by its arxivid
Args:
arxivid (str): the id the a paper has on arxiv
fields (dict): a dict containing a fields of the paper to update
"""
db.update(fields, doc_ids=[arxivid_docid[arxivid]])
def remove_paper(arxivid):
"""
remove one paper from the db
Args:
arxivid (str): the id the a paper has on arxiv
"""
db.remove(doc_ids=[arxivid_docid[arxivid]])
def get_paper(arxivid):
"""Returns a paper in the db
Args:
arxivid (str): the id the apaper has on arxiv
Returns:
dict: The paper as a dict
"""
return db.get(doc_id = arxivid_docid[arxivid])
def find_paper():
"""TODO:implement
"""
pass
def get_papers_after(startdate):
"""Gets all papers created after a date
Args:
startdate (datetime.datetime): startdate
Returns:
list: all papers creted after and on startdate
"""
query = Query()
return db.search(query.created >= startdate)
def get_papers_in_timewindow(startdate, enddate):
"""Gets all papers created between two dates
Args:
startdate (datetime.datetime): startdate
enddate (datetime.datetime): enddate
Returns:
list: all papers creted after and on startdate and before and on enddate
"""
query = Query()
return db.search(query.created.test(lambda d: startdate <= d <= enddate))
def get_all_papers_iterator():
"""Returns an iterator object for all papers in the dataset
Returns:
iterator: Iterator over all papers in the db
"""
return iter(db)
def get_arxivid_embedding(embedding ='infersent', textpart ='abstract'):
"""returns a dict with arxiv ids and their infersent embedding.
Only includes papers if their embeddings had been calculated for that textpart
Args:
embedding (str): infersent or unisent
textpart (str, optional): abstract or title
Returns:
dict: dict with arxiv ids and their infersent embedding
"""
arxivid_embedding = {}
with open('data/%s_%s.csv'%(embedding,textpart)) as f:
for line in f:
l = line.strip().split()
id = l[0]
embedding = np.array([float(v) for v in l[1:]])
arxivid_embedding[id] = embedding
return arxivid_embedding