In [1]:
import shutil
import os
import queue
import argparse
import random
import time
import logging
logger = logging.getLogger(__name__)
import multiprocessing

import numpy as np
from tqdm.notebook import tqdm
import wandb
from sklearn.metrics import roc_auc_score

import torch
from torch import optim
from torch.utils.data import DataLoader
from transformers import BlenderbotTokenizer, BertTokenizerFast, RobertaTokenizerFast, GPT2TokenizerFast
from transformers import get_linear_schedule_with_warmup
# from rezero.transformer import RZTXEncoderLayer

from models import SMI, is_ddp_module, WrappedSMI
from utils import GEN_UNIQ_RUN_ID, pprint_args
from datautils import DialogData, RMaxData
from nltk.translate.bleu_score import sentence_bleu

In [2]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer.add_tokens('__eou__')

1

In [4]:
%%capture
data_path = 'data/dailydialog'
train_data = DialogData(data_path=data_path + '/dialogues_train.txt', tokenizer=tokenizer)
valid_data = DialogData(data_path=data_path + '/dialogues_valid.txt', tokenizer=tokenizer)
test_data = DialogData(data_path=data_path + '/dialogues_test.txt', tokenizer=tokenizer)

# train_data = add_sep(train_data)
# valid_data = add_sep(valid_data)
# test_data  = add_sep(test_data)

In [5]:
%%capture
import torch.nn as nn
MAX_CTX_LEN = 300
MAX_RESP_LEN = 301
def proc_data(data):
    mod_data = []
    for i, j in enumerate(data):
        ctx = torch.cat((j[0][-MAX_CTX_LEN:], torch.tensor([102])))
        resp = torch.cat((j[1][1:][-MAX_RESP_LEN:], torch.tensor([102])))
        label = j[1][-MAX_RESP_LEN:]
        ctx_att = torch.ones(len(ctx))
        resp_att = torch.ones(len(resp))
        ctx_ids = nn.functional.pad(ctx, (0,MAX_CTX_LEN-len(ctx)+1))
        resp_ids = nn.functional.pad(resp, (0,MAX_RESP_LEN-len(resp)))
        label = nn.functional.pad(label, (0,MAX_RESP_LEN-len(resp)))
        ctx_att_ids = nn.functional.pad(ctx_att, (0,MAX_CTX_LEN-len(ctx)))
        resp_att_ids = nn.functional.pad(resp_att, (0,MAX_RESP_LEN-len(resp)))
        mod_data.append({'ctx_ids': ctx_ids,
                       'ctx_att': ctx_att_ids,
                       'resp_ids': resp_ids,
                       'resp_att': resp_att_ids,
                        'label':label})
    return mod_data        

In [6]:
train_dataloader = DataLoader(proc_data(train_data), batch_size=8, shuffle=False, num_workers=0)
valid_dataloader = DataLoader(proc_data(valid_data), batch_size=8, shuffle=False, num_workers=0)
test_dataloader = DataLoader(proc_data(test_data), batch_size=8, shuffle=False, num_workers=0)

Token indices sequence length is longer than the specified maximum sequence length for this model (521 > 512). Running this sequence through the model will result in indexing errors


## Base Model

In [7]:
!gpustat

