In [None]:
# Imports
from time import time
from os.path import join as join_path
import numpy as np
import pandas as pd

import multiprocessing
cores = multiprocessing.cpu_count()

from gensim.models import Word2Vec, Doc2Vec
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models.doc2vec import TaggedDocument
import logging # Setting up the loggings to monitor gensim
logging.basicConfig(format="%(levelname)s - %(asctime)s: %(message)s", datefmt= '%H:%M:%S', level=logging.INFO)

import nltk
nltk.download('punkt')

from utils import clean_text
from tqdm.notebook import tqdm

from sklearn.metrics.pairwise import cosine_similarity

## Load and prepare data

In [None]:
# Constants
cord_data_dir = 'data'
cord_data_path = join_path(cord_data_dir, 'cord-19-data.csv')
d2v_saved_models_dir = 'models-doc2vec'
saved_models_prefix = 'model'

In [None]:
cord_data = pd.read_csv(cord_data_path)
cord_data_eng = cord_data[cord_data['language'] == 'en']
eng_texts = cord_data_eng[['cord_uid', 'body_text']].values

In [None]:
cord_num_sentences = 0
for _, text in tqdm(eng_texts):
    sentences = nltk.tokenize.sent_tokenize(text)
    cord_num_sentences += len(sentences)
print(f'Total number of CORD-19 sentences: {cord_num_sentences}')

In [None]:
class CORDDataIteratorDoc2Vec():
    def __init__(self, texts: np.ndarray):
        self.texts = texts
    
    def __iter__(self):
        for cord_uid, cord_text in self.texts:
            sentences = nltk.tokenize.sent_tokenize(cord_text)
            cleaned_sentences = [clean_text(sent) for sent in sentences]
            for sentence in cleaned_sentences:
                yield TaggedDocument(sentence, [cord_uid])

In [None]:
cord_sentences = CORDDataIteratorDoc2Vec(eng_texts)

## Learn word embeddings using Doc2Vec

In [None]:
class DocEpochSaver(CallbackAny2Vec):
    '''Callback to save model after each epoch.'''

    def __init__(self, output_dir: str, prefix: str, start_epoch: int = 1):
        self.output_dir = output_dir
        self.prefix = prefix
        self.epoch = start_epoch

    def on_epoch_end(self, model):        
        output_path = join_path(self.output_dir, f'{self.prefix}_epoch_{self.epoch + 1}.model')
        model.save(output_path)
        self.epoch += 1    

In [None]:
# Setup initial model
d2v_model = Doc2Vec(
    min_count=20,
    window=2,
    vector_size=300,
    negative=5,
    workers=cores-1,
    callbacks=[DocEpochSaver(d2v_saved_models_dir, saved_models_prefix)]
)

In [None]:
# Build vocabulary
t = time()
d2v_model.build_vocab(tqdm(cord_sentences, total=cord_num_sentences), progress_per=int(cord_num_sentences / 100))
print(f'Time to build vocab: {round((time() - t) / 60, 2)} mins')

In [None]:
# d2v_model = Word2Vec.load('models-doc2vec/model_epoch_2.model')

In [None]:
# Train model
t = time()
d2v_model.train(cord_sentences, total_examples=d2v_model.corpus_count, epochs=10, report_delay=30, callbacks=[DocEpochSaver(d2v_saved_models_dir, saved_models_prefix, 10)])
print(f'Time to train the model: {round((time() - t) / 60, 2)} mins')

In [None]:
# d2v_model.save('models-doc2vec/model_epoch_2.model')

In [None]:
# Prototype search pipeline below
'''
len(d2v_model.docvecs.index2entity)

query = clean_text('The patient (Fo, ) was a 58 year old mentally retarded white woman, born in a rural area of southwestern Virginia.')
query_vec = d2v_model.infer_vector(query, epochs=100)

eng_texts[0][0]

doc_weight_mat = np.zeros((len(d2v_model.docvecs.index2entity), 300))
for i, cord_uid in enumerate(tqdm(d2v_model.docvecs.index2entity)):
    doc_weight_mat[i] = d2v_model.docvecs[cord_uid]

def cosine_sim(vec: np.ndarray, mat: np.ndarray):
    return vec @ mat.T / (np.linalg.norm(vec) * np.linalg.norm(mat, axis=1))

query_vec.shape, doc_weight_mat.shape

# Find closest document
#keys = d2v_model.docvecs.index2entity
similarities = cosine_sim(query_vec, doc_weight_mat)

top_n = 10
sorted_indicies = similarities.argsort()[::-1]
top_sim = list(zip(np.array(d2v_model.docvecs.index2entity)[sorted_indicies][:top_n], similarities[sorted_indicies][:top_n]))
top_sim

top_cord_uid = top_sim[0][0]
best_text = cord_data[cord_data['cord_uid'] == top_cord_uid.split('_')[0]].body_text.values[0]
best_text_sentences = nltk.tokenize.sent_tokenize(best_text)
best_text_sentences[int(top_cord_uid.split('_')[1])]
'''