# Research Project

## Import required libraries

In [None]:
from datasets import list_datasets, load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_polynomial_decay_schedule_with_warmup
from torch.nn import functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
import torch
import math
import numpy as np
import random

import os, sys

os.environ['CURL_CA_BUNDLE'] = ''

import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


## Data Loading

In [2]:
train_dataset = load_dataset('daily_dialog.py', split='train')
validation_dataset = load_dataset('daily_dialog.py', split='validation')
test_dataset = load_dataset('daily_dialog.py', split='test')

Found cached dataset daily_dialog (C:/Users/aramosvela/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd)
Found cached dataset daily_dialog (C:/Users/aramosvela/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd)
Found cached dataset daily_dialog (C:/Users/aramosvela/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd)


### GPT 2

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [5]:
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()
    
    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]
                
            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]
                        
        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1
                
        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()
                
    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])
        
    return new_token_list

In [6]:
def load_daily(tokenizer, train_frac=0.85):
    dataset = load_dataset('daily_dialog')
    train_dialogues = dataset['train']['dialog']
    valid_dialogues = dataset['validation']['dialog']
    test_dialogues = dataset['test']['dialog']
    
    total_dialogues = train_dialogues + valid_dialogues + test_dialogues
    
    for i, dialogue in enumerate(tqdm(total_dialogues)):
        new_dialogue = []
        for utter in dialogue:
            token_list = tokenizer.tokenize(utter.strip().replace(pre_quote, quotes[1]))
            token_list = process_token_list(token_list)
            text = tokenizer.convert_tokens_to_string(token_list)
            new_dialogue.append(text)
            
        total_dialogues[i] = new_dialogue
    
    train_utter_num = 0
    valid_utter_num = 0
    train_dialogues = total_dialogues[:int(len(total_dialogues)*train_frac)]
    valid_dialogues = total_dialogues[int(len(total_dialogues)*train_frac):]
    
    for dialogue in train_dialogues:
        train_utter_num += len(dialogue)
        
    for dialogue in valid_dialogues:
        valid_utter_num += len(dialogue)
    
    return train_dialogues, valid_dialogues, train_utter_num, valid_utter_num

In [7]:
train_dialogues = []
valid_dialogues = []
num_train = 0
num_valid = 0

part_train_dialogues, part_valid_dialogues, part_num_train, part_num_valid = load_daily(tokenizer)
   
train_dialogues += part_train_dialogues
valid_dialogues += part_valid_dialogues

# print("#"*50 + f"Analysis on {data_name}" + "#"*50)
print(f"The number of train dialogues: {len(part_train_dialogues)}")
print(f"The number of valid dialogues: {len(part_valid_dialogues)}")    
print(f"The number of train utterances: {part_num_train}")    
print(f"The number of valid utterances: {part_num_valid}")

num_train += part_num_train
num_valid += part_num_valid

Found cached dataset daily_dialog (C:/Users/aramosvela/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd)
100%|██████████| 3/3 [00:00<00:00, 428.53it/s]
100%|██████████| 13118/13118 [00:08<00:00, 1540.94it/s]

The number of train dialogues: 11150
The number of valid dialogues: 1968
The number of train utterances: 87467
The number of valid utterances: 15512





In [8]:
# Extrac ids (input_ids, token_ids) from processed text

def extract_ids(dialogues):

    ids = []
    for dialogue in tqdm(dialogues):
            dialogue_ids = []
            for utter in dialogue:
                tokens = tokenizer.tokenize(utter)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                dialogue_ids.append(token_ids)
            ids.append(dialogue_ids)
            
    assert len(ids) == len(dialogues)

    return ids  

train_ids = extract_ids(train_dialogues)
valid_ids = extract_ids(valid_dialogues) 
    

100%|██████████| 11150/11150 [00:05<00:00, 1909.58it/s]
100%|██████████| 1968/1968 [00:01<00:00, 1960.16it/s]


In [50]:
#Parameters

sp1_token = '<sp1>'
sp2_token = '<sp2>'
bos_token = '<bos>'
max_turns = 5
max_len = 1024
seed = 0
gpu = 0

#Tokeniser
special_tokens = {'bos_token': bos_token,
                'additional_special_tokens': [sp1_token, sp2_token]}

eos_token = tokenizer.eos_token
num_new_tokens = tokenizer.add_special_tokens(special_tokens)

vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
bos_id = vocab[bos_token]
eos_id = vocab[eos_token]
sp1_id = vocab[sp1_token]
sp2_id = vocab[sp2_token]

lr = 2e-5
batch_size = 8
num_workers = 0
num_epochs = 10
warmup_ratio = 0.1
last_epoch = 0
end_command = 'Abort!'
top_p = 0.8


ckpt_dir = 'saved_models'

