In [None]:
import json
import sys
import collections
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import copy

import torch
import torch.nn as nn
import torch.optim as optim
sys.path.append('../src')
import cb_utils

In [None]:
raw = cb_utils.sql_query_to_df('SELECT * FROM junk.cui_ndc_hcc_dataset;')

In [None]:
n_samples = len(raw)
n_samples 

In [None]:
raw.head()

### Create lookups for cui and hccs

In [None]:
cuis = collections.Counter()
for r in raw.rscuis:
    cuis.update(r)

In [None]:
n_cuis = len(cuis)
n_cuis  

In [None]:
# raw.categories.apply(lambda x: len(x) if x is not None else 0).describe()

In [None]:
# raw.rscuis.apply(lambda x: len(x)).describe()

In [None]:
# list(reversed(cuis.most_common(1000)))[:10]

In [None]:
cui_lookup = {}
for i, (cui, cnt) in enumerate(cuis.most_common()):
    cui_lookup[i + 1] = cui

In [None]:
cui_idx_lookup = {v: k for k, v in cui_lookup.items()}

In [None]:
hccs = collections.Counter()
for r in raw.categories:
    hccs.update(r)

In [None]:
n_hccs = len(hccs)
n_hccs

In [None]:
# hccs.most_common()

In [None]:
hcc_lookup = {}
for i, (hcc, cnt) in enumerate(hccs.most_common()):
    hcc_lookup[i] = hcc
hcc_idx_lookup = {v: k for k, v in hcc_lookup.items()}

In [None]:
max_cuis = 104
X = np.zeros((n_samples, max_cuis))
Y = np.zeros((n_samples, n_hccs))

In [None]:
def get_cui_idx(a): 
    return cui_idx_lookup[a]

for i, r in raw.iterrows():
    sorted_cuis = list(r.rscuis)
    sorted_cuis.sort(key=get_cui_idx)
    sorted_cuis = sorted_cuis[:max_cuis]
    for c, cui in enumerate(sorted_cuis):
        X[i, c] = cui_idx_lookup[cui]
    
    if r.categories is not None:
        for hcc in r.categories:
            Y[i, hcc_idx_lookup[hcc]] = 1

### Train test split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
X_train = torch.tensor(X_train, dtype=torch.int, device=device)
X_test = torch.tensor(X_test, dtype=torch.int, device=device)
X_val = torch.tensor(X_val, dtype=torch.int, device=device)
y_train = torch.tensor(y_train, dtype=torch.float, device=device)
y_test = torch.tensor(y_test, dtype=torch.float, device=device)
y_val = torch.tensor(y_val, dtype=torch.float, device=device)

### MLP

In [None]:
class EmbMLP(nn.Module):
    def __init__(self, embedding_dim=600, dropout=0.01, device=device):
        super(EmbMLP, self).__init__()
        self.emb_dim = embedding_dim
        self.emb = nn.Embedding(n_cuis + 1, embedding_dim=embedding_dim, padding_idx=0)
        
        self.mlp_model = nn.Sequential(
            
            nn.Linear(embedding_dim, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(dropout),
            
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(dropout),
            
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout),
            
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(dropout),
            
            nn.Linear(128, n_hccs)
        )
        if device == 'cuda':
            self.cuda()
    
    def forward(self, inputs):
        e = self.emb(inputs)
        emb = torch.sum(e, axis=1)
        return self.mlp_model(emb)

