# Research Project

## Import required libraries

In [1]:
!pip install datasets
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collec

In [2]:
from datasets import list_datasets, load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_polynomial_decay_schedule_with_warmup
from torch.nn import functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
import torch
import math
import numpy as np
import random

import os, sys

os.environ['CURL_CA_BUNDLE'] = ''

import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


## Data Loading

In [3]:
# train_dataset = load_dataset('daily_dialog', split='train')
# validation_dataset = load_dataset('daily_dialog', split='validation')
# test_dataset = load_dataset('daily_dialog', split='test')

### GPT 2

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [4]:
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()
    
    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]
                
            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]
                        
        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1
                
        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()
                
    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])
        
    return new_token_list

In [5]:
dataset = load_dataset('daily_dialog')
train_dialogues = dataset['train']['dialog']
valid_dialogues = dataset['validation']['dialog']
test_dialogues = dataset['test']['dialog']

Downloading builder script:   0%|          | 0.00/4.85k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.49k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.27k [00:00<?, ?B/s]

Downloading and preparing dataset daily_dialog/default to /root/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd...


Downloading data:   0%|          | 0.00/4.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11118 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset daily_dialog downloaded and prepared to /root/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
def load_daily(dataset, tokenizer):
        
    for i, dialogue in enumerate(tqdm(dataset)):
        new_dialogue = []
        for utter in dialogue:
            token_list = tokenizer.tokenize(utter.strip().replace(pre_quote, quotes[1]))
            token_list = process_token_list(token_list)
            text = tokenizer.convert_tokens_to_string(token_list)
            new_dialogue.append(text)
            
        dataset[i] = new_dialogue
    
    utter_num = 0

    for dialogue in dataset:
        utter_num += len(dialogue)
    
    return dataset, utter_num

In [7]:
train_dialogues, num_train = load_daily(train_dialogues, tokenizer)
valid_dialogues, num_valid = load_daily(valid_dialogues, tokenizer)
test_dialogues, num_test = load_daily(test_dialogues, tokenizer)

100%|██████████| 11118/11118 [00:29<00:00, 377.29it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1007.61it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1029.30it/s]


In [8]:
print(f"The number of train dialogues: {len(train_dialogues)}")
print(f"The number of valid dialogues: {len(valid_dialogues)}")    
print(f"The number of test dialogues: {len(test_dialogues)}")    

print(f"The number of train utterances: {num_train}")    
print(f"The number of valid utterances: {num_valid}")
print(f"The number of test utterances: {num_test}")



The number of train dialogues: 11118
The number of valid dialogues: 1000
The number of test dialogues: 1000
The number of train utterances: 87170
The number of valid utterances: 8069
The number of test utterances: 7740


In [9]:
# Extrac ids (input_ids, token_ids) from processed text
def extract_ids(dialogues):

    ids = []
    for dialogue in tqdm(dialogues):
            dialogue_ids = []
            for utter in dialogue:
                tokens = tokenizer.tokenize(utter)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                dialogue_ids.append(token_ids)
            ids.append(dialogue_ids)
            
    assert len(ids) == len(dialogues)

    return ids  

In [10]:
train_ids = extract_ids(train_dialogues)
valid_ids = extract_ids(valid_dialogues) 
test_ids = extract_ids(test_dialogues) 

100%|██████████| 11118/11118 [00:10<00:00, 1029.85it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1248.14it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1294.52it/s]


In [11]:
#Parameters

sp1_token = '<sp1>'
sp2_token = '<sp2>'
bos_token = '<bos>'
max_turns = 5
max_len = 1024
seed = 0
gpu = 0

#Tokeniser
special_tokens = {'bos_token': bos_token,
                'additional_special_tokens': [sp1_token, sp2_token]}

eos_token = tokenizer.eos_token
num_new_tokens = tokenizer.add_special_tokens(special_tokens)

vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
bos_id = vocab[bos_token]
eos_id = vocab[eos_token]
sp1_id = vocab[sp1_token]
sp2_id = vocab[sp2_token]

lr = 2e-5
batch_size = 8
num_workers = 0
num_epochs = 10
warmup_ratio = 0.1
last_epoch = 0
end_command = 'Quit!'
top_p = 0.8


In [12]:
!mkdir 'saved_models'
ckpt_dir = 'saved_models'

