# Research Project

## Import required libraries

In [None]:
!pip install datasets
!pip install transformers
!pip install evaluate
!pip install rouge_score bert_score sacrebleu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from datasets import list_datasets, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, get_polynomial_decay_schedule_with_warmup
from torch.nn import functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
import torch
import math
import numpy as np
import random
import datasets

#For networking purposes
import os, sys
os.environ['CURL_CA_BUNDLE'] = ''

import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


In [None]:
selected_model = 'dialoGPT' #''gpt2' 
dataset_name = 'empathetic_dialogues'#'daily_dialog' 

## Tokenizer

In [None]:
if selected_model == 'dialoGPT':
    tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
elif selected_model == 'gpt2':
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
else:
    print('No tokenizer')

In [None]:
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

# For empathetic dialogues
exclude_symbol = "_conv"
comma_symbol = "_comma_"

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()
    
    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]
                
            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]
                        
        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1
                
        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()
                
    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])
        
    return new_token_list

## Load dataset

In [None]:
if dataset_name == 'daily_dialog':
    print('Loading ', dataset_name)
    dataset = load_dataset('daily_dialog')
    train_dialogues = dataset['train']['dialog']
    valid_dialogues = dataset['validation']['dialog']
    test_dialogues = dataset['test']['dialog']
elif dataset_name == 'empathetic_dialogues':
    print('Loading ', dataset_name)
    dataset = load_dataset('empathetic_dialogues')
    train_dialogues = dataset['train']
    valid_dialogues = dataset['validation']
    test_dialogues = dataset['test']
else:
    print('No dataset selected')

Loading  empathetic_dialogues




  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
def load_empathetic(dataset, tokenizer):
    
    total_utters = dataset['utterance']
    total_conv_ids = dataset['conv_id'] 
    total_speaker_ids = dataset['speaker_idx']
    
    assert len(total_utters) == len(total_conv_ids) and len(total_conv_ids) == len(total_speaker_ids)
    
    num = 0
    
    conv_dict = {}
    cur_speaker_idx = -1
    for i, utter in enumerate(tqdm(total_utters)):
        conv_id = total_conv_ids[i]
        speaker_idx = total_speaker_ids[i]
        
        utter_modified = utter.strip().replace(comma_symbol, ',')
        new_token_list = process_token_list(tokenizer.tokenize(utter_modified))
        text = tokenizer.convert_tokens_to_string(new_token_list)
        
        if exclude_symbol in utter:
            continue
        
        if conv_id not in conv_dict:
            conv_dict[conv_id] = []
            cur_speaker_idx = -1

        if cur_speaker_idx != speaker_idx:
            conv_dict[conv_id].append(text)
            cur_speaker_idx = speaker_idx
        else:
            conv_dict[conv_id][-1] += f" {text}"
    
    utter_num = 0
    dialogues = []
    
    for i, (conv_id, utter_list) in enumerate(conv_dict.items()):
        utter_num += len(utter_list)
        dialogues.append(utter_list)
            
    return dialogues, utter_num

def load_daily(dataset, tokenizer):
        
    for i, dialogue in enumerate(tqdm(dataset)):
        new_dialogue = []
        for utter in dialogue:
            token_list = tokenizer.tokenize(utter.strip().replace(pre_quote, quotes[1]))
            token_list = process_token_list(token_list)
            text = tokenizer.convert_tokens_to_string(token_list)
            new_dialogue.append(text)
            
        dataset[i] = new_dialogue
    
    utter_num = 0

    for dialogue in dataset:
        utter_num += len(dialogue)
    
    return dataset, utter_num

In [None]:
if dataset_name == 'daily_dialog':
    train_dialogues, num_train = load_daily(train_dialogues, tokenizer)
    valid_dialogues, num_valid = load_daily(valid_dialogues, tokenizer)
    test_dialogues, num_test = load_daily(test_dialogues, tokenizer)
