In [1]:
!pip install datasets
!pip install transformers
!pip install evaluate
!pip install rouge_score bert_score sacrebleu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW
from datasets import load_dataset
import numpy as np
import random
from tqdm import tqdm



In [3]:
dataset_name = 'daily_dialog'

In [4]:
if dataset_name == 'daily_dialog':
    print('Loading ', dataset_name)
    dataset = load_dataset('daily_dialog')
    train_dialogues = dataset['train']['dialog']
    valid_dialogues = dataset['validation']['dialog']
    test_dialogues = dataset['test']['dialog']
elif dataset_name == 'empathetic_dialogues':
    print('Loading ', dataset_name)
    dataset = load_dataset('empathetic_dialogues')
    train_dialogues = dataset['train']
    valid_dialogues = dataset['validation']
    test_dialogues = dataset['test']
else:
    print('No dataset selected')

Loading  daily_dialog




  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

sp1_token = '<sp1>'
sp2_token = '<sp2>'
bos_token = '<bos>'
eos_token = '<eos>'
seed = 0
gpu = 0

#Tokeniser
special_tokens = {'bos_token': bos_token,
                'additional_special_tokens': [sp1_token, sp2_token]}

eos_token = tokenizer.eos_token
num_new_tokens = tokenizer.add_special_tokens(special_tokens)

vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
bos_id = vocab[bos_token]
eos_id = vocab[eos_token]
sp1_id = vocab[sp1_token]
sp2_id = vocab[sp2_token]

lr = 2e-5
batch_size = 8
num_workers = 0
num_epochs = 8
warmup_ratio = 0.1
last_epoch = 0
end_command = 'Quit!'
top_p = 0.8


In [6]:
def preprocess_dialog(dialog, window_size=5):
    instances = []
    
    # response = dialog["dialog"][-1]  # Last utterance as the response

    for i in range(0, len(dialog) - window_size - 1, 2):
        
        window = dialog[i:i+window_size]
        window_context = []
        for j, utterance in enumerate(window):
            speaker = "<sp1>" if j % 2 == 0 else '<sp2>'
            window_context.append(speaker + " " + utterance)

        # Add special tokens for bos, eos
        # window_context.insert(0, "<s>")
        # window_context.append("</s>")

        window_context = ' '.join(window_context)
        response = '<sp2>' + dialog[i+window_size] #+ " </s>"

        # print('window_context: ', type(window_context), window_context)
        # print('response: ', type(response), response)
        


        # print()
        # print('window_context: ', window_context)
        # print('response: ', response)


        # Tokenize the context and response
        input_ids = tokenizer.encode_plus(window_context, padding='max_length', max_length=512, truncation=True , return_tensors="pt")

        decoder_input_ids = tokenizer.encode_plus(response, padding='max_length', max_length=512, truncation=True, return_tensors="pt")

        instance = {
            "input_ids": input_ids["input_ids"].squeeze(0),
            "attention_mask": input_ids["attention_mask"].squeeze(0),
            "decoder_input_ids": decoder_input_ids["input_ids"].squeeze(0),
            "decoder_attention_mask": decoder_input_ids["attention_mask"].squeeze(0)
            
        }

        # print('input_ids: ', instance['input_ids'].shape)
        # print('attention_mask: ', instance['attention_mask'].shape)
        # print('decoder_input_ids: ', instance['decoder_input_ids'].shape)
        # print('decoder_attention_mask: ', instance['decoder_attention_mask'].shape)
        instances.append(instance)
        



    return instances

In [7]:
train_dialogues[0]

['Say , Jim , how about going for a few beers after dinner ? ',
 ' You know that is tempting but is really not good for our fitness . ',
 ' What do you mean ? It will help us to relax . ',
 " Do you really think so ? I don't . It will just make us fat and act silly . Remember last time ? ",
 " I guess you are right.But what shall we do ? I don't feel like sitting at home . ",
 ' I suggest a walk over to the gym where we can play singsong and meet some of our friends . ',
 " That's a good idea . I hear Mary and Sally often go there to play pingpong.Perhaps we can make a foursome with them . ",
 ' Sounds great to me ! If they are willing , we could ask them to go dancing with us.That is excellent exercise and fun , too . ',
 " Good.Let ' s go now . ",
 ' All right . ']

