In [None]:
import os
import torch
import random
from tqdm import tqdm
from IPython.display import clear_output

In [None]:
if not (os.path.exists('/content/ijcnlp_dailydialog.zip') and os.path.exists('/content/ijcnlp_dailydialog')):
  print('dowloading...')
  !wget http://yanran.li/files/ijcnlp_dailydialog.zip
  !unzip /content/ijcnlp_dailydialog.zip
else:
  print('files already exists')

if not (os.path.exists('/content/data/test') and os.path.exists('/content/data/train') and os.path.exists('/content/data/validation')):
  !mkdir data
  !unzip /content/ijcnlp_dailydialog/validation.zip -d /content/data
  !unzip /content/ijcnlp_dailydialog/train.zip -d /content/data
  !unzip /content/ijcnlp_dailydialog/test.zip -d /content/data
else:
  print('files already exists')

!rm /content/ijcnlp_dailydialog.zip
!rm -r /content/ijcnlp_dailydialog

clear_output(wait=False)

In [None]:
!pip install transformers==4.38.2
clear_output(wait=False)

In [None]:
from transformers import RobertaModel, RobertaTokenizer
encoder = RobertaModel.from_pretrained('roberta-base')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
clear_output(wait=False)
print(tokenizer.pad_token_id)

In [None]:
from google.colab import drive
drive.mount("Mydrive")

In [None]:
class CustomDataset():

    def __init__(self, data_path, tokenizer, batchsize, maxlength) -> None:
        self.data = open(data_path)
        self.tokenizer = tokenizer
        self.batchsize = batchsize
        self.maxlength = maxlength
        self.data = self.custom_dataset()

    def custom_dataset(self):
        data = []
        for line in self.data:
            seqs = line.split('__eou__')
            seqs = ''.join(seqs)
            data.append(seqs)
        self.length = len(data)//self.batchsize
        return data

    def __len__(self):
        return self.length

    def batch_tokenize(self, texts):
        token_ids_batch, token_mask_batch = [], []
        # add start token
        for text in texts:
            encoded = self.tokenizer(text,
                      truncation = True,
                      padding="max_length",
                      max_length = self.maxlength,
                      return_tensors="pt")
            token_ids_batch.append(encoded)
        return token_ids_batch

    def load_batch(self, shuffle=True):
        data = self.custom_dataset()
        if shuffle:
            random.shuffle(data)

        for i in range(0, len(data), self.batchsize):
            batch_texts = data[i:i+self.batchsize]
            yield self.batch_tokenize(batch_texts)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class Model(nn.Module):
    def __init__(self, encoder, vocab_size) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder_layer = nn.Linear(768, vocab_size)

    def forward(self, x, mask):
        embeddings = self.encoder(input_ids=x, attention_mask=mask)
        embedding = embeddings.last_hidden_state
        return self.decoder_layer(embedding)

model = Model(encoder, tokenizer.vocab_size).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
dataset = CustomDataset("/content/data/train/dialogues_train.txt", tokenizer, 4, 512)

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
teacher_model = Model(encoder, tokenizer.vocab_size).to(device)
pretrained_state_dict = torch.load("/content/Mydrive/MyDrive/18786 Project/Results/big_model")
teacher_model.load_state_dict(pretrained_state_dict)
teacher_model.eval()
clear_output(wait=False)

In [None]:
import torch
import torch.nn as nn

import torch.nn.functional as F

