# Fine tuning GPT-2 model with Google/AirDialog dataset

Here we attempt the fine tuning the pre-trained GPT-2 model with open dataset to generate answers for conversation between client and agent in Ticket Booking dataset. The idea is to evaluate the work necessary for obtaining customized model for conversations.

In [1]:
import time
import datetime
import random

import torch


from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer

import numpy as np

2023-01-08 14:14:33.011095: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-08 14:14:33.662815: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-08 14:14:33.662871: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


## Build the dataset representation from given files

Note: This code should be moved to a library as it will be useful for other training systems.

In [2]:
import json
import os
import gc

from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from transformers import GPT2Tokenizer

DATADIR = "/home/martin/data/airdialogue/airdialogue/"

SPECIAL_TOKENS_DICT = { "additional_special_tokens": ["[USR]", "[SYS]", "[SEP]"] }
#get pretrained tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(
    'gpt2',
    add_special_tokens=True,
    bos_token='[BOS]',
    eos_token='[EOS]',
    pad_token='[PAD]'
)
tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)


class AirdialogDataset(Dataset):
    """Google AirDialog dataset as two utterance pair
    
    As we intend to train a conversational model, we will build the dataset as concatenated
    chunks of customer and agent text response.
    """
    
    def __init__(self, path, tokenizer, gpt2_type="gpt2", max_size=-1):
    
        self.tokenizer = tokenizer 
        self.input_ids = []
        self.attn_masks = []
        self.max_size = max_size
        
        # intializes max_lenght as well
        sentences = [ _ for _ in self.get_sentences(path, max_size) ]

        self.max_length = max(len(tokenizer.encode(s)) for s in sentences)
        print(self.max_length)

        for sentence in sentences:  
            encodings = tokenizer(sentence, max_length=self.max_length, truncation=True, padding="max_length")

            self.input_ids.append(torch.tensor(encodings['input_ids']))
            self.attn_masks.append(torch.tensor(encodings['attention_mask']))
    
    @classmethod
    def get_sentences(cls, path, max_size=-1):
        """Generates list of sentences per stored dialogue
        
        Version 1. concat all sentences into one string and let's see how well will this go.
        """
        def _parse_dialogue_chunk(utterances:list) -> str:
            for i, utt in enumerate(utterances):
                utterances[i] = utt.replace("customer:", "[USR]")\
                .replace("agent:", "[SYS]")\
                .replace(". ", ". [SEP] ")\
                .replace("? ", "? [SEP] ")\
                .replace("! ", "! [SEP] ")
            return " [EOS] ".join(utterances)
        
        def _split_utts(utterances, size):
            out = []
            count = 0
            for usr, sys in zip(utterances[::2], utterances[1::2]):
                add = len(usr.split(" ")) + len(sys.split(" "))
                if count + add < size:
                    out += [usr, sys]
                    count += add
                else:
                    yield out
                    out = [usr, sys]
                    count = add
            
            if len(out) > 1:
                # At least two utterances on the end, otherwise no value
                yield out
        
        with open(path, "r") as f:
            count = 0
            for line in f.readlines():
                content = json.loads(line)
                dialogue = content["dialogue"]
                for chunk in _split_utts(dialogue, 80):
                    res = "[BOS] " + _parse_dialogue_chunk(chunk) + " [EOS]"
                    if count < 10:
                        print(res)
                    yield res
                    count += 1
                if max_size > 0 and count > max_size:
                    break
                
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]   


dataset = AirdialogDataset(os.path.join(DATADIR, "train_data.json"), tokenizer, max_size=100000)
testset = AirdialogDataset(os.path.join(DATADIR, "dev_data.json"), tokenizer, max_size=2000)

# Split into training and validation sets
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

train_set, val_set = random_split(dataset, [train_size, val_size])
print("train_size :",train_size)
print("val_size   :",val_size)

#define dataloaders
train_dataloader = DataLoader(train_set,  sampler = RandomSampler(train_set), batch_size = 20)
validation_dataloader = DataLoader(val_set, sampler = SequentialSampler(val_set), batch_size = 20)

gc.collect()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