elif dataset_name == 'empathetic_dialogues':
    train_dialogues, num_train = load_empathetic(train_dialogues, tokenizer)
    valid_dialogues, num_valid = load_empathetic(valid_dialogues, tokenizer)
    test_dialogues, num_test = load_empathetic(test_dialogues, tokenizer)

  2%|▏         | 1443/76673 [00:00<00:10, 7247.66it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (6234 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 76673/76673 [00:11<00:00, 6916.31it/s]
100%|██████████| 12030/12030 [00:02<00:00, 5088.97it/s]
100%|██████████| 10943/10943 [00:02<00:00, 4032.50it/s]


In [None]:
print(f"The number of train dialogues: {len(train_dialogues)}")
print(f"The number of valid dialogues: {len(valid_dialogues)}")    
print(f"The number of test dialogues: {len(test_dialogues)}")    

print(f"The number of train utterances: {num_train}")    
print(f"The number of valid utterances: {num_valid}")
print(f"The number of test utterances: {num_test}")



The number of train dialogues: 17793
The number of valid dialogues: 2759
The number of test dialogues: 2540
The number of train utterances: 76622
The number of valid utterances: 12025
The number of test utterances: 10939


In [None]:
# Extrac ids (input_ids, token_ids) from processed text
def extract_ids(dialogues):

    ids = []
    for dialogue in tqdm(dialogues):
            dialogue_ids = []
            for utter in dialogue:
                tokens = tokenizer.tokenize(utter)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                dialogue_ids.append(token_ids)
            ids.append(dialogue_ids)
            
    assert len(ids) == len(dialogues)

    return ids  

In [None]:
train_ids = extract_ids(train_dialogues)
valid_ids = extract_ids(valid_dialogues) 
# test_ids = extract_ids(test_dialogues) 

100%|██████████| 17793/17793 [00:09<00:00, 1817.94it/s]
100%|██████████| 2759/2759 [00:02<00:00, 1313.96it/s]


In [None]:
#Parameters

sp1_token = '<sp1>'
sp2_token = '<sp2>'
bos_token = '<bos>'
max_turns = 5
max_len = 1024
seed = 0
gpu = 0

#Tokeniser
special_tokens = {'bos_token': bos_token,
                'additional_special_tokens': [sp1_token, sp2_token]}

eos_token = tokenizer.eos_token
num_new_tokens = tokenizer.add_special_tokens(special_tokens)

vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
bos_id = vocab[bos_token]
eos_id = vocab[eos_token]
sp1_id = vocab[sp1_token]
sp2_id = vocab[sp2_token]

lr = 2e-5
batch_size = 8
num_workers = 0
num_epochs = 8
warmup_ratio = 0.1
last_epoch = 0
end_command = 'Quit!'
top_p = 0.8


In [None]:
!mkdir 'saved_models'
ckpt_dir = 'saved_models'

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dials):

        self.input_ids = []  # (N, L)
        self.token_type_ids = []  # (N, L)
        self.labels = []  # (N, L)
            
        for dial in tqdm(dials):
            hists = []
            for u, utter in enumerate(dial):
                if u % 2 == 0:
                    hists.append([sp1_id] + utter)
                else:
                    hists.append([sp2_id] + utter)

            # print(hists) 
            # print() 
            for h in range(len(hists)):
                if hists[h][0] == sp2_id:
                    # print(hists[h])
                    start = max(0, h - max_turns+1)
                    # print('start: ', start, ' to: ', h)
                    for s in range(start, h):
                        contexts = hists[s:h+1]
                        # print('Context: ', contexts)
                        input_ids = [bos_id] + list(chain.from_iterable(contexts)) + [eos_id]
                        if len(input_ids) <= max_len:
                            start_sp_id, next_sp_id = contexts[0][0], contexts[1][0]
                            token_type_ids = [[start_sp_id] * len(ctx) if c % 2 == 0 else [next_sp_id] * len(ctx) for c, ctx in enumerate(contexts)]
                            # print('token_type_ids 1: ', token_type_ids)
                            # print('LEN 1: ', len(token_type_ids))
                            # print('len input_ids', len(input_ids))
                            assert token_type_ids[-1][0] == sp2_id
                            token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [sp2_id]
                            # print('token_type_ids 2: ', token_type_ids)
                            assert len(input_ids) == len(token_type_ids)
                            
                            labels = [[-100] * len(ctx) if c < len(contexts)-1 else [-100] + ctx[1:] for c, ctx in enumerate(contexts)]
                            # print('labels 1: ', labels)
                            assert labels[-1][1:] == contexts[-1][1:]
                            labels = [-100] + list(chain.from_iterable(labels)) + [eos_id]
                            # print('labels 2: ', labels)
                            assert len(input_ids) == len(labels)
                            
                            self.input_ids.append(input_ids)
                            self.token_type_ids.append(token_type_ids)
                            self.labels.append(labels)
                            
                            break
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.token_type_ids[idx], self.labels[idx]
    
    
class PadCollate():
    def __init__(self, eos_id):
        self.eos_id = eos_id
        
    def pad_collate(self, batch):
        input_ids, token_type_ids, labels =[], [], []
        for idx, seqs in enumerate(batch):
            input_ids.append(torch.LongTensor(seqs[0]))
            token_type_ids.append(torch.LongTensor(seqs[1]))
            labels.append(torch.LongTensor(seqs[2]))
            
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id)
        token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=self.eos_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    
        return input_ids, token_type_ids, labels