In [51]:
class CustomDataset(Dataset):
    def __init__(self, dials):

        self.input_ids = []  # (N, L)
        self.token_type_ids = []  # (N, L)
        self.labels = []  # (N, L)
            
        for dial in tqdm(dials):
            hists = []
            for u, utter in enumerate(dial):
                if u % 2 == 0:
                    hists.append([sp1_id] + utter)
                else:
                    hists.append([sp2_id] + utter)
                    
            for h in range(len(hists)):
                if hists[h][0] == sp2_id:
                    start = max(0, h - max_turns+1)
                    for s in range(start, h):
                        contexts = hists[s:h+1]
                        input_ids = [bos_id] + list(chain.from_iterable(contexts)) + [eos_id]
                        if len(input_ids) <= max_len:
                            start_sp_id, next_sp_id = contexts[0][0], contexts[1][0]
                            token_type_ids = [[start_sp_id] * len(ctx) if c % 2 == 0 else [next_sp_id] * len(ctx) for c, ctx in enumerate(contexts)]
                            assert token_type_ids[-1][0] == sp2_id
                            token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [sp2_id]
                            assert len(input_ids) == len(token_type_ids)
                            
                            labels = [[-100] * len(ctx) if c < len(contexts)-1 else [-100] + ctx[1:] for c, ctx in enumerate(contexts)]
                            assert labels[-1][1:] == contexts[-1][1:]
                            labels = [-100] + list(chain.from_iterable(labels)) + [eos_id]
                            assert len(input_ids) == len(labels)
                            
                            self.input_ids.append(input_ids)
                            self.token_type_ids.append(token_type_ids)
                            self.labels.append(labels)
                            
                            break
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.token_type_ids[idx], self.labels[idx]
    
    
class PadCollate():
    def __init__(self, eos_id):
        self.eos_id = eos_id
        
    def pad_collate(self, batch):
        input_ids, token_type_ids, labels =[], [], []
        for idx, seqs in enumerate(batch):
            input_ids.append(torch.LongTensor(seqs[0]))
            token_type_ids.append(torch.LongTensor(seqs[1]))
            labels.append(torch.LongTensor(seqs[2]))
            
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id)
        token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=self.eos_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    
        return input_ids, token_type_ids, labels

In [52]:
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

In [53]:
#Load Model

if torch.cuda.is_available():
    device = torch.device(f"cuda:{gpu}")
else:
    device = torch.device("cpu")

print("Loading the model...")
fix_seed(seed)
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
model.resize_token_embeddings(vocab_size)

max_len = min(max_len, model.config.n_ctx)

Loading the model...


In [54]:
# Load optimizer
print("Loading the optimizer...")
optim = torch.optim.AdamW(model.parameters(), lr=lr)
best_loss = sys.float_info.max

# Load train & valid dataset
print("Loading train & valid data...")

train_set = CustomDataset(train_ids)
valid_set = CustomDataset(valid_ids)

ppd = PadCollate(eos_id=eos_id)

train_loader = DataLoader(train_set, 
                            collate_fn=ppd.pad_collate, 
                            shuffle=True, 
                            batch_size=batch_size, 
                            num_workers=num_workers, 
                            pin_memory=True)

valid_loader = DataLoader(valid_set, 
                            collate_fn=ppd.pad_collate,
                            batch_size=batch_size, 
                            num_workers=num_workers, 
                            pin_memory=True)
    
# Calculate total training steps
num_batches = len(train_loader)
total_train_steps = num_epochs * num_batches
warmup_steps = int(warmup_ratio * total_train_steps)

sched = get_polynomial_decay_schedule_with_warmup(
    optim,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_train_steps,
    power=2
)



writer = SummaryWriter()

Loading the optimizer...
Loading train & valid data...


100%|██████████| 11150/11150 [00:00<00:00, 15379.39it/s]
100%|██████████| 1968/1968 [00:00<00:00, 23428.87it/s]


In [55]:
def validation():

    print("Validation processing...")
    model.eval()
            
    valid_losses = []
    valid_ppls = []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(valid_loader)):
            input_ids, token_type_ids, labels = batch
            input_ids, token_type_ids, labels = \
                input_ids.to(device), token_type_ids.to(device), labels.to(device)
            
            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )
            
            loss, logits = outputs[0], outputs[1]
            
            valid_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            valid_ppls.append(ppl)
        
        valid_losses = [loss.item() for loss in valid_losses]
        valid_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in valid_ppls]
        valid_loss = np.mean(valid_losses)
        valid_ppl = np.mean(valid_ppls)
        
        if math.isnan(valid_ppl):
            valid_ppl = 1e+8
            
    return valid_loss, valid_ppl

In [56]:
last_epoch

0