In [None]:
class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def train_loop(model, X_train, y_train, batch_size):
    losses = AverageMeter()
    for i in range(0, X_train.shape[0] // batch_size):
        batch_start = i * batch_size
        batch_end = (i + 1) * batch_size

        x = X_train[batch_start: batch_end]
        y = y_train[batch_start: batch_end]
        
        output = model(x)
        loss = criterion(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        l = loss.item()
        losses.update(l, output.shape[0])
        
    print('Loss: {:.3f}'.format(l))
            
    return losses.avg
        
def validation_loop(model, X_val, y_val, batch_size):
    losses = AverageMeter()
    for i in range(0, X_val.shape[0] // batch_size):
        batch_start = i * batch_size
        batch_end = (i + 1) * batch_size

        x = X_val[batch_start: batch_end]
        y = y_val[batch_start: batch_end]
        
        with torch.no_grad():
            output = model(x)
            loss = criterion(output, y)
            l = loss.item()
            losses.update(l, output.shape[0])
    print('val Loss: {:.3f}'.format(losses.avg))
    return losses.avg

In [None]:
mlp_model = EmbMLP(device=device, embedding_dim=600, dropout=0.1)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(mlp_model.parameters())
train_losses = []
val_losses = []
best_val_loss = np.inf
best_val_epoch = 0
best_model = copy.deepcopy(mlp_model)

epoch = 0

while epoch < 100:
    epoch += 1
    print('EPOCH: ', epoch)
    train_loss = train_loop(mlp_model, X_train, y_train, 256)
    val_loss = validation_loop(mlp_model, X_val, y_val, 256)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_epoch = epoch
        best_model = copy.deepcopy(mlp_model)
        print('new best val Loss: {:.3f}'.format(best_val_loss))
    elif epoch - best_val_epoch > 10:
        print('Stopping early no improvement since epoch', best_val_epoch)
        break

In [None]:
fig, ax = plt.subplots()  # Create a figure and an axes.
ax.plot(train_losses, label='Train')  # Plot some data on the axes.
ax.plot(val_losses, label='Val')
ax.legend()
# .111 best 

In [None]:
def build_results(model, X, labels, thresholds=None):
    if thresholds is None:
        thresholds = np.arange(0.1, 1, .1)

    results = []
        
    for pos_threshold in thresholds:
        preds = torch.sigmoid(model(X))
        pred_labels = torch.zeros_like(preds)
        pred_labels[preds > pos_threshold] = 1

        tp = torch.sum(pred_labels + labels == 2, axis=1, dtype=torch.float)
        tn = torch.sum(pred_labels + labels == 0, axis=1, dtype=torch.float)
        fp = torch.sum(pred_labels - labels == 1, axis=1, dtype=torch.float)
        fn = torch.sum(pred_labels - labels == -1, axis=1, dtype=torch.float)

        acc = (tp + tn) / (tp + tn + fp + fn) 

        recall = tp / (tp + fn)
        # recall[recall.isnan()] = 1
        results.append({
            'threshold': pos_threshold,
            'avg_recall': recall[~recall.isnan()].mean().item(),
            'avg_acc': acc.mean().item(),
            'tp_avg': tp.mean().item(),
            'tp_max': tp.max().item(),
            'tp_median': tp.median().item(),
            'tp_std': tp.std().item(),
            'fp_avg': fp.mean().item(),
            'fp_max': fp.max().item(),
            'fp_median': fp.median().item(),
            'fp_std': fp.std().item(),
            'tn_avg': tn.mean().item(),
            'tn_max': tn.max().item(),
            'tn_median': tn.median().item(),
            'tn_std': tn.std().item(),
            'fn_avg': fn.mean().item(),
            'fn_max': fn.max().item(),
            'fn_median': fn.median().item(),
            'fn_std': fn.std().item(),
        })
    return pd.DataFrame(results)

In [None]:
results = build_results(best_model, X_val, y_val)

In [None]:
results.head(10)

In [None]:
results.columns.to_list()

In [None]:
tall = results.melt(id_vars=['threshold'],
             value_vars=[
                 'avg_recall',
                 'avg_acc',
                 'tp_avg',
                 'tp_max',
                 'tp_median',
                 'tp_std',
                 'fp_avg',
                 'fp_max',
                 'fp_median',
                 'fp_std',
                 'tn_avg',
                 'tn_max',
                 'tn_median',
                 'tn_std',
                 'fn_avg',
                 'fn_max',
                 'fn_median',
                 'fn_std'
             ])

In [None]:
tall.head()

In [None]:
tall = tall.assign(grp=tall.variable.str.split('_').map(lambda x: x[0]))


In [None]:
df = tall.loc[tall.variable.isin(['tp_avg', 'fp_avg', 'fn_avg'])]
sns.relplot(data=df, x='threshold', y='value', hue='variable',  kind='line')

In [None]:
df = tall.loc[tall.variable.isin(['tp_median', 'fp_median', 'fn_median'])]
sns.relplot(data=df, x='threshold', y='value', hue='variable',  kind='line')

In [None]:
df = tall.loc[tall.variable.isin(['avg_acc'])]
sns.relplot(data=df, x='threshold', y='value', hue='variable',  kind='line')

In [None]:
df = tall.loc[tall.variable.isin(['avg_recall'])]
sns.relplot(data=df, x='threshold', y='value', hue='variable',  kind='line')