## Imports

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelWithLMHead
import gzip
from typing import List
import json
from sklearn.metrics import mean_squared_error

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

## Load the movie title data and filter US movies to sample 20000 for training

In [2]:
fname = 'title.akas.tsv.gz'

with gzip.open(fname, 'rb') as f:
    movie_df = pd.read_table(f, sep='\t', na_values=["\\N","nan"])

# print(movie_df.head())

movie_df_sampled = movie_df[movie_df['region']=="US"].sample(20000)
print(movie_df_sampled.head())

  movie_df = pd.read_table(f, sep='\t', na_values=["\\N","nan"])


             titleId  ordering                     title region language  \
16065241   tt1601958         2  The Seven Masks of Volto     US      NaN   
14795078  tt15255036         2          Full Time Pimpin     US      NaN   
36522495   tt8768044         4      Desire (Chapter Two)     US      NaN   
10680532  tt13292924         2        Roadblock and Play     US      NaN   
33232      tt0010233         3        Hearts and Flowers     US      NaN   

                types attributes  isOriginalTitle  
16065241  imdbDisplay        NaN              0.0  
14795078  imdbDisplay        NaN              0.0  
36522495          dvd        NaN              0.0  
10680532  imdbDisplay        NaN              0.0  
33232     imdbDisplay        NaN              0.0  


## Statistics about the movie titles

In [5]:
movie_titles = movie_df_sampled['title'].tolist()
title_lengths = [len(title.split()) for title in movie_titles]

mean_length = np.mean(title_lengths)
std_length = np.std(title_lengths)
print("Mean length:",mean_length)
print("Std length:",std_length)
print("Max length",max(title_lengths))

max_len = int(mean_length + 3*std_length)
print("Max length for model:",max_len)

Mean length: 3.4815
Std length: 2.4839600942849303
Max length 94
Max length for model: 10


## Dataloader class

- For training, it encodes `<len> ## <word> asfgads <text> asd jjksd lksda`
- For testing, it encodes `<len> ## <word> asfgads <text> `

In [3]:
class MovieDataset(Dataset):  
    def __init__(self, tokenizer, movie_titles: List, max_len: int, dataset_type: str,max_seq_len: int=30) -> None:
        self.max_len = max_len
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenizer
        self.eos = self.tokenizer.eos_token
        self.eos_id = self.tokenizer.eos_token_id
        self.movies = movie_titles
        self.dataset_type = dataset_type
        self.result = []
        self.populate()


    def __len__(self) -> int:
        return len(self.result)


    def __getitem__(self, item: int) -> torch.Tensor:
        return self.result[item]
    
    def populate(self) -> None:
        for movie in self.movies:
            movie_words = movie.split()
            movie_len = len(movie_words)
            if movie_len > 1:
                prefix = f"<len> {movie_len-1} <word> {movie_words[0]} <text> "
                movie = (" ").join(movie_words[1:])
            else:
                prefix = f"<len> {movie_len} <word> movie <text> "
                movie = (" ").join(movie_words[:])

            encoded_prefix = self.tokenizer.encode(prefix)
            if self.dataset_type=="train":
                encoded_movie = self.tokenizer.encode(movie)
                if len(encoded_movie)>self.max_len:
                    encoded_movie = encoded_movie[:self.max_len]
                encoded_input = encoded_prefix + encoded_movie
                if len(encoded_input)>self.max_seq_len:
                    encoded_input = encoded_input[:self.max_seq_len-1]
                padded = encoded_input + [self.eos_id]*(self.max_seq_len-len(encoded_input))
            elif self.dataset_type=="test":
                padded = encoded_prefix
            # print(len(padded))
            self.result.append(torch.tensor(padded))


## Model Class