[BOS] [USR] Hello. [EOS] [SYS] Hello. [SEP] How may I help you? [EOS] [USR] Can you help me to change my recent reservation because my trip dates are got postponed? [EOS] [SYS] I will help you with that please share your name to proceed further? [EOS] [USR] Edward hall here. [EOS] [SYS] Please wait for a while. [EOS] [USR] Sure, take your own time. [EOS] [SYS] There is no active reservation found under your name to amend it. [EOS]
[BOS] [USR] That's ok, thank you for checking. [EOS] [SYS] Thank you for choosing us. [EOS]
[BOS] [USR] HI. [EOS] [SYS] Hello. [SEP] How may I be of your address? [EOS] [USR] I want to book a flight ticket to Las Vegas with price under 1000. [SEP] Can you please help me with it? [EOS] [SYS] Sure, can I know your connection limit? [EOS] [USR] I need a single connection. [EOS] [SYS] Please let me know your boarding and landing points. [EOS] [USR] Airport codes are HOU-LAS. [EOS] [SYS] Kindly share your planned travelling dates to proceed further. [EOS]
[BOS] [U

0

## Declare the model

Here we will reshape the embedding to include our special tokens

In [3]:
# Create default config
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
# Load pretrained gpt2
model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration)
model.resize_token_embeddings(len(tokenizer))

# Create device
device = torch.device("cuda")
model.cuda()

optimizer = torch.optim.Adam(model.parameters(),lr = 0.0005)
model = model.to(device)

In [4]:
### TODO Add some simple test based on actual data to see the status

In [5]:
#call model with a batch of input
def process_one_batch(model, batch):
    b_input_ids = batch[0].to(device)
    b_labels = batch[0].to(device)
    b_masks = batch[1].to(device)
    outputs  = model(b_input_ids,  attention_mask = b_masks,labels=b_labels)
    return outputs

#do one epoch for training
def train_epoch(model, dataloader):
    t0 = time.monotonic()
    total_train_loss = 0
    model.train()
    for step, batch in enumerate(dataloader):
        
        model.zero_grad()        
        outputs = process_one_batch(model, batch)
        loss = outputs[0]
        batch_loss = loss.item()
        total_train_loss += batch_loss

        loss.backward()
        optimizer.step()
        
    avg_train_loss = total_train_loss / len(dataloader)  
    print("avg_train_loss", avg_train_loss)
    elapsed_time = time.monotonic() - t0
    print("elapsed time for 1 training epoch : ",elapsed_time)

#do one epoch for eval
def eval_epoch(model, dataloader):
    t0 = time.monotonic()
    total_eval_loss = 0
    nb_eval_steps = 0
    # Evaluate data for one epoch
    for batch in dataloader:            
        
        with torch.no_grad():        
            outputs = process_one_batch(model, batch)
            loss = outputs[0]              
            batch_loss = loss.item()
            total_eval_loss += batch_loss         

    avg_val_loss = total_eval_loss / len(validation_dataloader)
    print("avg_val_loss",avg_val_loss) 
    elapsed_time = time.monotonic() - t0
    print("elapsed time for 1 eval epoch : ",elapsed_time)
    
#at every step i want to check if generations are getting better.
def show_sentences(model):
    model.eval()

    eos = tokenizer.convert_tokens_to_ids(["[EOS]"])[0]
    input_sentences = [
        "[BOS] [USR] Hi. [EOS]",
        "[USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]",
        "[USR] My travel dates are June 24 and June 26. [EOS]",
        "[USR] I need a single connecting flight. [EOS]",
        "[USR] No. [EOS]"
    ]
    
    context = ""
    for i, input_seq in enumerate(input_sentences):
        print(f"Input: {input_seq}")
        input_seq = " ".join([context, input_seq])
        input_tkn = tokenizer.encode(input_seq)
        
        inputs = torch.tensor(tokenizer.encode(input_seq)).unsqueeze(0)
        inputs = inputs.to(device)
        
        sample_output = model.generate(
            inputs, 
            do_sample=True,   
            top_k=30, 
            max_length = 100,
            top_p=0.90, 
            num_return_sequences=1
        )[0]
        pruned = sample_output[len(input_tkn):]
        for i, tkn in enumerate(pruned):
            if tkn == eos:
                pruned = pruned[:i+1]
                break
        print(f"Resp: {tokenizer.decode(pruned)}")
        #print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=False)))