In [None]:
print(eos_id, bos_id, sp1_id, sp2_id)

50256 50257 50258 50259


In [None]:
print(train_dialogues[0])

['I remember going to see the fireworks with my best friend. It was the first time we ever spent time alone together. Although there was a lot of people, We felt like the only people in the world.', 'Was this a friend you were in love with, Or just a best friend?', 'This was a best friend. I miss her.', 'Where has she gone?', 'We no longer talk.', 'Oh was this something that happened because of an argument?']


In [None]:
#Debugging purposes
debug_dialog = CustomDataset([train_ids[0]])

100%|██████████| 1/1 [00:00<00:00, 4194.30it/s]


In [None]:
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

In [None]:
#Load Model

if torch.cuda.is_available():
    device = torch.device(f"cuda:{gpu}")
    print('Using GPU')
else:
    device = torch.device("cpu")
    print('Using CPU')

Using GPU


In [None]:
print("Loading the model: ", selected_model)
fix_seed(seed)

if selected_model == 'dialoGPT':
    model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small").to(device)
elif selected_model == 'gpt2':
    model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
else:
    print('No model')

model.resize_token_embeddings(vocab_size)
max_len = min(max_len, model.config.n_ctx)

Loading the model:  dialoGPT


In [None]:
#Load from checkpoint
ckpt = torch.load("/content/saved_models/best_ckpt_epoch=4_valid_loss=2.5344.ckpt", map_location=device)
model.load_state_dict(ckpt['model_state_dict'])

<All keys matched successfully>

In [None]:
# Load optimizer
print("Loading the optimizer...")
optim = torch.optim.AdamW(model.parameters(), lr=lr)

Loading the optimizer...


In [None]:
# Load train & valid dataset
print("Loading train & valid data...")

train_set = CustomDataset(train_ids)
valid_set = CustomDataset(valid_ids)
# test_set = CustomDataset(test_ids)

ppd = PadCollate(eos_id=eos_id)

train_loader = DataLoader(train_set, 
                            collate_fn=ppd.pad_collate, 
                            shuffle=True, 
                            batch_size=batch_size, 
                            num_workers=num_workers, 
                            pin_memory=True)

valid_loader = DataLoader(valid_set, 
                            collate_fn=ppd.pad_collate,
                            batch_size=batch_size, 
                            num_workers=num_workers, 
                            pin_memory=True)

# test_loader = DataLoader(test_set, 
#                             collate_fn=ppd.pad_collate,
#                             batch_size=batch_size, 
#                             num_workers=num_workers, 
#                             pin_memory=True)
    
# Calculate total training steps
num_batches = len(train_loader)
total_train_steps = num_epochs * num_batches
warmup_steps = int(warmup_ratio * total_train_steps)

sched = get_polynomial_decay_schedule_with_warmup(
    optim,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_train_steps,
    power=2
)