from torch.nn import Transformer
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class EmbeddingLayer(nn.Module):
    def __init__(self, d_model, vocab_size, max_lens, device="cuda", dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout_layer = nn.Dropout(dropout)
        self.embedding_layer = nn.Embedding(vocab_size, d_model)
        self.positional_enoc = nn.Parameter(torch.zeros(1, max_lens, d_model)).to(device)

    def forward(self, x):
        embedding = self.layer_norm(self.embedding_layer(x))
        return self.dropout_layer(embedding + self.positional_enoc[:, :x.size(1), :])


class Vanilla_Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_head, dim_feedforward, num_layers, max_lens, device="cuda", dropout=0.1) -> None:
        super().__init__()
        self.embedding_layer = EmbeddingLayer(d_model, vocab_size, max_lens, device)
        encoder_layer = TransformerEncoderLayer(d_model, n_head, dim_feedforward, batch_first=True)
        self.encoder = TransformerEncoder(encoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.dropout_layer = nn.Dropout(dropout)
        self.device = device

    def get_causal_mask(self, seq_len, device):
        return Transformer.generate_square_subsequent_mask(sz=seq_len, device=device)

    def forward(self, x, src_padding_mask=None):
        embedding = self.embedding_layer(x)
        mask = self.get_causal_mask(x.size(1), device=self.device)
        encoded_embedding = self.encoder(src=embedding, mask=mask,
                          src_key_padding_mask=src_padding_mask, is_causal=True)
        output = self.output_layer(encoded_embedding)
        return self.dropout_layer(output)


vocab_size, d_model, n_head, dim_feedforward, num_layers, max_lens = tokenizer.vocab_size, 512, 8, 2048, 4, 512
student_model = Vanilla_Transformer(vocab_size, d_model, n_head, dim_feedforward, num_layers, max_lens).to(device)

In [None]:
def validation(model, dataset, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for data_point in tqdm(dataset.load_batch(), total=dataset.length, leave=True):
            for j in range(len(data_point)):
                if j == 0:
                    x, mask = data_point[j]['input_ids'].to(device), data_point[j]["attention_mask"].to(device)
                else:
                    x, mask = torch.cat([x, data_point[j]['input_ids'].to(device)], dim=0), torch.cat([mask, data_point[j]["attention_mask"].to(device)], dim=0)

            src = x[:, :-1]
            tgt = x[:, 1: ]
            bool_mask = ~ mask.to(torch.bool)
            float_mask = bool_mask.to(torch.float)
            float_mask = float_mask.masked_fill(bool_mask, float('-inf'))
            output = model(src, float_mask[:, :-1])
            loss = criterion(output.transpose(1,2), tgt)
            epoch_loss += loss.item()
        epoch_loss = epoch_loss/(dataset.length)
    return epoch_loss

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
def distillation_loss(student_logits, teacher_logits, temperature=1):
    student_probs = torch.nn.functional.softmax(student_logits / temperature, dim=-1)
    teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
    # return torch.nn.functional.kl_div(student_probs.log(), teacher_probs, reduction='batchmean')
    return torch.nn.functional.kl_div(teacher_probs.log(), student_probs, reduction='batchmean')

epoch = 80
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-4)
dataset = CustomDataset("/content/data/train/dialogues_train.txt", tokenizer, 24, 512)
validationset = CustomDataset("/content/data/validation/dialogues_validation.txt", tokenizer, 1, 512)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.5, last_epoch=-1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, threshold=0.001)
# please exploring these parameters: epoch, lr, scheduler(step_size, gamma), (num_layer, d_model, n_head, dim_feedforward), batch_size
# size of (num_layer=6, d_model=512, n_head=8, dim_feedforward=2048)->153 MB, please don't test model larger than this one
train_loss = []
valid_loss = []
for i in range(epoch):
    e_loss = 0
    student_model.train()
    # teacher_model.eval()
    for data_point in tqdm(dataset.load_batch(), total=dataset.length, leave=True):
        optimizer.zero_grad()
        for j in range(len(data_point)):
            if j == 0:
                x, mask = data_point[j]['input_ids'].to(device), data_point[j]["attention_mask"].to(device)
            else:
                x, mask = torch.cat([x, data_point[j]['input_ids'].to(device)], dim=0), torch.cat([mask, data_point[j]["attention_mask"].to(device)], dim=0)

        src = x[:, :-1]
        tgt = x[:, 1: ]
        bool_mask = ~ mask.to(torch.bool)
        float_mask = bool_mask.to(torch.float)
        float_mask = float_mask.masked_fill(bool_mask, float('-inf'))
        student_logits = student_model(src, float_mask[:, :-1])
        # with torch.no_grad():
        #     teacher_logits = teacher_model(src, mask[:, :-1])

        primary_loss = criterion(student_logits.transpose(1, 2), tgt)
        # distillation_losses = distillation_loss(student_logits, teacher_logits)
        loss = primary_loss
        loss.backward()
        optimizer.step()
        e_loss += loss.item()

    e_loss = e_loss/(dataset.length)
    v_loss = validation(student_model, validationset, criterion)
    print(e_loss)
    print(v_loss)
    print('tgts: {}'.format(tokenizer.decode(tgt[0].tolist())))
    print('pred: {}'.format(tokenizer.decode(torch.argmax(student_logits, dim=-1)[0].tolist())))
    print("epoch: {}".format(i))
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {i + 1}, Current Learning Rate: {current_lr}")
    train_loss.append(e_loss)
    valid_loss.append(v_loss)
    scheduler.step(v_loss)