In [13]:
class CustomDataset(Dataset):
    def __init__(self, dials):

        self.input_ids = []  # (N, L)
        self.token_type_ids = []  # (N, L)
        self.labels = []  # (N, L)
            
        for dial in tqdm(dials):
            hists = []
            for u, utter in enumerate(dial):
                if u % 2 == 0:
                    hists.append([sp1_id] + utter)
                else:
                    hists.append([sp2_id] + utter)
                    
            for h in range(len(hists)):
                if hists[h][0] == sp2_id:
                    start = max(0, h - max_turns+1)
                    for s in range(start, h):
                        contexts = hists[s:h+1]
                        input_ids = [bos_id] + list(chain.from_iterable(contexts)) + [eos_id]
                        if len(input_ids) <= max_len:
                            start_sp_id, next_sp_id = contexts[0][0], contexts[1][0]
                            token_type_ids = [[start_sp_id] * len(ctx) if c % 2 == 0 else [next_sp_id] * len(ctx) for c, ctx in enumerate(contexts)]
                            assert token_type_ids[-1][0] == sp2_id
                            token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [sp2_id]
                            assert len(input_ids) == len(token_type_ids)
                            
                            labels = [[-100] * len(ctx) if c < len(contexts)-1 else [-100] + ctx[1:] for c, ctx in enumerate(contexts)]
                            assert labels[-1][1:] == contexts[-1][1:]
                            labels = [-100] + list(chain.from_iterable(labels)) + [eos_id]
                            assert len(input_ids) == len(labels)
                            
                            self.input_ids.append(input_ids)
                            self.token_type_ids.append(token_type_ids)
                            self.labels.append(labels)
                            
                            break
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.token_type_ids[idx], self.labels[idx]
    
    
class PadCollate():
    def __init__(self, eos_id):
        self.eos_id = eos_id
        
    def pad_collate(self, batch):
        input_ids, token_type_ids, labels =[], [], []
        for idx, seqs in enumerate(batch):
            input_ids.append(torch.LongTensor(seqs[0]))
            token_type_ids.append(torch.LongTensor(seqs[1]))
            labels.append(torch.LongTensor(seqs[2]))
            
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id)
        token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=self.eos_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    
        return input_ids, token_type_ids, labels

In [14]:
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

In [15]:
#Load Model

if torch.cuda.is_available():
    device = torch.device(f"cuda:{gpu}")
else:
    device = torch.device("cpu")

print("Loading the model...")
fix_seed(seed)
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
model.resize_token_embeddings(vocab_size)

max_len = min(max_len, model.config.n_ctx)

Loading the model...


Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [16]:
# Load optimizer
print("Loading the optimizer...")
optim = torch.optim.AdamW(model.parameters(), lr=lr)

# Load train & valid dataset
print("Loading train & valid data...")

train_set = CustomDataset(train_ids)
valid_set = CustomDataset(valid_ids)
test_set = CustomDataset(test_ids)

ppd = PadCollate(eos_id=eos_id)

train_loader = DataLoader(train_set, 
                            collate_fn=ppd.pad_collate, 
                            shuffle=True, 
                            batch_size=batch_size, 
                            num_workers=num_workers, 
                            pin_memory=True)

valid_loader = DataLoader(valid_set, 
                            collate_fn=ppd.pad_collate,
                            batch_size=batch_size, 
                            num_workers=num_workers, 
                            pin_memory=True)

test_loader = DataLoader(test_set, 
                            collate_fn=ppd.pad_collate,
                            batch_size=batch_size, 
                            num_workers=num_workers, 
                            pin_memory=True)
    
# Calculate total training steps
num_batches = len(train_loader)
total_train_steps = num_epochs * num_batches
warmup_steps = int(warmup_ratio * total_train_steps)

sched = get_polynomial_decay_schedule_with_warmup(
    optim,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_train_steps,
    power=2
)



writer = SummaryWriter()

Loading the optimizer...
Loading train & valid data...


100%|██████████| 11118/11118 [00:01<00:00, 10563.55it/s]
100%|██████████| 1000/1000 [00:00<00:00, 16374.85it/s]
100%|██████████| 1000/1000 [00:00<00:00, 17791.70it/s]