In [6]:
show_sentences(model)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Input: [BOS] [USR] Hi. [EOS]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Resp: [SYS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [SYS] [BOS] [SYS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [USR] [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Resp: [BOS] [BOS] [BOS] [BOS] [BOS] [SYS] [USR] [SYS] [SYS] [BOS] [BOS] [BOS] [SYS] [SYS] [USR] [SYS] [USR] [BOS] [SYS] [SYS] [SYS] [USR] [SYS] [BOS] [SYS] [SYS] [USR] [BOS] [SYS] [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Resp: [BOS] [BOS] [SYS] [SYS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [SYS] [USR] [SYS] [USR] [USR] [BOS] [BOS] [SYS] [SYS] [EOS]
Input: [USR] I need a single connecting flight. [EOS]


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Resp: [SYS] [BOS] [BOS] [BOS] [SYS] [BOS] [SYS] [BOS] [BOS] [SYS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [USR] [BOS] [BOS] [BOS] [SYS] [SYS] [BOS] [BOS] [USR] [BOS] [USR] [BOS] [BOS] [USR] [EOS]
Input: [USR] No. [EOS]
Resp: [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [SYS] [BOS] [BOS] [BOS] [BOS] [BOS] [SYS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [BOS] [EOS]


In [7]:
for i in range(20):
    #train eval 1 cycle
    #then create sample sentences
    train_epoch(model, train_dataloader)
    eval_epoch(model, validation_dataloader)
    show_sentences(model)

avg_train_loss 0.7092560136848026
elapsed time for 1 training epoch :  1866.385577672


avg_val_loss 0.5170431025251895
elapsed time for 1 eval epoch :  69.78717359499979
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello, how can I aid you?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] Please wait for a moment.  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] Is there any specific timings to depart and return?  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] May I know your airline preference?  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] We found 1001 Frontier Airlines with single connection with economy class and flight number is 1026 with economy class. [SEP] Shall I proceed for booking?  [EOS]
avg_train_loss 0.49315162052710854
elapsed time for 1 training epoch :  1871.693765811


avg_val_loss 0.47952845704531716
elapsed time for 1 eval epoch :  69.81591926899955
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello. [SEP] How can I help you today?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] May I know your trip dates?  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] May I know your preferred boarding and landing airport codes?  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] May I know your good name please?  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] Thank you for waiting. [SEP] But I am sorry to say that there is no flight available in this route as per your travelling details.  [EOS]
avg_train_loss 0.4569603205720584
elapsed time for 1 training epoch :  1874.9408861679995


avg_val_loss 0.4568063116478111
elapsed time for 1 eval epoch :  69.74603020399991
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello. [SEP] How may I help you?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] Ok, may I know your planned travel dates please?  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] Do you have any preferences for your journey?  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] May I know your name for booking your flight ticket?  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] Ok, wait for some time I will check and get back to you.  [EOS]
avg_train_loss 0.42884386374553046
elapsed time for 1 training epoch :  1875.3040625780004


avg_val_loss 0.44419189347716387
elapsed time for 1 eval epoch :  69.77165642900036
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello. [SEP] How can I assist you?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] Sure, I am happy to help you. [SEP] Please let me know your travelling dates.  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] Do you have any other specifications?  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] Do you have any other specifications for this trip?  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] May I know your good name?  [EOS]
avg_train_loss 0.40657353472709656
elapsed time for 1 training epoch :  1875.5570044809992


avg_val_loss 0.4377269617573706
elapsed time for 1 eval epoch :  69.68527582800016
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello, how may I help you?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] Please share your planned travel dates.  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] Would you like to travel in a connecting flight or a direct flight?  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] May I know your class preference?  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] May I know your name please?  [EOS]
avg_train_loss 0.3880218263367812
elapsed time for 1 training epoch :  1875.5898526769997


avg_val_loss 0.43470200664269
elapsed time for 1 eval epoch :  69.78501310899992
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello. [SEP] How can I aid you?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] Sure, please specify your connection limit and price limit.  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] May I know your airport codes?  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] May I know your name to proceed?  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] May I know your name?  [EOS]
avg_train_loss 0.37130008898178735
elapsed time for 1 training epoch :  1875.2125434900008


avg_val_loss 0.4365191345443269
elapsed time for 1 eval epoch :  69.69980028700047
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello. [SEP] How can I aid you today?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] Sure, may I know your journey dates?  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] Do you have any connection limit preference?  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] Please wait for a mean while.  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] Can I have your good name please?  [EOS]
avg_train_loss 0.35573135691550045
elapsed time for 1 training epoch :  1875.6051464559987


avg_val_loss 0.43829745429004735
elapsed time for 1 eval epoch :  69.72021361599946
Input: [BOS] [USR] Hi. [EOS]


Resp: [SYS] Hello. [SEP] How may I guide you?  [EOS]
Input: [USR] Could you please book a flight ticket from DEN to DTW as I want to attend a pioneer festival? [EOS]


Resp: [SYS] Sure, let me know your journey dates.  [EOS]
Input: [USR] My travel dates are June 24 and June 26. [EOS]


Resp: [SYS] Ok, please wait for a while.  [EOS]
Input: [USR] I need a single connecting flight. [EOS]


Resp: [SYS] Do you have any other specifications?  [EOS]
Input: [USR] No. [EOS]
Resp: [SYS] There is a connecting flight 1002 of UA airlines, can I proceed?  [EOS]


In [8]:
torch.save(model.state_dict(), "output/GPT2-airdialog.ckpt")