In [12]:
class GPT2DistillMovie(torch.nn.Module):
    def __init__(self, device: str, teacher_model: str=None, student_model: str=None):
        super().__init__()
        self.teacher_model = AutoModelWithLMHead.from_pretrained(teacher_model)
        if student_model:
            self.student_model = AutoModelWithLMHead.from_pretrained(student_model)
        else:
            self.student_model = AutoModelWithLMHead.from_pretrained("distilgpt2")

        self.teacher_model = self.teacher_model.to(device)
        self.student_model = self.student_model.to(device)

        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.optimizer = optim.AdamW(self.student_model.parameters(), lr=5e-4)

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return self.student_model(tensor)

    def train(self,train_dataloader, epochs: int, temperature: float=2.0):    
        for epoch in range(epochs):
            self.student_model.train()
            total_loss = 0.0
            for idx, batch in enumerate(train_dataloader):
                self.optimizer.zero_grad()
                batch = batch.to(device)
                with torch.no_grad():
                    teacher_outputs = self.teacher_model(batch)
                    logits_teacher = teacher_outputs.logits

                outputs = self.student_model(batch, labels=batch)
                logits_student = outputs.logits

                loss = outputs.loss + self.distillation_loss(logits_student, logits_teacher, temperature=temperature)
                loss.backward()
                self.optimizer.step()
                # if idx % 100 == 0:
                #         print("loss: %f, %d"%(loss, idx))

                total_loss += loss.item()
            
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_dataloader)}")


    def distillation_loss(self,logits_student: torch.Tensor, logits_teacher: torch.Tensor, temperature: float=2.0) -> torch.Tensor:

        p_student = torch.nn.functional.log_softmax(logits_student / temperature, dim=-1)
        p_teacher = torch.nn.functional.softmax(logits_teacher / temperature, dim=-1)
        loss = torch.nn.functional.kl_div(p_student, p_teacher, reduction='batchmean') * (temperature** 2)

        return loss


    def save(self, filepath: str="distilled_model/") -> None:
        self.student_model.save_pretrained(save_directory=filepath)
        self.tokenizer.save_vocabulary(save_directory=filepath)
        
    def topk(self,probs: torch.Tensor, k: int=5) -> int:
        probs = torch.softmax(probs, dim= -1)

        token_probs, topIx = torch.topk(probs, k=k)
        token_probs = token_probs / torch.sum(token_probs)
        token_probs = token_probs.cpu().detach().numpy()

        choice = np.random.choice(k, 1, p = token_probs)
        token_id = topIx[choice][0]

        return int(token_id)
    
    def inference(self, init_token: torch.Tensor, max_length: int=10) -> str:

        sequence = init_token.numpy().tolist()
        init_input = init_token.unsqueeze(0).to(device)

        with torch.set_grad_enabled(False):
            output = self.student_model(init_input)
            logits = output.logits[0,-1]

            sequence.append(self.topk(logits))

            for i in range(max_length):
                inp = torch.tensor(sequence).unsqueeze(0).to(device)
                output = self.student_model(inp)
                logits = output.logits[0,-1]
                res_id = self.topk(logits)

                if res_id == self.tokenizer.eos_token_id:
                    return self.tokenizer.decode(sequence)
                else: 
                    sequence.append(res_id)

        return self.tokenizer.decode(sequence)

    def eval(self, test_dataset) -> None:
        results = []
        within_max_len = 0
        within_req_len = 0
        equal_req_len = 0
        req_len = []
        gen_len = []
        for inp in test_dataset:
            ret_seq = self.inference(inp).strip()
            results.append(ret_seq)
            true_len = int(ret_seq.split("<text>")[0].split(" ")[1])
            output = ret_seq.split("<text>")[1].split(" ")[1:]
            # print(req_len,len(output),output)
            if len(output)<=max_len:
                within_max_len+=1
            if len(output)<=true_len:
                within_req_len+=1
                if len(output)==true_len:
                    equal_req_len+=1
            req_len.append(true_len)
            gen_len.append(len(output))
            
        
        result_json = {"within_max_len":within_max_len/len(test_dataset),
                        "within_req_len": within_req_len/len(test_dataset),
                        "equal_req_len":equal_req_len/len(test_dataset),
                        "MSE_genvreq":mean_squared_error(req_len,gen_len),
                        "gen_results":results}
                        
        json_file_path = "eval_student_results.json"

        with open(json_file_path, "w") as json_file:
            json.dump(result_json, json_file, indent=4)
        
        print(f"Output within max seq length: {within_max_len/len(test_dataset)}")
        print(f"Output within req seq length: {within_req_len/len(test_dataset)}")
        print(f"Output equal req seq length: {equal_req_len/len(test_dataset)}")
        print(f"MSE req vs gen seq length: {mean_squared_error(req_len,gen_len)}")
        print("-"*20)
        print(results[:10])
    

## Load the model

In [13]:
gpt2distill = GPT2DistillMovie(device,teacher_model="model/")



## Load the tokenizer and dataset

In [10]:
dataset = MovieDataset(gpt2distill.tokenizer, movie_titles, max_len, dataset_type="train")
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)
print(len(dataset))

20000


## Train

In [14]:
gpt2distill.train(train_dataloader=dataloader, epochs=20,temperature=2.0)

Epoch 1/20, Loss: 36.21026574707031
Epoch 2/20, Loss: 21.448064233398437


KeyboardInterrupt: 

## Save the model

In [None]:
gpt2distill.save("distilled_model/")

## Load test set

In [17]:
gpt2distill = GPT2DistillMovie(device,teacher_model="model/",student_model="distilled_model/")

movie_test = movie_df[movie_df['region']=="US"].sample(1000)
movie_test = movie_test['title'].tolist()
test_dataset = MovieDataset(gpt2distill.tokenizer, movie_test, max_len, dataset_type="test")

## Evaluate on test set

In [18]:
gpt2distill.eval(test_dataset)


Output within max seq length: 1.0
Output within req seq length: 0.945
Output equal req seq length: 0.789
MSE req vs gen seq length: 0.974
--------------------
['<len> 3 <word> Rachael <text> Rayal/Alex Eden/CyHi', '<len> 3 <word> Melanie <text> Tree/John Pugh/Regina Deme', '<len> 2 <word> So <text> Long, Long', '<len> 1 <word> L.A. <text> Muse', '<len> 3 <word> Scott <text> Free: Words', "<len> 3 <word> Didn't <text> I Tweet", '<len> 5 <word> Broken <text> Sharts and Broken Dreams', '<len> 5 <word> Zeta <text> No. 13: The First Edition', '<len> 2 <word> Lil <text> Y-Lane/Jean', '<len> 3 <word> Mystery <text> on the Trail']
