In [1]:
import os
import sys
import re
import time
import random
import warnings
import collections
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

# sys.path.append('../src')
# import cb_utils

sns.set(style="darkgrid")
pd.options.display.max_columns = 500

%load_ext autoreload
%autoreload 2

### Modelling

In [2]:
from fastai.text.all import *

In [3]:
path = Path('../')

In [4]:
files = get_text_files(path, folders = ['icds_and_target'])

In [5]:
len(files)

9664

In [6]:
vocab_path = '/home/bp/data-analytics/modelslml_epoch_20_vocab_20220918.pkl'
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

In [7]:
text_block = TextBlock.from_folder(
    path / 'icds_and_target', 
    vocab=vocab,
    is_lm=False,
    seq_len=72,
    backwards=False
)

In [8]:
targets = [float(re.match('\d*_\d*_(.*)\.txt$', f.name).groups()[0]) for f in files]


In [9]:
targets[:10]

[1552.16, 27232.44, 309.84, 72.55, 0.0, 0.0, 8.68, 0.0, 0.0, 1536.82]

In [10]:
mx, mn = max(targets), min(targets)
diff = mx - mn

In [11]:
def normalize_target(t):
    return (t - mn) / diff

In [12]:
def splitter(file_names, *args):
    train = L()
    valid = L()
    for i, f in enumerate(file_names):
        if 'train' == f.parent.name:
            train.append(i)
        else:
            valid.append(i)
    
    return train, valid
        

In [13]:
def label_func(fname):
    match = re.match('\d*_\d*_(.*)\.txt$', fname.name)
    if match is not None:
        return normalize_target(float(match.groups()[0]))
    else:
        print('failed on', fname.name)
        return float(match.groups()[0])
#     return float(re.match(r'^\d_\d*_(.*)\.txt$', fname.name).groups()[0])

In [14]:
dblocks = DataBlock(blocks=(text_block, RegressionBlock),
                 get_items=get_text_files,
                 get_y=label_func,
                 splitter=RandomSplitter())

In [15]:
dls = dblocks.dataloaders(path / 'icds_and_target', bs=32)

In [16]:
dls.show_batch(max_n=3)

Unnamed: 0,text,text_
0,xxbos xxbos e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 e7800 i10 k219 xxunk z7689 f330 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e1142 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119 z7689 i10 f330 k219 j3089 e782 ttlc_1 e119,0.0008105357992462
1,xxbos xxbos n186 i770 z992 n189 k219 m1990 ttlc_1 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_5 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 ttlc_1 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 z7401 z79899 m5126 m4606 xxunk m549 z9114 ttlc_1 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_2 n189 z992 n186 d689 d631 m899 n2581 e8350 e8330 ttlc_5 n186 n189 d689 z992 d631 m899 n2581 e8350 e8330,0.0001275233662454
2,xxbos xxbos z48815 k50912 k50919 e1143 z432 k3184 i10 f319 f419 ttlc_1 z48815 e1143 k3184 i10 f319 f419 ttlc_1 k50912 k50919 z432 z48815 e1143 k3184 i10 f319 f419 ttlc_2 k50912 k50919 z432 ttlc_5 z48815 e1143 k3184 i10 f319 f419 ttlc_2 k50912 z48815 e1143 k3184 i10 f319 f419 ttlc_5 z48815 e1143 k3184 i10 f319 f419 ttlc_1 z48815 e1143 k3184 i10 f319 f419 ttlc_2 k50912 z48815 e1143 k3184 i10 f319 f419 ttlc_5 z48815 e1143 k3184 i10 f319 f419 ttlc_1 k3184 r99 z9884 z48815 e1143 i10 f319 f419 ttlc_1 k3184 ttlc_1 k50912 z48815 e1143 k3184 i10 f319 f419 ttlc_5 r1031 m25461 r1032 m170 e119 f419 i10 j45909 e039 ttlc_1 r99 z0120 z48815 e1143 k3184 i10 f319 f419 ttlc_1 z48815 e1143 k3184 i10 f319 f419 ttlc_1 k50912 k50919 z432 ttlc_1 z48815 z7901 e1143 z969 k3184 z86718 i10 f319 f419 ttlc_5 r99 k3184 z48815 e1143 i10 f319 f419 ttlc_1 g8929 k50912 k50919 z432 ttlc_2 m1711,0.3420833647251129


In [17]:
learn = text_classifier_learner(
    dls,
    AWD_LSTM,
    y_range=(0,1),
#     drop_mult=0.5,
    metrics=[mse, R2Score()]
).to_fp16()
learn = learn.load_encoder('/home/bp/data-analytics/modelslml_epoch_20_encoder_20220918')

In [18]:
learn.model

