-
Notifications
You must be signed in to change notification settings - Fork 0
/
es.py
executable file
·180 lines (147 loc) · 6.95 KB
/
es.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import pandas as pd
from aioelasticsearch import Elasticsearch
from utils.parser import CovidParserNew as CovidParser
import plac
from pathlib import Path
from tqdm import tqdm
from bert_serving.client import BertClient
import numpy as np
import asyncio
import logging
import traceback
import os
import spacy
logging.basicConfig()
logging.getLogger().setLevel(logging.WARN)
es_client = Elasticsearch(timeout=600)
bc = None
parsed_fp = None
async def index_documents(parsed_ids, df, index_name, valid_ids):
nlp = spacy.load("en_core_sci_sm", disable=['ner', 'tagger'])
nlp.max_length = 2000000
with open(parsed_fp, 'a+') as logfile:
for _, row in tqdm(df.iterrows(), total=len(valid_ids)):
doc = None
_id = row['cord_uid'].strip()
try:
if not valid_ids[_id] or _id in parsed_ids:
logging.warn(f'Skipping {_id} as it\'s already indexed')
continue
# Parse the document fields
doc = CovidParser.parse(row)
await asyncio.sleep(0.1)
title_exists = True
if pd.isnull(doc['title']): # Check if a title exists
logging.warn('{_id} has no title')
doc['title'] = ''
title_exists = False
# Check if abstract exists before performing code that requires it to be non-empty
abstract_exists = True
if doc['abstract'].lower() == 'Unknown' or len(doc['abstract'].strip()) == 0:
logging.warn('{_id} has no abstract')
abstract_exists = False
sentences = []
if abstract_exists:
sentences = [sent.text.strip() for sent in nlp(doc['abstract']).sents]
if title_exists:
# Batch to save memory and requests
sentences.insert(0, doc['title'])
if sentences:
await asyncio.sleep(0.1) # prepare the next batch before we submit our request
encodings = bc.encode(sentences)
if abstract_exists:
# if abstract, then title is the first element/vector
doc['title_embedding'] = encodings[0].tolist() if title_exists else [0]*768
else:
# If not abstract, the title is the only vector
doc['title_embedding'] = encodings.tolist()[0] if title_exists else [0]*768
if abstract_exists:
if title_exists:
arr = np.delete(encodings,0,0)
#doc['abstract_embedding_array'] = arr.tolist()
doc['abstract_embedding'] = np.mean(arr, axis=0).tolist()
else:
#doc['abstract_embedding_array'] = encodings.tolist()
doc['abstract_embedding'] = np.mean(encodings, axis=0).tolist()
else:
# Default to zero vector is no abstract
doc['abstract_embedding'] = [0]*768
if 'fulltext' in doc.keys() and doc['fulltext']:
fulltexts = []
await asyncio.sleep(0.1)
if len(doc['fulltext']) > 2000000:
doc['fulltext'] = doc['fulltext'][:2000000]
for sent in nlp(doc['fulltext']).sents:
text = sent.text.strip()
if text:
fulltexts.append(text)
if fulltexts:
full_embedding = bc.encode(fulltexts)
if len(full_embedding) > 5000:
full_embedding = full_embedding[:5000]
#doc['fulltext_embedding_array'] = full_embedding.tolist()
doc['fulltext_embedding'] = np.mean(np.array(full_embedding), axis=0).tolist()
assert len(doc['title_embedding']) == 768
assert len(doc['abstract_embedding']) == 768
# Embedding is the bottleneck, so we can perform multiple requests before indexing
await es_client.index(index=index_name, id=_id, body=doc)
logfile.write(f'{_id}\n')
logfile.flush()
except Exception as e:
print(traceback.format_exc())
logging.critical(f"Cannot process doc {_id}")
async def create_es_index(index_file, index_name, delete=False):
with open(index_file) as idx_file:
index_exists = await es_client.indices.exists(index_name)
if index_exists:
if delete:
await es_client.indices.delete(index_name)
os.remove(parsed_fp)
open(parsed_fp, 'w+').close() # create empty file
logging.warn('Deleting old index')
else:
return
source = idx_file.read().strip()
await es_client.indices.create(index=index_name, body=source)
@plac.annotations(
metafile=('Path to metadata','option', None, Path),
index_config=('Mappings for ES Index', 'option', None, Path),
delete_index=('Delete past index', 'flag', None),
data_path=("path to the dataset", 'option', None, Path),
index_name=('Index Name', 'option', None, str),
bert_inport=('BC port in', 'option', None, int),
bert_outport=('BC port', 'option', None, int),
valid_id_path=('Path to valid ids', 'option', None, str),
)
def main(metafile: Path = Path('covid-april-10/metadata.csv'),
index_config: Path = Path('assets/es_config.json'),
delete_index: bool=False,
index_name: str = 'covid-april-10',
bert_inport: int = 51235,
bert_outport: int = None,
data_path: Path= Path("datasets/covid-april-10/"),
valid_id_path: str = 'covid-april-10/docids-rnd1.txt'):
assert metafile.exists()
assert index_config.exists()
assert data_path.exists()
loop = asyncio.get_event_loop()
CovidParser.data_path = str(data_path) + "/"
# Keep a list of parsed documents that we have processed in the event of a crash
global parsed_fp
parsed_fp = f'parsed_docs_{index_name}.txt'
if not os.path.exists(parsed_fp):
open(parsed_fp, 'w+') # Create file
loop.run_until_complete(create_es_index(index_config, index_name, delete=delete_index))
df = pd.read_csv(metafile, index_col=None)
global bc
if bert_outport is None:
bert_outport = bert_inport+1
bc = BertClient(port=bert_inport, port_out=bert_outport)
parsed_ids = list(map(lambda k: k.strip(), open(parsed_fp, 'r+').readlines()))
valid_ids = open(valid_id_path).readlines()
valid_ids = {_id.strip(): True for _id in open(valid_id_path).readlines()}
loop.run_until_complete(index_documents(parsed_ids, df, index_name, valid_ids))
loop.run_until_complete(es_client.transport.close())
loop.close()
if __name__ == '__main__':
plac.call(main)