In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import csv
import re
import random
from collections import Counter
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
import math
use_cuda = torch.cuda.is_available()




class WordModel(nn.Module):
    def __init__(self, embed_dim, vocab_size, hidden_dim, batch_size):
        super(WordModel, self).__init__()
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim

        self.word_embed = nn.Embedding(vocab_size, embed_dim)
        self.word_rnn = nn.GRU(embed_dim, hidden_dim,  bidirectional=True)

    def forward(self, x, _hidden):
        true_x_size = x.size()
        x = x.view(self.batch_size, -1)
        #print("before embedding", x.size())
        x = self.word_embed(x)
        #print("after embedding", x.size())
        x = torch.transpose(x, 1, 0)
        return self.word_rnn(x, _hidden)


    def init_hidden(self):
        hidden1 = Variable(torch.zeros(2, self.batch_size,  self.hidden_dim))
        #hidden2 = Variable(torch.zeros(1, self.batch_size, self.hidden_dim))
        return hidden1#, hidden2)


class Attend(nn.Module):
    def __init__(self, batch_size, hidden_dim):
        super(Attend, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.lin = nn.Linear(hidden_dim, hidden_dim)

        self.context = Variable(torch.FloatTensor(hidden_dim))
        stdv = 1. / math.sqrt(self.context.size(0))
        self.context.data.uniform_(-stdv, stdv)

        self.sm = nn.Softmax()
    def forward(self, x, sentence_size):
        attends = []
        for i in range(x.size(0)):
            #print(x[i,:,:].size())
            attends.append(F.tanh(self.lin(x[i,:,:])).unsqueeze(0))
        #print ("single attend:", attends[0].size())
        attends = torch.cat(attends)
        #print("cat attention:", attends.size())
        attn_combine = torch.mul(attends, self.context)
        #print("attention_combine:", attn_combine.size())
        alpha = self.sm(attn_combine.contiguous().view(-1, self.hidden_dim))
        #print("sm size:", alpha.size())
        #print(x.size())
        attended = torch.mul(x, alpha).contiguous().view(self.batch_size, sentence_size, -1, self.hidden_dim)
        #print("x.alpha prod:", attended.size())
        attended = torch.sum(attended, 2)
        #print("attended sum:", attended.size())
        return attended

class SentModel(nn.Module):
    def __init__(self, batch_size, hidden_dim):
        super(SentModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.sent_rnn = nn.GRU(hidden_dim, hidden_dim,  bidirectional=True)
    def forward(self, x, _hidden):
        x = torch.transpose(x, 1, 0)
        return self.sent_rnn(x, _hidden)
    def init_hidden(self):
        hidden1 = Variable(torch.zeros(2, self.batch_size,  self.hidden_dim))
        #hidden2 = Variable(torch.zeros(1, self.batch_size, self.hidden_dim))
        return hidden1#, hidden2)

class Classifer(nn.Module):
    def __init__(self, hidden_dim, op_dim):
        super(Classifer, self).__init__()
        self.lin = nn.Linear(hidden_dim, op_dim)
    def forward(self, x):
        return self.lin(x)

class Ensemble(nn.Module):
    def __init__(self, embed_dim, vocabulary_size, hidden_dim, batch_size, label_map):
        super(Ensemble, self).__init__()
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.word_rnn = WordModel(embed_dim, vocabulary_size, hidden_dim, batch_size)
        self.wordattention = Attend(batch_size, 2*hidden_dim)
        self.sent_rnn = SentModel(batch_size, 2*hidden_dim)
        self.sentattention = Attend(batch_size, 4*hidden_dim)
        self.clf = Classifer(4*hidden_dim, len(label_map.keys()))

    def forward(self, batch_x, word_hidden, sent_hidden):
        #print("raw size:", batch_x.size())
        x, hidden = self.word_rnn(batch_x, word_hidden)
        #print("word rnn op size:", x.size())
        #print("word rnn hidden size:", hidden.size())
        x = x.contiguous().view(batch_x.size(2), batch_x.size(0)*batch_x.size(1), -1) # sent_size x batch_size x 2*hidd
        #print("============")
        #print("word attention ip size:", x.size())
        sentence_reprs = self.wordattention(x, batch_x.size(1)) # batch_size x sent_size x 2*hidden
        #print(sentence_reprs.size())
        #print("============")
        sent_op, sent_hidden = self.sent_rnn(sentence_reprs, sent_hidden)
        #print("sent rnn op size:", sent_op.size())
        sent_op = sent_op.contiguous().view(batch_x.size(1), self.batch_size, -1) # sent_size x batch_size x 2*hidden
        sent_att = self.sentattention(sent_op, 1)
        sent_att = sent_att.contiguous().view(self.batch_size, 4*self.hidden_dim)
        pred_prob = self.clf(sent_att)
        return pred_prob
    

In [None]:
import csv
import re
import random
from collections import Counter
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
import math
import tensorboard_logger
import pickle
import time
import sys
sys.path.append('/home/ag4508/Medical-Diagnosis-Learning/src')
from attention_databuilder import *
#from attention_models import *

In [None]:
label_path = '../data/top50_labels1.csv'
log_path = '/home/ag4508/nlp_log'
traindata_path = '/misc/vlgscratch2/LecunGroup/anant/nlp/processed_data/50codesL3_UNK_content_4_train_data.pkl'
valdata_path = '/misc/vlgscratch2/LecunGroup/anant/nlp/processed_data/50codesL3_UNK_content_4_valid_data.pkl'

PADDING = "<PAD>"
UNKNOWN = "UNK"
batch_size = 4
num_workers = 4
embed_dim = 50
hidden_dim = 100
lr = 1e-2
num_epochs = 10
log_interval = 1
gpu_id = 3
_t = time.time()

torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(gpu_id)

In [None]:
# data reader
# tensorboard_logger.configure(log_path)
# traindata = pickle.load(open(traindata_path, 'r'))
# valdata = pickle.load(open(valdata_path, 'r'))


label_map = {i:_ for _,i in enumerate(get_labels(traindata))}
##TODO confirm if padding gets 0th index
vocabulary, token2idx  = build_vocab(traindata, PADDING)

traindata = traindata[:1000]
valdata = valdata[:100]

trainset = NotesData(traindata, token2idx, UNKNOWN, label_map)
valset = NotesData(valdata, token2idx, UNKNOWN, label_map)
print("Data Loaded in %.2f mns."%((time.time()-_t)/60))

train_loader = torch.utils.data.DataLoader(dataset = trainset, batch_size=batch_size, shuffle=True,
                                                           num_workers=num_workers, collate_fn=sent_batch_collate)
val_loader = torch.utils.data.DataLoader(dataset = valset, batch_size=batch_size, shuffle=True,
                                                           num_workers=num_workers, collate_fn=sent_batch_collate)
print("data loader done")

In [None]:
model = Ensemble(embed_dim, len(vocabulary), hidden_dim, batch_size, label_map)
crit = nn.CrossEntropyLoss()
opti = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))

