In [31]:
import os
import pickle
import csv
from datetime import datetime
import numpy as np
import torch

base_path = '/media/disk3/disk3'

In [48]:
import sklearn
#ndarray -> normalize -> map2word -> tensor    
def read_embeddings(vecidx_path, vec_path):
    words = []
    with open(vecidx_path, 'r') as f:
        tsvreader = csv.reader(f, delimiter='\t')
        for i,row in enumerate(tsvreader):
            words.append(row[0])
            if i >1000:
                break

    vecs = np.ndarray((len(words), 400))
    with open(vec_path, 'r') as f:
        tsvreader = csv.reader(f, delimiter='\t')
        for i,row in enumerate(tsvreader):
            vecs[i,:] = row[:400]
            if i >1000:
                break

    vecs = sklearn.preprocessing.normalize(vecs)
            
    word2vec = {_:vecs[i,:].reshape(-1) for i,_ in enumerate(words)}
    return word2vec

In [49]:
def read_data_dump(data_path):
    with open(data_path, 'r') as f:
        data = pickle.load(f)
    return data

In [89]:
def get_labels(data):
    labels = []
    for hadm in data.keys():
        labels.append(data[hadm]['labels']['icd'][0])
    labels = {i:_ for _,i in enumerate(labels)}
    return labels

In [84]:
from sklearn.feature_extraction import stop_words
import re
import random

def clean_str_no_stopwords(s):
    s = re.sub('\[\*\*.*\*\*\]|\\n|\s+', ' ', s).replace('  ', ' ').lower().split() 
    return [token for token in s if token not in stop_words.ENGLISH_STOP_WORDS]

class Datum():
    '''
    seq of notes
        tokenize
    labels vector
    '''
    def __init__(self, data, label_map, embeddings, log_unk):
        self.label = label_map[data['labels']['icd'][0]]
        # List of Average embedding tokens of each note
        self.notes = data['notes']
        self.embeddings = embeddings
        #Logs missing vocabs in notes & in total
        self.log_unk= log_unk
        self.logged = False
        self.preprocess_notes()        
        
    def preprocess_notes(self):
        self.notes = sorted(self.notes, key=lambda x:datetime.strptime(x['date'], '%Y-%m-%d'))
        self.notes = [clean_str_no_stopwords(note['note']) for note in self.notes]
        # averaging of embeddings
        # PAD & UNKNOWN
        # padding of notes seq. in dataloader collation
        embedded_notes = []
        for note in self.notes:
            emb_note = []
            for _ in note:
                _emb = self.embeddings.get(_, None)
                if _emb is not None:
                    emb_note.append(_emb)
                else:
                    if self.log_unk.get(_) is None:
                        self.log_unk[_] = {'count':1, 'notes':1}
                    else:
                        if self.logged is False:
                            self.log_unk[_]['count'] +=1
                            self.log_unk[_]['notes'] +=1
                            self.logged = True
                        else:
                            self.log_unk[_]['count'] +=1
                    emb_note.append(self.embeddings['unknown'])
                                                    
            embedded_notes.append(np.mean(emb_note, 0))
        self.notes = embedded_notes
        

In [None]:
#Getting only priority #1 labels now
# pretrained = read_embeddings(os.path.join(base_path, 'ri-3gram-400-tsv/vocab.tsv'), 
#                                          os.path.join(base_path, 'ri-3gram-400-tsv/vectors.tsv'))
# data = read_data_dump(os.path.join(base_path, 'notes_dump.pkl'))

#TODO Map labels to CCS
label_map = get_labels(data)
log_unk = {}
dataset = [Datum(data[adm], label_map, pretrained, log_unk) for adm in data if 'notes' in data[adm]]    
# random.shuffle(dataset)
# margin = int(len(dataset)*0.8)
# trainset = dataset[:margin]
# valset = dataset[margin:]

In [None]:
#TODO Print log of missing vocab

In [None]:
batch_size = 64
num_workers = 2
hidden_dim = 100

class Dataloader(data.Dataset):
    def __init__(self, data):
        super(Dataloader, self).__init__()
        self.data = data
    def __getitem__(self, index):
        return self.data[index]#(self.data[index].embedded_notes, self.data[index].label) 
    def __len__(self):
        return len(self.data)
    
def padding_collation(batch):
    batch_list, label_list = [], []
    max_seq_len = np.max([len(datum.embedded_notes) for datum in batch])
    for datum in batch:
        ##TODO MAP to padding embedding
        padded_vec = [embedded_notes['pad'] for i in range(max_seq_len-len(datum.embedded_notes))] + datum.embedded_notes
        batch_list.append(padded_vec)
        label_list.append(datum.label)
    return [batch_list, label_list]    

train_loader = torch.utils.data.DataLoader(dataset= Dataloader(training_set), batch_size=batch_size, shuffle=True, 
                                                           num_workers=num_workers, collate_fn=padding_collation)
val_loader = torch.utils.data.DataLoader(dataset= Dataloader(val_set), batch_size=batch_size, shuffle=True, 
                                                           num_workers=num_workers, collate_fn=padding_collation)

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, embed_dim, hidden_dim, labels, batch_size):
        super(LSTMModel, self).__init__()
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(embed_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim, len(labels.keys()))
    def init_hidden(self):
        return (Variable(torch.zeros(1, self.batch_size, self.hidden_dim)), Variable(torch.zeros(1, self.batch_size, self.hidden_dim)))
    def forward(self, x, hidden):
        # seqlen x batch x emb_dim
        x = torch.transpose(x, 1, 0)
        x, _hidden  = self.lstm(x, hidden)
        x = x[-1, :, :].view(self.batch_size, -1)
        x = self.lin(x)
        return x

In [None]:
model = LSTMModel(embed_dim, hidden_dim, easy_label_map, batch_size)
opti = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.999))
crit = nn.CrossEntropyLoss()

In [None]:
def evaluate(model, loader, batch_size):
    model.eval()
    correct = 0
    total = 0
    for i in loader:

    for batch in loader:
        if batch[0].size(0) != batch_size:
            continue
        x = Variable(batch[0])
        hidden = model.init_hidden()
        x = model(x, hidden)
        _, predicted = torch.max(x.data, 1)
        total += batch[1].size(0)
        correct += (predicted == batch[1]).sum()

    return correct / float(total)

In [None]:
step = 0
step_log = []
loss_log = []
val_acc_log = []
for batch in train_loader:
    if batch[0].size(0) != batch_size:
        continue
    model.zero_grad()
    x = Variable(batch[0])
    hidden = model.init_hidden()
    x = model(x, hidden)
    loss = crit(x, Variable(batch[1].view(-1)))
    loss.backward()
    opti.step()
    
    if step % 100 == 0:
        val_acc = evaluate(model, val_loader, batch_size)
        print("Step: %d, Loss: %.4f, Validation Acc: %.2f"%(step, loss.data[0], val_acc))
        step_log.append(step)
        loss_log.append(loss.data[0])
        val_acc_log.append(val_acc)
    step += 1
    if step == 20:
        break


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.plot(step_log, loss_log)