writer = SummaryWriter()

Loading train & valid data...


100%|██████████| 17793/17793 [00:00<00:00, 23742.94it/s]
100%|██████████| 2759/2759 [00:00<00:00, 40248.91it/s]


In [None]:
def validation():

    print("Validation processing...")
    model.eval()
            
    valid_losses = []
    valid_ppls = []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(valid_loader)):
            input_ids, token_type_ids, labels = batch
            input_ids, token_type_ids, labels = \
                input_ids.to(device), token_type_ids.to(device), labels.to(device)
            
            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )
            
            loss, logits = outputs[0], outputs[1]
            
            valid_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            valid_ppls.append(ppl)
        
        valid_losses = [loss.item() for loss in valid_losses]
        valid_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in valid_ppls]
        valid_loss = np.mean(valid_losses)
        valid_ppl = np.mean(valid_ppls)
        
        if math.isnan(valid_ppl):
            valid_ppl = 1e+8
            
    return valid_loss, valid_ppl

In [None]:
def train():
    
    fix_seed(seed)  # Fix seed before training
    print("Training starts.")

    best_loss = sys.float_info.max
    last_epoch= 0
    
    start_epoch = last_epoch +1
    for epoch in range(start_epoch, start_epoch+num_epochs):
        model.train()
        
        print(f"#"*50 + f"Epoch: {epoch}" + "#"*50)
        train_losses = []
        train_ppls = []
        for i, batch in enumerate(tqdm(train_loader)):
            input_ids, token_type_ids, labels = batch
            input_ids, token_type_ids, labels = \
                input_ids.to(device), token_type_ids.to(device), labels.to(device)
            
            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )
            
            loss, logits = outputs[0], outputs[1]
            
            optim.zero_grad()
            loss.backward()
            optim.step()
            sched.step()
            
            train_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            train_ppls.append(ppl)
        
        train_losses = [loss.item() for loss in train_losses]
        train_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in train_ppls]
        train_loss = np.mean(train_losses)
        train_ppl = np.mean(train_ppls)
        print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}")
        
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("PPL/train", train_ppl, epoch)
        
        last_epoch += 1
        
        valid_loss, valid_ppl = validation()
            
        if valid_loss < best_loss:
            best_loss = valid_loss
            state_dict = {
                'model_state_dict': model.state_dict(),
                'optim_state_dict': optim.state_dict(),
                'sched_state_dict': sched.state_dict(),
                'loss': best_loss,
                'epoch': last_epoch
            }
            
            torch.save(state_dict, f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            print("*"*10 + "Current best checkpoint is saved." + "*"*10)
            print(f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            
        print(f"Best valid loss: {best_loss}")
        print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}")
        
        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("PPL/valid", valid_ppl, epoch)
        
        writer.add_scalars("Losses", {
            'train': train_loss, 
            'valid': valid_loss,
        }, epoch)
        writer.add_scalars("PPLs", {
            'train': train_ppl,
            'valid': valid_ppl,
        }, epoch)
            
    print("Training finished!")

In [None]:
train()

Training starts.
##################################################Epoch: 1##################################################


100%|██████████| 4579/4579 [19:40<00:00,  3.88it/s]


Train loss: 3.4754341883960316 || Train perplexity: 136.05326380107985
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.84it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=1_valid_loss=2.6207.ckpt
Best valid loss: 2.6207386190149964
Valid loss: 2.6207386190149964 || Valid perplexity: 14.652000994909377
##################################################Epoch: 2##################################################


100%|██████████| 4579/4579 [19:47<00:00,  3.86it/s]


Train loss: 2.6777115453406437 || Train perplexity: 15.279803808507609
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.88it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=2_valid_loss=2.5557.ckpt
Best valid loss: 2.5557342567363706
Valid loss: 2.5557342567363706 || Valid perplexity: 13.708324942602163
##################################################Epoch: 3##################################################


100%|██████████| 4579/4579 [19:46<00:00,  3.86it/s]


Train loss: 2.524160280111029 || Train perplexity: 13.014129212652916
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.86it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=3_valid_loss=2.5367.ckpt
Best valid loss: 2.5366925942797622
Valid loss: 2.5366925942797622 || Valid perplexity: 13.476058622702164
##################################################Epoch: 4##################################################


100%|██████████| 4579/4579 [19:44<00:00,  3.86it/s]


Train loss: 2.4330872132492316 || Train perplexity: 11.846613157521848
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.88it/s]


