# Code embeddings
Get fasttext embeddings for dx and rx codes.  Treat each stay as a sentence, and the corpus is composed of multiple random permutations. 

In [1]:
import os
import sys
import datetime
import pandas as pd
import pickle as pkl
import numpy as np
import scipy
import gc

from multiprocessing import Pool
from gensim.models import FastText

%load_ext autoreload
%autoreload 2

In [2]:
sys.path.append('/code')
from edge import data
from edge import patient_stays
from edge import diagnosis
from edge import meds
from edge import vitals
from edge import utils

In [3]:
combined_ptimes = pd.read_parquet('/data/raw/combined_ptimes.parquet')
combined_stays = pd.read_parquet('/data/raw/combined_stays.parquet')
combined_ptimes.shape, combined_stays.shape

((14694483, 4), (118127, 19))

In [4]:
combined_dx = scipy.sparse.load_npz('/data/processed/combined_dx_features_csr.npz')
combined_rx = scipy.sparse.load_npz('/data/processed/combined_rx_features_csr.npz')

combined_codes = scipy.sparse.hstack([combined_dx, combined_rx])
combined_codes = scipy.sparse.csr_matrix(combined_codes)
combined_codes

<14694483x5070 sparse matrix of type '<class 'numpy.float32'>'
	with 294388924 stored elements in Compressed Sparse Row format>

In [5]:
with open('/data/processed/combined_dx_colnames.pkl', 'rb') as f_in: 
    dx_colnames = pkl.load(f_in)
with open('/data/processed/combined_rx_colnames.pkl', 'rb') as f_in: 
    rx_colnames = pkl.load(f_in)    
rx_colnames = np.array([c.replace(' ', '_') for c in rx_colnames])
code_vocab = np.concatenate([dx_colnames, rx_colnames])

In [68]:
# Map codes in vocab to indices
code_to_index = {k:(v+1) for v, k in enumerate(code_vocab)}
with open('/data/raw/code_to_index_map.pkl', 'wb') as f_out: 
    pkl.dump(code_to_index, file=f_out)

In [6]:
grouped_ptimes = combined_ptimes.groupby(['MasterPatientID', 'StayRowIndex'])

In [75]:
print('Constructing jobs data')
groups_list = []
for group_idx, group in grouped_ptimes: 
    groups_list.append(group)
print(f'Got {len(groups_list)} jobs...')  

Constructing jobs data
Got 116667 jobs...


In [76]:
num_permutations = 4

def stayWorker(group): 
    indices = group.index.values
    this_dx = np.asarray(combined_dx[indices].todense())
    this_rx = np.asarray(combined_rx[indices].todense())
    dx_col_sums = np.sum(this_dx, axis=0)
    rx_col_sums = np.sum(this_rx, axis=0)
    dx_sel = np.nonzero(dx_col_sums > 0)[0]
    rx_sel = np.nonzero(rx_col_sums > 0)[0]
#    dx_indices = [code_to_index[c] for c in dx_colnames[dx_sel]]
#    rx_indices = [code_to_index[c] for c in rx_colnames[rx_sel]]
#    codes_for_stay = np.concatenate([dx_indices, rx_indices])
    dx_codes = dx_colnames[dx_sel]
    rx_codes = rx_colnames[rx_sel]
    codes_for_stay = np.concatenate([dx_codes, rx_codes])
    sentences = []
    for i in range(num_permutations): 
        codes_for_stay = np.random.permutation(codes_for_stay)
        code_sentence = " ".join(codes_for_stay)
        sentences.append(code_sentence)
    return sentences


In [None]:
print('Staring jobs...')
code_sentences = []
with Pool(os.cpu_count() - 4) as pool: 
    code_sentence_lists = pool.map(stayWorker, groups_list)

print('Collating code sentences...')
for code_sentence_list in code_sentence_lists:     
    code_sentences.extend(code_sentence_list)

print('Done...')

Staring jobs...


In [79]:
len(code_sentences)

466668

In [80]:
# Save "corpus"
with open('/data/raw/code_sentence_corpus.txt', 'w') as f_out: 
    for code_sentence in code_sentences: 
        print(code_sentence, file=f_out)
print('Done...')    

Done...


In [85]:
# Fit model. 
num_threads = os.cpu_count() - 4

# Fit fasttext embeddings to these...
embed_dim = 200
num_iter = 3
print(f"Fitting fasttext for codes")
code_model = FastText(corpus_file='/data/raw/code_sentence_corpus.txt', 
                      size=embed_dim, 
                      sg=1, 
                      iter=num_iter,
                      negative=10,
                      workers=num_threads)
code_model.save(f"/data/model/ft_combined_codes_d{embed_dim}.model")
print('Done...')

Fitting fasttext for pn phrases
Done...