In [8]:
train_instances = []
val_instances = []

for dialog in tqdm(train_dialogues):
    train_instances.extend(preprocess_dialog(dialog))
    # break

for dialog in tqdm(valid_dialogues):
    val_instances.extend(preprocess_dialog(dialog))

100%|██████████| 11118/11118 [00:23<00:00, 473.81it/s]
100%|██████████| 1000/1000 [00:01<00:00, 548.14it/s]


In [9]:
train_instances[0]['input_ids'].shape

torch.Size([512])

In [10]:
train_instances[1]['input_ids'].shape

torch.Size([512])

In [11]:
class DialogueDataset(Dataset):
    def __init__(self, instances):
        self.instances = instances

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        return self.instances[idx]



In [12]:
#Create data loaders
train_dataset = DialogueDataset(train_instances)
val_dataset = DialogueDataset(val_instances)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [13]:
print(train_dataset.__getitem__(0)['input_ids'].shape)
print(train_dataset.__getitem__(0)['attention_mask'].shape)

print(train_dataset.__getitem__(0)['decoder_input_ids'].shape)
print(train_dataset.__getitem__(0)['decoder_attention_mask'].shape)


torch.Size([512])
torch.Size([512])
torch.Size([512])
torch.Size([512])


In [14]:
print(train_dataset.__getitem__(1)['input_ids'].shape)
print(train_dataset.__getitem__(1)['attention_mask'].shape)

print(train_dataset.__getitem__(1)['decoder_input_ids'].shape)
print(train_dataset.__getitem__(1)['decoder_attention_mask'].shape)

torch.Size([512])
torch.Size([512])
torch.Size([512])
torch.Size([512])


In [15]:
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

In [16]:
if torch.cuda.is_available():
    device = torch.device(f"cuda:{gpu}")
    print('Using GPU')
else:
    device = torch.device("cpu")
    print('Using CPU')

Using GPU


In [17]:
#Define the BART model and optimizer

fix_seed(seed)
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(device)
model.resize_token_embeddings(vocab_size)

Embedding(50268, 768)

In [18]:
print("Loading the optimizer...")
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)


Loading the optimizer...


In [26]:
# input_ids[0]
# attention_mask[0]
# decoder_input_ids[0]
# decoder_attention_mask[0]
# decoder_input_ids[0]


tensor([    0, 50267, 33295,  9856,     5,   864,   114,    47,   218,    75,
         2916,    24,  2563,   479,  1437,     2,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1, 

In [None]:
num_epochs=1
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_dataloader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        decoder_input_ids = batch["decoder_input_ids"].to(device)
        decoder_attention_mask = batch["decoder_attention_mask"].to(device)

        optimizer.zero_grad()

        
        # print('input_ids: ', input_ids.shape)
        # print('attention_mask: ', attention_mask.shape)
        # print('decoder_input_ids: ', decoder_input_ids.shape)
        # print('decoder_attention_mask: ', decoder_attention_mask.shape)

        # decoder_input_ids = decoder_input_ids[:, None, :]
        # decoder_attention_mask = decoder_attention_mask[:, None, :]
        # print(decoder_input_ids.shape)
        # print(decoder_attention_mask.shape)




        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels = decoder_input_ids
            # use_cache=False
        )

        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    average_loss = total_loss / len(train_dataloader)

    print(f"Epoch {epoch + 1} - Average Loss: {average_loss:.4f}")

 12%|█▏        | 231/1939 [04:54<36:51,  1.29s/it]

In [None]:
def generate_response(context):
    model.eval()
    input_ids = tokenizer.encode(context, padding='max_length', max_length=512, truncation=True , return_tensors="pt").to(device)
    output_ids = model.generate(input_ids=input_ids, max_length=512)
    response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)

    return response

context = "<sp1> How are you? <sp2> I'm doing well. How about you? <sp1> I'm good too. Thanks for asking."

response = generate_response(context)
print("Generated Response:", response)