SequentialRNN(
  (0): SentenceEncoder(
    (module): AWD_LSTM(
      (encoder): Embedding(14008, 400, padding_idx=1)
      (encoder_dp): EmbeddingDropout(
        (emb): Embedding(14008, 400, padding_idx=1)
      )
      (rnns): ModuleList(
        (0): WeightDropout(
          (module): LSTM(400, 1152, batch_first=True)
        )
        (1): WeightDropout(
          (module): LSTM(1152, 1152, batch_first=True)
        )
        (2): WeightDropout(
          (module): LSTM(1152, 400, batch_first=True)
        )
      )
      (input_dp): RNNDropout()
      (hidden_dps): ModuleList(
        (0): RNNDropout()
        (1): RNNDropout()
        (2): RNNDropout()
      )
    )
  )
  (1): PoolingLinearClassifier(
    (layers): Sequential(
      (0): LinBnDrop(
        (0): BatchNorm1d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): Dropout(p=0.2, inplace=False)
        (2): Linear(in_features=1200, out_features=50, bias=False)
        (3): ReLU(inplace=True

In [19]:
learn.fine_tune(50, freeze_epochs=5)

epoch,train_loss,valid_loss,mse,r2_score,time
0,0.233348,0.220325,0.220325,-141.096143,00:41
1,0.119356,0.071207,0.071207,-44.924021,00:41
2,0.010373,0.003346,0.003346,-1.157962,00:41
3,0.003077,0.001636,0.001636,-0.055322,00:41
4,0.001037,0.001359,0.001359,0.123438,00:41


epoch,train_loss,valid_loss,mse,r2_score,time
0,0.001091,0.001324,0.001324,0.14584,01:27
1,0.001173,0.00135,0.00135,0.129553,01:28
2,0.000795,0.001374,0.001374,0.113625,01:28
3,0.001,0.00126,0.00126,0.187541,01:28
4,0.001469,0.001483,0.001483,0.04387,01:28
5,0.001063,0.001417,0.001417,0.085988,01:28
6,0.000792,0.00138,0.00138,0.109743,01:27
7,0.000846,0.001297,0.001297,0.163293,01:27
8,0.001085,0.001551,0.001551,-0.000255,01:28
9,0.000691,0.001372,0.001372,0.114864,01:27


In [None]:
learn.fit_one_cycle(10)

In [None]:
awd_lstm_clas_config

In [None]:
learn.fit_one_cycle(20, 2e-4)

In [None]:
models_path = '/home/bp/data-analytics/models'
file = models_path + 'lml_epoch_20_20220918.pkl'
file

In [None]:
learn.save(file, with_opt=True, pickle_protocol=2)

In [None]:
file = models_path + 'lml_epoch_20_encoder_20220918'
learn.save_encoder(file)

### Visualize Results

In [None]:
learn.model

In [None]:
from torch.nn import functional as F

def get_normalized_embeddings():
    return F.normalize(learn.model[0].encoder.weight)

def most_similar(token, embs):
    if token[:2] == 'xx':
        return
#     idx = num.vocab.itos.index(token)
    idx = num.vocab.index(token)
    sims = (embs[idx] @ embs.t()).cpu().detach().numpy()

    print(f'Similar to: {token}: {icd_lookup[token]}')
    for sim_idx in np.argsort(sims)[::-1][1:11]:
        if num.vocab[sim_idx][:2] == 'xx':
            print(f'{sims[sim_idx]:.02f}: {num.vocab[sim_idx]}')
        else:
            print(f'{sims[sim_idx]:.02f}: {num.vocab[sim_idx]}: {icd_lookup[num.vocab[sim_idx]]}')
            

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import random
import json


with open('icd_descr_lookup.json', 'r') as f:
    icd_lookup = json.load(f)

def tsne_plot(model):
    "Creates and TSNE model and plots it"
    labels = []
    tokens = []

    for word in model.wv.vocab:
        tokens.append(model[word])
        labels.append(word)
    
    tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)
    new_values = tsne_model.fit_transform(tokens)

    x = []
    y = []
    for value in new_values:
        x.append(value[0])
        y.append(value[1])
        
    plt.figure(figsize=(16, 16)) 
    for i in range(len(x)):
        plt.scatter(x[i],y[i])
        plt.annotate(labels[i],
                     xy=(x[i], y[i]),
                     xytext=(5, 2),
                     textcoords='offset points',
                     ha='right',
                     va='bottom')
    plt.show()

In [None]:
embeddings = get_normalized_embeddings()
most_similar('e11', embeddings)

In [None]:
for code in random.sample(num.vocab, 10):
    most_similar(code, embeddings)
    print('')

In [None]:
'xxfake'[:2]

In [None]:
icd_lookup = icd_lookup[0]['jsonb_object_agg']

In [None]:
icd_lookup['e11']

In [None]:
with open('icd_descr_lookup.json', 'w') as f:
    json.dump(icd_lookup, f)