In [None]:
import faiss
import functools as ft
import gc
import nltk
import numpy as np
import os
import pandas as pd
import random
import re
import string

from gensim.models import Word2Vec
from gensim.models.word2vec import LineSentence
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.tokenize import sent_tokenize
from nltk import WordNetLemmatizer
from sqlalchemy import create_engine

#if True, save tokens to file
#if False (or otherwise), train model using token file
get_tokens = True
token_filename = "word2vec/miltoks.txt"
model_filename = "word2vec/covid_milnotes.model"

if get_tokens:
    conn = ""
    cnx = create_engine(conn)
    sites = []

    #given a site name (str), return query to get patients who have had covid from that site
    def create_query(site):
        return """with post_covid_pts as (
     SELECT person_id, min(observation_date) as covid_index_date
     FROM peds_recover.observation_derivation_recover odr
     WHERE value_as_concept_name = 'COVID-19 specific diagnosis'
     group by person_id
    )
    SELECT q.person_id new_person_id, q.birth_date, '""" + site + """' as site, covid_index_date, m.*, n.notes
    FROM notes.""" + note_site + """_notes n
    INNER JOIN notes.""" + note_site + """_metadata m ON n.notes_id = m.notes_id
    LEFT JOIN peds_recover.person q on m.person_id = q.site_id
    INNER JOIN post_covid_pts pcp on q.person_id = pcp.person_id"""

    dfs = []
    for site in sites:
      dfs.append(pd.read_sql_query(create_query(site), cnx))
    df_covid = pd.concat(dfs)
    del dfs

    #separate out pre/post
    #convert times
    df_covid["covid_index_date"] = pd.to_datetime(df_covid["covid_index_date"],
                                              format="%Y-%m-%d")
    df_covid["note_date"] = pd.to_datetime(df_covid["note_date"],
                                          format="%Y-%m-%d", errors="coerce")
    #remove erroneous dates                                    
    df_covid = df_covid[df_covid["note_date"] != "NaT"]
    #blackout period
    df_covid["timediff"] = abs((df_covid["note_date"] -
                               df_covid["covid_index_date"]).dt.days)
    df_covid = df_covid[df_covid["timediff"] > 28]
    print(len(df_covid))
    gc.collect()

#take a df series of text
#return text cleaned of unicode code points and bad characters that we have found in the notes
def clean_text(text):
    text = text.str.replace("<82>|<85>", ",", regex=True)
    text = text.str.replace("<92>", "'", regex=False)
    text = text.str.replace("<93>|<94>", '"', regex=True)
    #these two seem to mostly (though not always) appear in contexts that should have spaces
    text = text.str.replace("u00b7u00b0u00b7u00b0", " ")
    text = text.str.replace("u00b7u00b0", " ")
    #remove raw unicode
    text = text.str.replace("u0.b.|u2...|u0.a.|<..>", "", regex=True)
    #remove other misc. symbols + the strange $& & $!
    text = text.str.replace("\$\&|\$!|¼|½|·|`|°|•|§|—", "", regex=True)
    return text

if get_tokens:
    df_covid = df_covid.drop_duplicates(subset="notes")
    df_covid["cleaned_notes"] = clean_text(df_covid.notes)
    df_covid = df_covid.drop_duplicates(subset="cleaned_notes")
    print(len(df_covid))
    df_covid = df_covid.cleaned_notes
    gc.collect()

if get_tokens:
    seed = 129081
    #sample notes (1 million) to reduce tokenization time and memory
    df_covid = df_covid.sample(n=1000000, random_state=seed)
    gc.collect()

#set random seeds for repeatability
seed = 129081
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)

stops = set(stopwords.words("english"))
stops.add("history")
stops.add("mg")
stops.add("ml")
stops.add("md")
stops.add("date")
stops.add("nd")
stops.add("th")
stops.add("st")
stops.add("rd")

#takes a sentence (str) and a tokenizer
#returns a list of tokens
def process_text_word2vec(text, tokenizer):
    text.replace("504", "Fiveohfour")
    text = re.sub("-", " ", text)
    text = re.sub("/", " ", text)
    text = re.sub(":", " ", text)
    text = re.sub("=", " ", text)
    remove = "[" + string.punctuation + string.digits + "¿¼½Â·°©®" + "]"
    text = re.sub(remove, "", text)
    #remove multiple spaces
    text = re.sub(" +", " ", text)
    
    tokens = tokenizer(text)
    #remove common words
    tokens = [x for x in tokens if x not in stops]
    return tokens

#takes note Series and filepath
#saves list of tokens to filepath
def tokenize(text, filepath):
    gc.collect()
    sensenlist = []
    res = []
    for note in text:
        senlist = []
        note = note.lower()
        sens = sent_tokenize(note)
        for sen in sens:
            #first split again, because many sentences are separated by spaces
            s = sen.split("  ")
            for i in s:
                senlist.append(i)
        for sen in senlist:
            pro = process_text_word2vec(sen, word_tokenize)
            if len(pro) > 1:
                with open(filepath, encoding="utf8", mode="a") as f:
                    for word in pro:
                        f.write(word + " ")
                    f.write("\n")

#create list of tokens
if get_tokens:
    tokenize(df_covid, token_filename)
    
else:
    #train model
    gc.collect()
    sentences = LineSentence(token_filename)
    model = Word2Vec(sentences, vector_size=1000,
                             workers=4, seed=seed)
    model.save(model_filename)