In [17]:
def validation():

    print("Validation processing...")
    model.eval()
            
    valid_losses = []
    valid_ppls = []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(valid_loader)):
            input_ids, token_type_ids, labels = batch
            input_ids, token_type_ids, labels = \
                input_ids.to(device), token_type_ids.to(device), labels.to(device)
            
            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )
            
            loss, logits = outputs[0], outputs[1]
            
            valid_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            valid_ppls.append(ppl)
        
        valid_losses = [loss.item() for loss in valid_losses]
        valid_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in valid_ppls]
        valid_loss = np.mean(valid_losses)
        valid_ppl = np.mean(valid_ppls)
        
        if math.isnan(valid_ppl):
            valid_ppl = 1e+8
            
    return valid_loss, valid_ppl

In [18]:
def train():
    
    fix_seed(seed)  # Fix seed before training
    print("Training starts.")

    best_loss = sys.float_info.max
    last_epoch= 0
    
    start_epoch = last_epoch +1
    for epoch in range(start_epoch, start_epoch+num_epochs):
        model.train()
        
        print(f"#"*50 + f"Epoch: {epoch}" + "#"*50)
        train_losses = []
        train_ppls = []
        for i, batch in enumerate(tqdm(train_loader)):
            input_ids, token_type_ids, labels = batch
            input_ids, token_type_ids, labels = \
                input_ids.to(device), token_type_ids.to(device), labels.to(device)
            
            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )
            
            loss, logits = outputs[0], outputs[1]
            
            optim.zero_grad()
            loss.backward()
            optim.step()
            sched.step()
            
            train_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            train_ppls.append(ppl)
        
        train_losses = [loss.item() for loss in train_losses]
        train_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in train_ppls]
        train_loss = np.mean(train_losses)
        train_ppl = np.mean(train_ppls)
        print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}")
        
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("PPL/train", train_ppl, epoch)
        
        last_epoch += 1
        
        valid_loss, valid_ppl = validation()
            
        if valid_loss < best_loss:
            best_loss = valid_loss
            state_dict = {
                'model_state_dict': model.state_dict(),
                'optim_state_dict': optim.state_dict(),
                'sched_state_dict': sched.state_dict(),
                'loss': best_loss,
                'epoch': last_epoch
            }
            
            torch.save(state_dict, f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            print("*"*10 + "Current best checkpoint is saved." + "*"*10)
            print(f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            
        print(f"Best valid loss: {best_loss}")
        print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}")
        
        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("PPL/valid", valid_ppl, epoch)
        
        writer.add_scalars("Losses", {
            'train': train_loss, 
            'valid': valid_loss,
        }, epoch)
        writer.add_scalars("PPLs", {
            'train': train_ppl,
            'valid': valid_ppl,
        }, epoch)
            
    print("Training finished!")

In [19]:
train()

Training starts.
##################################################Epoch: 1##################################################


100%|██████████| 5205/5205 [08:42<00:00,  9.97it/s]


Train loss: 6.538847804275645 || Train perplexity: 3.231633850593542e+35
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 35.83it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=1_valid_loss=2.7095.ckpt
Best valid loss: 2.709511728454922
Valid loss: 2.709511728454922 || Valid perplexity: 15.906743613516147
##################################################Epoch: 2##################################################


100%|██████████| 5205/5205 [08:33<00:00, 10.13it/s]


Train loss: 2.719154739059243 || Train perplexity: 15.732304244540725
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 37.07it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=2_valid_loss=2.5834.ckpt
Best valid loss: 2.58340733451962
Valid loss: 2.58340733451962 || Valid perplexity: 14.028150634646911
##################################################Epoch: 3##################################################


100%|██████████| 5205/5205 [08:30<00:00, 10.20it/s]


Train loss: 2.544509691624087 || Train perplexity: 13.199640189124116
Validation processing...


100%|██████████| 482/482 [00:12<00:00, 37.28it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=3_valid_loss=2.5255.ckpt
Best valid loss: 2.5255081200995386
Valid loss: 2.5255081200995386 || Valid perplexity: 13.317135094607025
##################################################Epoch: 4##################################################


100%|██████████| 5205/5205 [08:32<00:00, 10.16it/s]


Train loss: 2.4147201054606957 || Train perplexity: 11.56191546790996
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 36.19it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=4_valid_loss=2.4981.ckpt
Best valid loss: 2.4981133752344062
Valid loss: 2.4981133752344062 || Valid perplexity: 13.043940874550847
##################################################Epoch: 5##################################################


100%|██████████| 5205/5205 [08:33<00:00, 10.14it/s]


Train loss: 2.3180094953458203 || Train perplexity: 10.490810293056551
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 36.60it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=5_valid_loss=2.4785.ckpt
Best valid loss: 2.478514183853672
Valid loss: 2.478514183853672 || Valid perplexity: 12.847803089133931
##################################################Epoch: 6##################################################


100%|██████████| 5205/5205 [08:34<00:00, 10.11it/s]


Train loss: 2.250378006839844 || Train perplexity: 9.795134301510828
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 35.92it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=6_valid_loss=2.4765.ckpt
Best valid loss: 2.4765184872872603
Valid loss: 2.4765184872872603 || Valid perplexity: 12.900187119408761
##################################################Epoch: 7##################################################


100%|██████████| 5205/5205 [08:34<00:00, 10.11it/s]


Train loss: 2.2025948225822045 || Train perplexity: 9.331543687532776
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 36.37it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=7_valid_loss=2.4742.ckpt
Best valid loss: 2.474211246393528
Valid loss: 2.474211246393528 || Valid perplexity: 12.919125017783454
##################################################Epoch: 8##################################################


100%|██████████| 5205/5205 [08:34<00:00, 10.11it/s]


Train loss: 2.17496813992602 || Train perplexity: 9.080786132171221
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 36.03it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=8_valid_loss=2.4735.ckpt
Best valid loss: 2.473535792214247
Valid loss: 2.473535792214247 || Valid perplexity: 12.939094784843476
##################################################Epoch: 9##################################################


100%|██████████| 5205/5205 [08:34<00:00, 10.12it/s]


Train loss: 2.1568723358636173 || Train perplexity: 8.909436005611585
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 36.07it/s]


Best valid loss: 2.473535792214247
Valid loss: 2.4758523999408073 || Valid perplexity: 12.989477071524655
##################################################Epoch: 10##################################################


100%|██████████| 5205/5205 [08:35<00:00, 10.11it/s]


Train loss: 2.150291860916658 || Train perplexity: 8.848216480617907
Validation processing...


100%|██████████| 482/482 [00:13<00:00, 36.08it/s]

Best valid loss: 2.473535792214247
Valid loss: 2.476038559096483 || Valid perplexity: 12.995482023326193
Training finished!





In [52]:
def infer():
    model.eval()
    fix_seed(seed)

    with torch.no_grad():
        
        for i, batch in enumerate(tqdm(test_dialogues)):
            print()
            input_hists = []
            for j in range(0, len(batch), 2):

                utter = batch[j]
                print('Utterance: ', utter)

                # utter = input("You: ")
                if utter == end_command:
                    print("Bot: Good bye.")
                    break
            
                input_ids = [sp1_id] + tokenizer.encode(utter)
                input_hists.append(input_ids)
            
                if len(input_hists) >= max_turns:
                    num_exceeded = len(input_hists) - max_turns + 1
                    input_hists = input_hists[num_exceeded:]
                
                input_ids = [bos_id] + list(chain.from_iterable(input_hists)) + [sp2_id]
                start_sp_id = input_hists[0][0]
                next_sp_id = sp1_id if start_sp_id == sp2_id else sp2_id

                assert start_sp_id != next_sp_id
                token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
                assert len(token_type_ids) == len(input_hists)
                token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [sp2_id]
                assert len(input_ids) == len(token_type_ids)
                input_len = len(input_ids)
            
                input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
                token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)
                
                output_ids = nucleus_sampling(input_ids, token_type_ids, input_len)                
            # output_ids = self.model.generate(
            #     input_ids=input_ids, token_type_ids=token_type_ids, pad_token_id=self.args.eos_id,
            #     do_sample=True, top_p=self.args.top_p, max_length=self.args.max_len,
            #     output_hidden_states=True, output_scores=True, return_dict_in_generate=True,
            # ).sequences
            # output_ids = output_ids[0].tolist()[input_len:]
                res = tokenizer.decode(output_ids, skip_special_tokens=True)

                actual_res = batch[j+1]
                print(f"Bot response: {res}")
                print(f"Actual response: {actual_res}")

                input_hists.append([sp2_id] + tokenizer.encode(actual_res))

            if i ==2:
                break
                