**********Current best checkpoint is saved.**********
saved_models/best_ckpt_epoch=4_valid_loss=2.5344.ckpt
Best valid loss: 2.5344023479133093
Valid loss: 2.5344023479133093 || Valid perplexity: 13.472681053546296
##################################################Epoch: 5##################################################


100%|██████████| 4579/4579 [19:45<00:00,  3.86it/s]


Train loss: 2.3745743261723833 || Train perplexity: 11.147166979211365
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.86it/s]


Best valid loss: 2.5344023479133093
Valid loss: 2.5389193782285484 || Valid perplexity: 13.566308946502643
##################################################Epoch: 6##################################################


100%|██████████| 4579/4579 [19:49<00:00,  3.85it/s]


Train loss: 2.3399209027780823 || Train perplexity: 10.764859713433376
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.81it/s]


Best valid loss: 2.5344023479133093
Valid loss: 2.543360740363765 || Valid perplexity: 13.648788542974563
##################################################Epoch: 7##################################################


100%|██████████| 4579/4579 [19:47<00:00,  3.86it/s]


Train loss: 2.3189223773303995 || Train perplexity: 10.523445773057007
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.86it/s]


Best valid loss: 2.5344023479133093
Valid loss: 2.5462398779492417 || Valid perplexity: 13.700725489303846
##################################################Epoch: 8##################################################


100%|██████████| 4579/4579 [19:47<00:00,  3.86it/s]


Train loss: 2.309843450999463 || Train perplexity: 10.428232942457464
Validation processing...


100%|██████████| 714/714 [00:55<00:00, 12.85it/s]


Best valid loss: 2.5344023479133093
Valid loss: 2.546788294275268 || Valid perplexity: 13.710986280307717
Training finished!


In [None]:
window = 5

def infer():
    model.eval()
    fix_seed(seed)

    generated_responses = []
    actual_responses = []

    with torch.no_grad():
        
        for i, batch in enumerate(tqdm(test_dialogues)):

            # print()
            # print(batch)
            input_hists = []
            context = []
            for j in range(0, len(batch)): #Note: dialogues < window won't be processed

                #Set speaker 1 or speaker 2
                sp_id = sp1_id if j % 2 == 0 else sp2_id

                #Get utterance
                utter = batch[j]
            
                input_ids = [sp_id] + tokenizer.encode(utter)
                input_hists.append(input_ids)

                #Context just for debugging
                context.append(utter)
                
                if len(input_hists) < window:
                    continue
                elif len(input_hists) > window:
                    input_hists = input_hists[-window:]
                    context = context[-window:]
                
                # Debugging
                if i % 200 == 0:
                    print()
                    print('Context:')
                    for c in context:
                        print(c)
                
                # print('input_hists: ', input_hists)
                # print('len input_hists: ', len(input_hists))

                start_sp_id = input_hists[0][0]
                next_sp_id = sp1_id if start_sp_id == sp2_id else sp2_id
                assert start_sp_id != next_sp_id

                input_ids = [bos_id] + list(chain.from_iterable(input_hists)) + [next_sp_id] #Because window is 5, so 6th utter is = sp2
                # print('input_hists with bos: ', input_ids)
                
                token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
                assert len(token_type_ids) == len(input_hists)
                token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [next_sp_id]
                # print('token_type_ids: ', token_type_ids)
                assert len(input_ids) == len(token_type_ids)
                input_len = len(input_ids)
            
                input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
                token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)
                
                # next_sp_id = sp2_id if input_hists[-1][0] == sp1_id else sp1_id

                output_ids = nucleus_sampling(input_ids, token_type_ids, input_len, next_sp_id)                
            # output_ids = self.model.generate(
            #     input_ids=input_ids, token_type_ids=token_type_ids, pad_token_id=self.args.eos_id,
            #     do_sample=True, top_p=self.args.top_p, max_length=self.args.max_len,
            #     output_hidden_states=True, output_scores=True, return_dict_in_generate=True,
            # ).sequences
            # output_ids = output_ids[0].tolist()[input_len:]
                res = tokenizer.decode(output_ids, skip_special_tokens=True)

                if j == len(batch) - 1:  # No subsequent sentence
                    continue
                
                actual_res = batch[j+1]
                
                if i % 200 == 0:
                    print()
                    print(f"Bot response: {res}")
                    print(f"Actual response: {actual_res}")

                generated_responses.append(res)
                actual_responses.append(actual_res)

                # input_hists.append([next_sp_id] + tokenizer.encode(actual_res))
                # context.append(actual_res)

                # print('final input_hists: ', input_hists)    

               
    
    return generated_responses, actual_responses

            
                