if use_cuda:
    model.cuda()
    crit.cuda()
    model.wordattention.context = model.wordattention.context.cuda()
    model.sentattention.context = model.sentattention.context.cuda()


In [None]:
print("Starting training...")
step = 0
train_loss_mean = []
for n_e in range(num_epochs):
    word_hidden = model.word_rnn.init_hidden()
    sent_hidden = model.sent_rnn.init_hidden()
    if use_cuda:
        word_hidden, sent_hidden = word_hidden.cuda(), sent_hidden.cuda()

    for batch in train_loader:
        if batch[0].size(0) != batch_size:
            continue

        model.zero_grad()
        batch_x = Variable(batch[0])
        batch_y = Variable(batch[1])

        if use_cuda:
            batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

        pred_prob = model(batch_x, word_hidden, sent_hidden)
        loss = crit(pred_prob, batch_y)
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
        opti.step()

        train_loss_mean.append(loss.data[0])
        if step % log_interval ==0:
            val_loss_mean = 0
            word_hidden = model.word_rnn.init_hidden()
            sent_hidden = model.sent_rnn.init_hidden()
            if use_cuda:
                word_hidden, sent_hidden = word_hidden.cuda(), sent_hidden.cuda()

            correct = 0
            for val_batch in val_loader:
                if batch[0].size(0) != batch_size:
                    continue

                batch_x, batch_y = Variable(batch[0], volatile=True), Variable(batch[1])
                if use_cuda:
                    batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

                outputs = model(batch_x, word_hidden, sent_hidden)
                val_loss = crit(outputs, batch_y)
                val_loss_mean += val_loss.data[0]

                _, predicted = torch.max(outputs.data, 1)
                correct += predicted.eq(batch_y.data).cpu().sum()

            train_loss_mean = np.mean(train_loss_mean)
            correct /= float(len(val_loader.dataset))

            val_loss_mean /= float(len(val_loader.dataset))
            print("Epoch: %d, Step: %d, Train Loss: %.2f, Val Loss: %.2f, Val acc: %.2f"%(n_e, step, train_loss_mean, val_loss_mean, correct))

            param1, grad1 = calc_grad_norm(model.parameters(), 1)
            param2, grad2 = calc_grad_norm(model.parameters(), 2)
            print("Param Norm1: %.2f, grad Norm1: %.2f, Param Norm12: %.2f, grad Norm2: %.2f"%(param1, grad1, param2, grad2))

            tensorboard_logger.log_value('train_loss', train_loss_mean, step)
            tensorboard_logger.log_value('val_loss', val_loss_mean, step)
            tensorboard_logger.log_value('val_acc', correct, step)
            tensorboard_logger.log_value('param norm1', param1, step)
            tensorboard_logger.log_value('grad norm1', grad1, step)
            train_loss_mean = []            
        step += 1

In [None]:
def init_w(m):
    torch.nn.init.xavier_uniform(m.weight)
model.apply(init_w)

In [37]:
for i in model.parameters():
    if len(i.size()):
        print(i[0, :10])
    els


2
2
2
1
1
2
2
1
1
2
1
2
2
1
1
2
2
1
1
2
1
2
1