def nucleus_sampling(input_ids, token_type_ids, input_len):
    output_ids = []
    for pos in range(input_len, max_len):
        output = model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, pos-1]  # (1, V)
        output = F.softmax(output, dim=-1)  # (1, V)
        
        sorted_probs, sorted_idxs = torch.sort(output, descending=True)
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
        idx_remove = cumsum_probs > top_p
        idx_remove[:, 1:] = idx_remove[:, :-1].clone()
        idx_remove[:, 0] = False
        sorted_probs[idx_remove] = 0.0
        sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)
        
        probs = torch.zeros(output.shape, device=device).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
        idx = torch.multinomial(probs, 1)  # (1, 1)
        
        idx_item = idx.squeeze(-1).squeeze(-1).item()
        output_ids.append(idx_item)
        
        if idx_item == eos_id:
            break
            
        input_ids = torch.cat((input_ids, idx), dim=-1)
        next_type_id = torch.LongTensor([[sp2_id]]).to(device)
        token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
        assert input_ids.shape == token_type_ids.shape
        
    return output_ids

In [53]:
infer()


  0%|          | 0/1000 [00:00<?, ?it/s]


Utterance:  Hey man, You wanna buy some weed?
Bot response: Yeah, I do. I can't stop smoking, I don't want to.
Actual response: Some what?
Utterance:  Weed! You know? Pot, Ganja, Mary Jane some chronic!
Bot response: So what?
Actual response: Oh, Umm, No thanks.
Utterance:  I also have blow if you prefer to do a few lines.
Bot response: No thanks. No thanks at all.
Actual response: No, I am ok, Really.
Utterance:  Come on man! I even got dope and acid! Try some!
Bot response: Alright, Thanks for helping out.
Actual response: Do you really have all of these drugs? Where do you get them from?
Utterance:  I got my connections! Just tell me what you want and I'll even give you one ounce for free.
Bot response: Ok, I'm gonna give you a try.
Actual response: Sounds good! Let's see, I want.
Utterance:  Yeah?


  0%|          | 1/1000 [00:01<18:56,  1.14s/it]

