In [None]:
from Model.NQModel import NQModel
from Model.LossFn import LossFn
import torch
import time
import sklearn
import datetime
import Model.datasetutils as datasetutils
import Model.tensorboardutils as boardutils
import torch.utils.tensorboard as tensorboard
from tqdm import tqdm_notebook as tqdm
import transformers
from transformers import BertModel

## Constants

In [None]:
TensorBoardLocation = 'runs/NQ_TIME:{}'.format(int((time.time() - 1583988084)/60))
TensorBoardLocation

In [None]:
epochs = 1 # no loop 
use_cuda = torch.cuda.is_available()
model_device = torch.device("cuda:0" if use_cuda else "cpu")
bert_device  = torch.device("cuda:0" if use_cuda else "cpu")
writer = tensorboard.SummaryWriter(TensorBoardLocation)

## Dataset

In [None]:
traingen, validgen = datasetutils.get_dataset(num_workers = 4)

In [None]:
num_steps = len(traingen)
val_steps = len(validgen)
num_steps, val_steps

## MODEL

In [None]:
model = NQModel().to(model_device)

In [None]:
optim = torch.optim.SGD(model.parameters(), lr = 0.000005)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=0.0005, epochs=1, steps_per_epoch=len(traingen))

## Confusion Matrix

In [None]:
AnswerTypes = ['Wrong Ans', 'Short Ans', 'Yes No']
YesNoLabels = ['No', 'Yes']

In [None]:
def update_confusion_matrix(ATMatrix, YNMatrix, StartM, EndM, output, target) : 
    predsT = output[0].argmax(dim = 1)
    truthT = target[0].argmax(dim = 1)

    for x, y in zip(predsT, truthT) : 
        ATMatrix[x][y] += 1


    predsYN = (torch.sigmoid(output[3].flatten()) >= 0.5) + 1 -1
    truthYN = target[3].flatten()

    for x, y in zip(predsYN, truthYN) : 
        YNMatrix[x][y] += 1    


    start01 = (torch.sigmoid(output[1].flatten()) >= 0.5) +1 -1
    end01   = (torch.sigmoid(output[2].flatten()) >= 0.5) +1 -1

    startcm = sklearn.metrics.confusion_matrix(target[1].flatten().numpy(), start01)
    endcm   = sklearn.metrics.confusion_matrix(target[2].flatten().numpy(), end01)

    StartM += torch.from_numpy(startcm)
    EndM   += torch.from_numpy(endcm)    

In [None]:
def log_confusion_matrix(matrix, labels, name, step): 
    opfigure = boardutils.confusion_matrix_image(matrix.numpy(), labels)
    writer.add_figure(name, opfigure, step)

def log_matrices(AnsTypeM, YNM, StM, EndM, call_type, steps):
    log_confusion_matrix(AnsTypeM, AnswerTypes, "Answer type confusion matrix" + call_type, steps)
    log_confusion_matrix(YNM, YesNoLabels, "Yes No confusion matrix" + call_type, steps)
    log_confusion_matrix(StM, YesNoLabels, "Start confusion matrix" + call_type, steps)
    log_confusion_matrix(EndM, YesNoLabels, "End confusion matrix" + call_type, steps) 

## Run

In [None]:
loss = LossFn(model_device)
bert_encoder = BertModel.from_pretrained('bert-base-uncased').to(bert_device)

In [None]:
AnswerTypeMatrix = torch.zeros([3,3], requires_grad = False)
YesNoMatrix      = torch.zeros([2,2], requires_grad = False)
StartMatrix      = torch.zeros([2,2], requires_grad = False)
EndMatrix        = torch.zeros([2,2], requires_grad = False)
   
ValAnswerTypeMatrix = torch.zeros([3, 3], requires_grad = False)
ValYesNoMatrix      = torch.zeros([2, 2], requires_grad = False)
ValStartMatrix      = torch.zeros([2, 2], requires_grad = False)
ValEndMatrix        = torch.zeros([2, 2], requires_grad = False)

In [None]:
at_l, start_l, end_l, yn_l = 0, 0, 0, 0
def validate_step(inp_ids, mask, token_types, ans_type, start, end, yes_no): 
    global at_l, start_l, end_l, yn_l
    encoding, _ = bert_encoder(inp_ids.to(bert_device), mask.to(bert_device), token_types.to(bert_device))
    output = model(encoding)  
    # output = model(encoding.to(model_device))  bert and model devices are same

    ## Calculate Loss
    at_l += loss.ans_type(output[0], ans_type.argmax(1).to(model_device)).item()
    start_l += loss.start(output[1], start.type(torch.FloatTensor).to(model_device)).item()
    end_l += loss.end(output[2], end.type(torch.FloatTensor).to(model_device)).item()
    yn_l += loss.yes_no(output[3], yes_no.to(model_device)).item()

    detached_output = (output[0].detach().cpu(), output[1].detach().cpu(), output[2].detach().cpu(), output[3].detach().cpu())
    update_confusion_matrix(ValAnswerTypeMatrix, ValYesNoMatrix, ValStartMatrix, ValEndMatrix, detached_output, (ans_type.detach(), start.detach(), end.detach(), yes_no.detach()))