[1m[37mdevi                   [m  Wed Feb  9 02:17:42 2022  [1m[30m470.63.01[m
[36m[0][m [34mTesla P100-PCIE-12GB[m |[1m[31m 61'C[m, [1m[32m 96 %[m | [36m[1m[33m 5939[m / [33m12198[m MB | [1m[30mmithundas[m([33m1671M[m) [1m[30mprasanta-am[m([33m1745M[m) [1m[30mprasanta-am[m([33m2517M[m) [1m[30mgdm[m([33m4M[m)
[36m[1][m [34mTesla P100-PCIE-16GB[m |[31m 49'C[m, [32m  0 %[m | [36m[1m[33m    6[m / [33m16280[m MB | [1m[30mgdm[m([33m4M[m)


In [8]:
from torch.nn import Transformer, Softmax
from torch.optim import AdamW
class EnDModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vocab_size = len(tokenizer)
        self.softmax = nn.Softmax(dim=-1)
        self.d_model = 512
        self.in_model = Transformer(self.d_model,num_encoder_layers=4, num_decoder_layers=4)
        self.output_linear = nn.Linear(self.d_model, self.vocab_size)
            
    def forward(self, src_input, trg_input, e_mask=None, d_mask=None):
        self.in_output = self.in_model(src_input, trg_input)
        output = self.softmax(self.output_linear(self.in_output))
        
        return output
        

In [9]:
embedding = nn.Embedding(len(tokenizer), 512)
model = EnDModel()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss(ignore_index = 0)
# Define Optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

In [10]:
use_loss = 'baseline'
train_loss_epoch = []
dev_loss_epoch = []
max_epochs = 10
write_step = 1
best_metric = 0
vocab_len = len(tokenizer)
train_losses = []
dev_losses = []
best_train_loss = 0
best_dev_loss = 0

for epoch in tqdm(range(max_epochs)):
    train_loss = 0
    dev_loss = 0
    train_loss_set = []
    dev_loss_set = []
    train_bleu1 = 0
    train_bleu2 = 0
    train_bleu3 = 0
    dev_bleu1 = 0
    dev_bleu2 = 0
    dev_bleu3 = 0
    train_gold_resp = []
    train_gen_resp = []
    dev_gold_resp = []
    dev_gen_resp = []
    
    model.train()
    for i, batch in tqdm(enumerate(train_dataloader),total=len(train_dataloader)):
        ctx_ids = embedding(batch['ctx_ids']).to(device)
        resp_ids = embedding(batch['resp_ids']).to(device)
        
        optimizer.zero_grad()
        output = model(ctx_ids,resp_ids)
        
        output_ids = torch.argmax(output,dim=2)
        resp_shape = batch['resp_ids'].shape
        loss = criterion(output.view(-1, vocab_len), batch['label'].view(resp_shape[0]*resp_shape[1]).to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        train_gold_resp.extend(tokenizer.batch_decode(batch['resp_ids'].cpu().tolist(), skip_special_tokens=True))
        train_gen_resp.extend(tokenizer.batch_decode(output_ids.cpu().tolist(), skip_special_token = True))
        
        
    train_losses.append(train_loss/len(train_dataloader))
    print("Train Loss", train_loss/len(train_dataloader))
        
        
    
    model.eval()
    with torch.no_grad():
        for i, batch in tqdm(enumerate(valid_dataloader),total=len(valid_dataloader)):
            ctx_ids = embedding(batch['ctx_ids']).to(device)
            resp_ids = embedding(batch['resp_ids']).to(device)

            output = model(ctx_ids,resp_ids)
        
            output_ids = torch.argmax(output,dim=2)
            resp_shape = batch['resp_ids'].shape
            loss = criterion(output.view(-1, vocab_len), batch['label'].view(resp_shape[0]*resp_shape[1]).to(device))
            dev_loss += loss.item()
        
            dev_gold_resp.extend(tokenizer.batch_decode(batch['resp_ids'].cpu().tolist(), skip_special_tokens=True))
            dev_gen_resp.extend(tokenizer.batch_decode(output_ids.cpu().tolist(), skip_special_token = True))
            
        dev_losses.append(dev_loss/len(valid_dataloader))
        print("Dev Loss", dev_loss/len(valid_dataloader))
        path = os.getcwd() + '/baseline_s2s.pth'
        torch.save(model, path)
        if best_dev_loss < dev_loss/len(valid_dataloader):
            torch.save(model, os.getcwd() + '/baseline_s2s_best_loss.pth')
            best_dev_loss = dev_loss/len(valid_dataloader)
        
        
        

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/9507 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Train Loss 10.174951376794102


  0%|          | 0/884 [00:00<?, ?it/s]

Dev Loss 10.174220947118906


  0%|          | 0/9507 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

Train Loss 10.174237710814928


  0%|          | 0/884 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

Train Loss 10.174237710814928


  0%|          | 0/884 [00:00<?, ?it/s]

Dev Loss 10.174220103483934


  0%|          | 0/9507 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

Train Loss 10.174237710814928


  0%|          | 0/884 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [15]:
dev_gold_resp[70]

'go right ahead.'

In [16]:
dev_gen_resp[60]

'[CLS]............................................................................................................................................................................................................................................................................................................'

In [17]:
train_losses

[10.174951376794102,
 10.17423791214283,
 10.174237710814928,
 10.174237710814928,
 10.174237710814928,
 10.174237710814928,
 10.174237710814928,
 10.174237710814928,
 10.174237710814928,
 10.174237710814928]

In [18]:
dev_losses

[10.174220947118906,
 10.174220103483934,
 10.174220103483934,
 10.174220103483934,
 10.174220103483934,
 10.174220103483934,
 10.174220103483934,
 10.174220103483934,
 10.174220103483934,
 10.174220103483934]

In [19]:
batch

{'ctx_ids': tensor([[ 101, 6160, 1010,  ...,    0,    0,    0],
         [ 101, 6160, 1010,  ...,    0,    0,    0],
         [ 101, 6160, 1010,  ...,    0,    0,    0],
         [ 101, 6160, 1010,  ...,    0,    0,    0],
         [ 101, 6160, 1010,  ...,    0,    0,    0]]),
 'ctx_att': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]]),
 'resp_ids': tensor([[2821, 1010, 2009,  ...,    0,    0,    0],
         [2053, 1010, 2065,  ...,    0,    0,    0],
         [2129, 2172, 2003,  ...,    0,    0,    0],
         [2048, 4595, 1012,  ...,    0,    0,    0],
         [2821, 1010, 2009,  ...,    0,    0,    0]]),
 'resp_att': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]]),
 'labe

In [None]:
def metrics(gold,gen):
    bleu1, bleu2, bleu3 = 0,0,0
    for i, j in tqdm(enumerate(gold), total = len(gold)):
        bleu1 += sentence_bleu(j.split(" "), gen[i].split(" "), weights=(1, 0, 0, 0))
        bleu2 += sentence_bleu(j.split(" "), gen[i].split(" "), weights=(0.5, 0.5, 0, 0))
        bleu3 += sentence_bleu(j.split(" "), gen[i].split(" "), weights=(0.34, 0.33, 0.33, 0))
    return bleu1/len(gold), bleu2/len(gold), bleu3/len(gold)