Bot response: Here you go. Just give me two ounces, Please.
Actual response: I want you to put your hands behind your head! You are under arrest!

Utterance:  The taxi drivers are on strike again.
Bot response: Why are they going to strike?
Actual response: What for?
Utterance:  They want the government to reduce the price of the gasoline.


  0%|          | 2/1000 [00:01<12:02,  1.38it/s]

Bot response: Oh, You mean the price of unleaded petrol?
Actual response: It is really a hot potato.

Utterance:  We've managed to reduce our energy consumption in our factory by about 15 per cent in the last two years.
Bot response: I can't agree with you more.
Actual response: That's excellent. How have you managed that?
Utterance:  Mainly because we've invested in a heat recovery system.
Bot response: Cooling systems are something that we're very keen on. We don't want any extra waste from outside the factory.
Actual response: What does that mean exactly?
Utterance:  Well, We use the exhaust gases from our printing presses to provide energy to heat our dryers.
Bot response: Cooling my printer? Ok, Let's get started.
Actual response: What other sources of energy do you use?
Utterance:  We don't use any fossil fuels. Most of our power comes from hydro-electric plants. We're hoping to use even more energy from alternative sources in the future - perhaps even wind power.


  0%|          | 2/1000 [00:02<22:57,  1.38s/it]


IndexError: ignored

In [56]:
for batch in test_loader:
    print(batch)
    break

[tensor([[50257, 50258, 10814,   582,    11,   921, 18869,  2822,   617, 20349,
            30, 50259,  4366,   644,    30, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [50257, 50258, 10814,   582,    11,   921, 18869,  2822,   617, 20349,
            30, 50259,  4366,   644,    30, 50258,  1135,   276,     0,   921,
           760,    30,  6902,    11, 23207,  6592,    11,  5335, 12091,   617,
         10726,     0, 50259,  5812,    11,   471,  3020,    11,  1400,  5176,
            13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,

In [None]:
from datasets import load_metric
bleu_metric = load_metric('bleu')