def validate(val_num) : 
    model.eval()
    start_time = time.time()
    ctr = 0
    with torch.no_grad():
        for inp_ids, mask, token_types, ans_type, start, end, yes_no in tqdm(validgen) : 
            ctr += 1
            inp_ids, mask, token_types, ans_type, start, end, yes_no = inp_ids.squeeze(), mask.squeeze(), token_types.squeeze(), ans_type.squeeze(), start.squeeze(), end.squeeze(), yes_no.squeeze()

            validate_step(inp_ids, mask, token_types, ans_type, start, end, yes_no)
            if ctr >= 1: break
            
    print ("time : " + str(time.time() - start_time) + " steps : " + str(ctr))     
    ## Save loss values
    writer.add_scalars('Loss values Validation',
        {"AT_loss_val" : at_l,"Start_loss_val":start_l, "End_loss_val":end_l, "Yes_no_loss_val":yn_l},
        val_num, time.time())

    log_matrices(ValAnswerTypeMatrix, ValYesNoMatrix, ValStartMatrix, ValEndMatrix, " eval", val_num)    

In [None]:
def train_step(inp_ids, mask, token_types, ans_type, start, end, yes_no, steps): 
#     global writer, optim, scheduler, bert_encoder, model
    with torch.no_grad() :
        encoding, _ = bert_encoder(inp_ids.to(bert_device), mask.to(bert_device), token_types.to(bert_device))

    output = model(encoding) 
#     print ("1 :" + str(torch.cuda.memory_allocated()))
    # output = model(encoding.to(model_device))  bert and model devices are same
    
    ## Calculate Loss
    AT_loss = loss.ans_type(output[0], ans_type.argmax(1).to(model_device))
    Start_loss = loss.start(output[1], start.type(torch.FloatTensor).to(model_device))
    End_loss = loss.end(output[2], end.type(torch.FloatTensor).to(model_device))
    Yes_no_loss = loss.yes_no(output[3], yes_no.to(model_device))
#     print ("2 :" + str(torch.cuda.memory_allocated()))
    
    ## Update model params and optim/sched
    (AT_loss + Start_loss + End_loss + Yes_no_loss).backward()
#     print ("3 :" + str(torch.cuda.memory_allocated()))
    optim.step()
    optim.zero_grad()
#     scheduler.step()  
    
#     print ("4 :" + str(torch.cuda.memory_allocated()))
    if steps%20 == 0: 
        ## Calculate Confusion Matrix
        detached_output = (output[0].detach().cpu(), output[1].detach().cpu(), output[2].detach().cpu(), output[3].detach().cpu())
        update_confusion_matrix(AnswerTypeMatrix, YesNoMatrix, StartMatrix, EndMatrix, detached_output, (ans_type.detach(), start.detach(), end.detach(), yes_no.detach()))
        log_matrices(AnswerTypeMatrix, YesNoMatrix, StartMatrix, EndMatrix, " train", steps)
    
    writer.add_scalars('Loss values',
        {"AT_loss" : AT_loss.item(),"Start_loss":Start_loss.item(), "End_loss":End_loss.item(), "Yes_no_loss":Yes_no_loss.item()},
        steps, time.time())
       
#     print ("5 :" + str(torch.cuda.memory_allocated())) 

def train() : 
    start_time = time.time()
    model.train()
    steps = -1

    for inp_ids, mask, token_types, ans_type, start, end, yes_no in tqdm(traingen) : 
        inp_ids, mask, token_types, ans_type, start, end, yes_no = inp_ids.squeeze(), mask.squeeze(), token_types.squeeze(), ans_type.squeeze(), start.squeeze(), end.squeeze(), yes_no.squeeze()

        steps += 1
        
        train_step(inp_ids, mask, token_types, ans_type, start, end, yes_no,steps)
        
        if (steps%1000 == 0): 
            cur_time = time.time() - start_time
            expected_time = (cur_time*num_steps)/(steps + 1)
            print ("elapsed time : " + str(time.time() - start_time)+ " : expected time : " +  str(expected_time))
            if (steps%10000) : validate(steps/10000)

In [None]:
train()