def nucleus_sampling(input_ids, token_type_ids, input_len, next_sp_id):
    output_ids = []
    for pos in range(input_len, max_len):
        output = model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, pos-1]  # (1, V)
        output = F.softmax(output, dim=-1)  # (1, V)
        
        sorted_probs, sorted_idxs = torch.sort(output, descending=True)
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
        idx_remove = cumsum_probs > top_p
        idx_remove[:, 1:] = idx_remove[:, :-1].clone()
        idx_remove[:, 0] = False
        sorted_probs[idx_remove] = 0.0
        sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)
        
        probs = torch.zeros(output.shape, device=device).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
        idx = torch.multinomial(probs, 1)  # (1, 1)
        
        idx_item = idx.squeeze(-1).squeeze(-1).item()
        output_ids.append(idx_item)
        
        if idx_item == eos_id:
            break
            
        input_ids = torch.cat((input_ids, idx), dim=-1)
        next_type_id = torch.LongTensor([[next_sp_id]]).to(device)
        token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
        assert input_ids.shape == token_type_ids.shape
        
    return output_ids

In [None]:
generated_responses, actual_responses = infer()

  0%|          | 0/2540 [00:00<?, ?it/s]


Context:
Yeah about 10 years ago I had a horrifying experience. It was 100% their fault but they hit the water barrels and survived. They had no injuries but they almost ran me off the road.
Did you suffer any injuries?
No I wasn't hit. It turned out they were drunk. I felt guilty but realized it was his fault.
Why did you feel guilty? People really shouldn't drive drunk.
I don't know I was new to driving and hadn't experienced anything like that. I felt like my horn made him swerve into the water barrels.


 39%|███▉      | 999/2540 [01:40<01:21, 18.81it/s]


Context:
I hate that people hit on others they know are married.
Ugh, Why do people do that? It's annoying.
Right? Happens to my husband all the time and it's just frustrating.
Do you smack the women who do it?
No, Since it's his associates at work!!.


 87%|████████▋ | 2200/2540 [03:04<00:40,  8.47it/s]


Context:
I am going to play dominos tonight.
That sounds like fun! Is it with friends?
Well with some locals at the local brewpub.
So cool, I hope it goes well.
Yes I think I will do well even though I just learned last week how to play!


100%|██████████| 2540/2540 [03:30<00:00, 12.06it/s]


In [None]:
assert len(generated_responses) == len(actual_responses)
print(len(generated_responses))
print(len(actual_responses))

243
243


### Store responses

In [None]:
import pickle
from google.colab import files

file_generated = "" + selected_model + "_batch4_generated_responses_" + dataset_name
file_actual = "" + selected_model + "_batch4_actual_responses_"  + dataset_name

with open(file_generated, "wb") as fp:
    pickle.dump(generated_responses, fp)

with open(file_actual, "wb") as fp:
    pickle.dump(actual_responses, fp)

files.download(file_generated) 
files.download(file_actual) 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#Loading example
# with open("gpt2_batch8_generated_responses", "rb") as fp:   # Unpickling
#     dummy = pickle.load(fp)

### Compute metrics