In [59]:
def train():
    
    fix_seed(seed)  # Fix seed before training
    print("Training starts.")

    last_epoch= 0
    
    start_epoch = last_epoch +1
    for epoch in range(start_epoch, start_epoch+num_epochs):
        model.train()
        
        print(f"#"*50 + f"Epoch: {epoch}" + "#"*50)
        train_losses = []
        train_ppls = []
        for i, batch in enumerate(tqdm(train_loader)):
            input_ids, token_type_ids, labels = batch
            input_ids, token_type_ids, labels = \
                input_ids.to(device), token_type_ids.to(device), labels.to(device)
            
            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )
            
            loss, logits = outputs[0], outputs[1]
            
            optim.zero_grad()
            loss.backward()
            optim.step()
            sched.step()
            
            train_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            train_ppls.append(ppl)
        
        train_losses = [loss.item() for loss in train_losses]
        train_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in train_ppls]
        train_loss = np.mean(train_losses)
        train_ppl = np.mean(train_ppls)
        print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}")
        
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("PPL/train", train_ppl, epoch)
        
        last_epoch += 1
        
        valid_loss, valid_ppl = validation()
            
        if valid_loss < best_loss:
            best_loss = valid_loss
            state_dict = {
                'model_state_dict': model.state_dict(),
                'optim_state_dict': optim.state_dict(),
                'sched_state_dict': sched.state_dict(),
                'loss': best_loss,
                'epoch': last_epoch
            }
            
            torch.save(state_dict, f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            print("*"*10 + "Current best checkpoint is saved." + "*"*10)
            print(f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            
        print(f"Best valid loss: {best_loss}")
        print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}")
        
        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("PPL/valid", valid_ppl, epoch)
        
        writer.add_scalars("Losses", {
            'train': train_loss, 
            'valid': valid_loss,
        }, epoch)
        writer.add_scalars("PPLs", {
            'train': train_ppl,
            'valid': valid_ppl,
        }, epoch)
            
    print("Training finished!")

In [60]:
train()

Training starts.
##################################################Epoch: 1##################################################


 13%|█▎        | 665/5223 [1:19:47<9:06:53,  7.20s/it] 


KeyboardInterrupt: 

In [None]:
def infer():
    print("Let's start!")
    print(f"If you want to quit the conversation, please type \"{end_command}\".")
    model.eval()
    fix_seed(seed)

    with torch.no_grad():
        input_hists = []
        
        while True:
            utter = input("You: ")
            if utter == end_command:
                print("Bot: Good bye.")
                break
            
            input_ids = [sp1_id] + tokenizer.encode(utter)
            input_hists.append(input_ids)
            
            if len(input_hists) >= max_turns:
                num_exceeded = len(input_hists) - max_turns + 1
                input_hists = input_hists[num_exceeded:]
                
            input_ids = [bos_id] + list(chain.from_iterable(input_hists)) + [sp2_id]
            start_sp_id = input_hists[0][0]
            next_sp_id = sp1_id if start_sp_id == sp2_id else sp2_id

            assert start_sp_id != next_sp_id
            token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
            assert len(token_type_ids) == len(input_hists)
            token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [sp2_id]
            assert len(input_ids) == len(token_type_ids)
            input_len = len(input_ids)
            
            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)
            
            output_ids = nucleus_sampling(input_ids, token_type_ids, input_len)                
            # output_ids = self.model.generate(
            #     input_ids=input_ids, token_type_ids=token_type_ids, pad_token_id=self.args.eos_id,
            #     do_sample=True, top_p=self.args.top_p, max_length=self.args.max_len,
            #     output_hidden_states=True, output_scores=True, return_dict_in_generate=True,
            # ).sequences
            # output_ids = output_ids[0].tolist()[input_len:]
            res = tokenizer.decode(output_ids, skip_special_tokens=True)
            
            print(f"Bot: {res}")
            input_hists.append([sp2_id] + tokenizer.encode(res))
                
def nucleus_sampling(input_ids, token_type_ids, input_len):
    output_ids = []
    for pos in range(input_len, max_len):
        output = model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, pos-1]  # (1, V)
        output = F.softmax(output, dim=-1)  # (1, V)
        
        sorted_probs, sorted_idxs = torch.sort(output, descending=True)
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
        idx_remove = cumsum_probs > top_p
        idx_remove[:, 1:] = idx_remove[:, :-1].clone()
        idx_remove[:, 0] = False
        sorted_probs[idx_remove] = 0.0
        sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)
        
        probs = torch.zeros(output.shape, device=device).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
        idx = torch.multinomial(probs, 1)  # (1, 1)
        
        idx_item = idx.squeeze(-1).squeeze(-1).item()
        output_ids.append(idx_item)
        
        if idx_item == eos_id:
            break
            
        input_ids = torch.cat((input_ids, idx), dim=-1)
        next_type_id = torch.LongTensor([[sp2_id]]).to(device)
        token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
        assert input_ids.shape == token_type_ids.shape
        
    return output_ids