In [None]:
# epochs = 10
# for i in range(epochs):
#     epoch_loss = 0
#     for data_point in tqdm(dataset.load_batch(), total=dataset.length, leave=True):
#         optimizer.zero_grad()
#         for j in range(len(data_point)):
#             if j == 0:
#                 x, mask = data_point[j]['input_ids'].to(device), data_point[j]["attention_mask"].to(device)
#             else:
#                 x, mask = torch.cat([x, data_point[j]['input_ids'].to(device)], dim=0), torch.cat([mask, data_point[j]["attention_mask"].to(device)], dim=0)

#         src = x[:, :-1]
#         tgt = x[:, 1: ]

#         output = model(src, mask[:, :-1])
#         loss = criterion(output.transpose(1,2), tgt)
#         loss.backward()
#         optimizer.step()
#         epoch_loss += loss.item()
#     print(epoch_loss/dataset.length)

In [None]:
import matplotlib.pyplot as plt

# Use a style template
plt.style.use('classic')

# Create the plot
plt.plot(train_loss, label="Train Loss", color='blue', linewidth=2, marker='o', markersize=5)
plt.plot(valid_loss, label="Validation Loss", color='red', linewidth=2, marker='x', markersize=5)

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')

# Add a legend
plt.legend()

# Add a grid
plt.grid(True)

# Show the plot
plt.show()

In [None]:
plt.style.available

In [None]:
epoch = 30
student_model = torch.load('/content/Mydrive/MyDrive/18786 Project/Results/student_model.pth')

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
def distillation_loss(student_logits, teacher_logits, temperature=1):
    student_probs = torch.nn.functional.softmax(student_logits / temperature, dim=-1)
    teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
    # return torch.nn.functional.kl_div(student_probs.log(), teacher_probs, reduction='batchmean')
    return torch.nn.functional.kl_div(teacher_probs.log(), student_probs, reduction='none')


criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
dataset = CustomDataset("/content/data/train/dialogues_train.txt", tokenizer, 4, 512)
validationset = CustomDataset("/content/data/validation/dialogues_validation.txt", tokenizer, 1, 512)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, threshold=0.001)
trains_loss = []
valids_loss = []
trains_loss_d = []
for i in range(epoch):
    e_loss_1 = 0
    e_loss_2 = 0
    student_model.train()
    teacher_model.eval()
    for data_point in tqdm(dataset.load_batch(), total=dataset.length, leave=True):
        optimizer.zero_grad()
        for j in range(len(data_point)):
            if j == 0:
                x, mask = data_point[j]['input_ids'].to(device), data_point[j]["attention_mask"].to(device)
            else:
                x, mask = torch.cat([x, data_point[j]['input_ids'].to(device)], dim=0), torch.cat([mask, data_point[j]["attention_mask"].to(device)], dim=0)

        src = x[:, :-1]
        tgt = x[:, 1: ]
        bool_mask = ~ mask.to(torch.bool)
        float_mask = bool_mask.to(torch.float)
        float_mask = float_mask.masked_fill(bool_mask, float('-inf'))
        student_logits = student_model(src, float_mask[:, :-1])
        with torch.no_grad():
            teacher_logits = teacher_model(src, mask[:, :-1])

        float_mask = bool_mask.to(torch.float)[:, 1: ]
        # student_logits = student_logits * float_mask.unsqueeze(2)
        # teacher_logits = teacher_logits * float_mask.unsqueeze(2)
        distillation_losses = distillation_loss(student_logits, teacher_logits)
        # loss = distillation_losses * float_mask.unsqueeze(2)
        # loss = loss.sum() / mask.sum()
        loss = distillation_losses
        loss = loss.mean()
        loss.backward()
        optimizer.step()
        # e_loss_1 += primary_loss.item()
        e_loss_2 += loss.item()

    # e_loss_1 = e_loss_1/(dataset.length)
    e_loss_2 = e_loss_2/(dataset.length)
    v_loss = validation(student_model, validationset, criterion)
    print(e_loss_2)
    print(v_loss)
    print('tgts: {}'.format(tokenizer.decode(tgt[0].tolist())))
    print('pred: {}'.format(tokenizer.decode(torch.argmax(student_logits, dim=-1)[0].tolist())))
    print("epoch: {}".format(i))
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {i + 1}, Current Learning Rate: {current_lr}")
    trains_loss_d.append(e_loss_2)
    # trains_loss.append(e_loss_1)
    valids_loss.append(v_loss)
    scheduler.step(v_loss)