In [None]:
import evaluate

sacrebleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
chrf = evaluate.load("chrf")

In [None]:
actual_responses = [[res] for res in actual_responses] #Refs must be in a list of list of str

print(generated_responses[:5])
print(actual_responses[:5])

["It's hard to believe the average person would've done that. I think it would've been pretty difficult to track down what's going on.", 'Work and work with your boss for a change. I wish you the best in your future.', 'You must be so proud of yourself. I have a friend who has won his first, So I know it must have felt great when he did it!', 'I will, Thanks.', 'Oh. Well I hope you enjoyed it.']
[["Wow, So your going to take being a bad person to the grave. Maybe you'll see her in the next life?"], ["Well I've been in the business all my life and have worked for some great people. So I pull from what I learned from them."], ['Oh, Wow. Not only to be able to do all the running but to view the scenery!'], ['I know.'], ['Sounds interesting!']]


In [None]:
bleu_score = sacrebleu.compute(predictions=generated_responses, references=actual_responses)

rouge_score = rouge.compute(predictions=generated_responses, references=actual_responses)

bert_score = bertscore.compute(predictions=generated_responses, references=actual_responses, lang='en')
precision = bert_score['precision']
recall = bert_score['recall']
f1 = bert_score['f1']
avg_precision_bert = sum(precision) / len(precision)
avg_recall_bert = sum(recall) / len(recall)
avg_f1_bert = sum(f1) / len(f1)

chrf_score = chrf.compute(predictions=generated_responses, references=actual_responses)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
print('Bleu score: \n', bleu_score) #Range from 0 to 100
print('Rouge score: \n', rouge_score)
print('Bert score: \n', bert_score)
print('Avg precision Bert score: ', avg_precision_bert)
print('Avg recall Bert score: ', avg_recall_bert)
print('Avg f1 Bert score: ', avg_f1_bert)
print('chrf score: \n', chrf_score)


Bleu score: 
 {'score': 0.8775860537020593, 'counts': [579, 39, 8, 4], 'totals': [3440, 3197, 2954, 2711], 'precisions': [16.83139534883721, 1.2198936502971536, 0.2708192281651997, 0.14754703061600885], 'bp': 0.9220971878521749, 'sys_len': 3440, 'ref_len': 3719}
Rouge score: 
 {'rouge1': 0.12720051415115013, 'rouge2': 0.012832279511119581, 'rougeL': 0.11058318450259344, 'rougeLsum': 0.11041933689210065}
Bert score: 
 {'precision': [0.8519506454467773, 0.8683867454528809, 0.837348222732544, 0.8608677983283997, 0.8518811464309692, 0.8643646836280823, 0.8481199741363525, 0.858927845954895, 0.8822838664054871, 0.8653178811073303, 0.8559589385986328, 0.8775131106376648, 0.8663605451583862, 0.8884487748146057, 0.9010105133056641, 0.8656351566314697, 0.8619168996810913, 0.872879147529602, 0.8998854756355286, 0.8512505292892456, 0.8912308216094971, 0.9395709037780762, 0.8470564484596252, 0.8360403180122375, 0.835598349571228, 0.8487783670425415, 0.8631832003593445, 0.8428822755813599, 0.857396

In [None]:
# Play an audio beep. Any audio URL will do.
from google.colab import output
output.eval_js('new Audio("https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg").play()')

In [None]:
# predictions = ["hello there general kenobi", "foo bar foobar"]
# references = [["hello there general kenobi", "hello there !"],
#                  ["foo bar foobar", "foo bar foobar"]]
# sacrebleu = evaluate.load("sacrebleu")
# results = sacrebleu.compute(predictions=predictions, 
#                              references=references)
# print(results)

# results = rouge.compute(predictions=predictions, 
#                              references=references)
# print(results)

# results = bertscore.compute(predictions=predictions, 
#                              references=references, lang='eng')
# print(results)

# results = chrf.compute(predictions=predictions, 
#                              references=references)
# print(results)

In [None]:
files.download('/content/gpt2_epoch6_daily_dialog.ckpt') 


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>