In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
import os
import pandas as pd
from googleapiclient import discovery
import json
import numpy as np
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import time

os.environ['TRANSFORMERS_CACHE'] = '/deep/u/andleerew/final_proj/cache'
tf_cache = '/deep/u/andleerew/final_proj/cache'

API_KEY = 'AIzaSyAoIA0ZUJ8Q6fEpqJjgkHeiOe58zQUB0NA'

In [2]:
# Building Datasets - Real Toxicity Prompts
prompts_df = pd.read_json("/deep/u/andleerew/final_proj/real-toxicity-prompts/realtoxicityprompts-data/prompts.jsonl", lines=True)
prompts_shuffled = prompts_df.sample(frac=1)
test_prompts = prompts_shuffled[prompts_shuffled["challenging"]==True]
test_prompts = [prompt["text"] for prompt in test_prompts["prompt"]]
train_prompts = prompts_shuffled[prompts_shuffled["challenging"]==False]
train_prompts = [prompt["text"] for prompt in train_prompts["prompt"]]
teacher_context = "Do not use any toxic, insulting, profane, or sexually explicit language while completing the following prompt: "
student_context = "Complete the following prompt: "
empty_context = ""

In [3]:
# Setting Up Labels for Distillation
teacher_train_response = []
tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-v1_1-base", device_map="auto")

In [24]:
TRAIN_SIZE = 1024
BATCH_SIZE = 16
tr_prompts = train_prompts[:TRAIN_SIZE]
TRAIN_BATCH = int(TRAIN_SIZE / BATCH_SIZE)
progress = tqdm(range(TRAIN_BATCH))
for i in range(TRAIN_BATCH):
    input_text = tr_prompts[BATCH_SIZE * i:BATCH_SIZE * (i+1)]
    input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids.to("cuda")
    outputs = model.generate(input_ids)
    teacher_train_response += tokenizer.batch_decode(outputs, skip_special_tokens=True)
    progress.update(1)

  0%|          | 0/64 [00:00<?, ?it/s]

In [25]:
class Tokenizer:
    def __init__(self, model_name, context_token, truncation=512):
        self.__tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.__truncation = truncation
        self.__context = context_token
    
    def encode(self, sentence):
        return self.__tokenizer(
            self.__context + sentence,
            truncation=True,
            max_length=self.__truncation,
            padding='max_length',
            return_tensors='pt'
        )
    
    def decode(self, word_ids, *args, **kwargs):
        return self.__tokenizer.decode(word_ids, *args, **kwargs)
    
    def __call__(self, sentence):
        return self.encode(sentence)
    
    @property
    def pad_token_id(self):
        return self.__tokenizer.pad_token_id
    
teacher_tokenizer = Tokenizer('t5-base', teacher_context)
student_tokenizer = Tokenizer('t5-base', student_context)
response_tokenizer = Tokenizer('t5-base', empty_context)

In [26]:
class PromptDataset:
    def __init__(self, data, label, pr_tokenizer, re_tokenizer):
        self.data = data
        self.label = label
        self.prompt_tokenizer = pr_tokenizer
        self.response_tokenizer = re_tokenizer

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sentence = self.data[idx]
        label = self.label[idx]

        tokenized_input = self.prompt_tokenizer(sentence)
        tokenized_label = self.response_tokenizer(label)
        
        return (
            tokenized_input['input_ids'].squeeze(),
            tokenized_input['attention_mask'].squeeze(),
            tokenized_label['input_ids'].squeeze(),
            tokenized_label['attention_mask'].squeeze()
        )
    
teacher_train_ds = PromptDataset(train_prompts[:TRAIN_SIZE], teacher_train_response, teacher_tokenizer, response_tokenizer)
student_train_ds = PromptDataset(train_prompts[:TRAIN_SIZE], teacher_train_response, student_tokenizer, response_tokenizer)


print('Number of training examples:', len(teacher_train_ds))
print('Number of test examples:', len(student_train_ds))

Number of training examples: 1024
Number of test examples: 1024


In [27]:
from torch.utils.data import DataLoader

batch_size = 1

train_dl = DataLoader(student_train_ds, batch_size=batch_size, shuffle=False, num_workers=8)

In [28]:
# Definint Loss for Context Distillation
class DistillationLoss:
    def __init__(self, temp=1):
        self.temp = temp
    
    def __call__(self, teacher_logits, student_logits):
        t = self.temp
        
        loss = -(
            (teacher_logits / t).softmax(dim=-1) * (student_logits / t).log_softmax(dim=-1)
        ).sum(dim=-1).mean()
        
        return loss
    
def test_distillation_loss():
    t_input_ids, t_attention_mask, t_labels, t_dam = teacher_train_ds[0]
    s_input_ids, s_attention_mask, s_labels, s_dam = student_train_ds[0]

    output_teacher = T5ForConditionalGeneration.from_pretrained('t5-base')(
        input_ids=t_input_ids.unsqueeze(dim=0), 
        attention_mask=t_attention_mask.unsqueeze(dim=0),
        labels = t_labels.unsqueeze(dim=0),
        decoder_attention_mask = t_dam.unsqueeze(dim=0)
    )

    output_student = T5ForConditionalGeneration.from_pretrained('t5-base')(
        input_ids=s_input_ids.unsqueeze(dim=0), 
        attention_mask=s_attention_mask.unsqueeze(dim=0),
        labels = s_labels.unsqueeze(dim=0),
        decoder_attention_mask = s_dam.unsqueeze(dim=0)
    )
    
    criterion = DistillationLoss(temp=1)
    loss = criterion(output_teacher.logits, output_student.logits)
    
    print(loss)
    
test_distillation_loss()

tensor(3.8450, grad_fn=<NegBackward0>)


In [9]:
import torch

from torch.optim import AdamW
from transformers import T5ForConditionalGeneration

class LightningT5(pl.LightningModule):
    def __init__(self, model_name, tokenizer):
        super().__init__()
        
        self.load_model(model_name)
        self.tokenizer = tokenizer
        
    def load_model(self, model_name):
        self.t5 = T5ForConditionalGeneration.from_pretrained(model_name)
        
    def forward(self, x):
        return self.t5.generate(x)
    
    def configure_optimizers(self):
        return AdamW(self.t5.parameters(), lr=1e-5)
    
    def _step(self, batch):
        input_ids, attention_mask, labels, decoder_attention_mask = batch
        labels[labels[:] == self.tokenizer.pad_token_id] = -100
        
        return self.t5(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask
        ).loss
    
    def training_step(self, batch, idx):
        loss = self._step(batch)
        return loss
    
    def save_pretrained(self, path):
        self.t5.save_pretrained(path)
    
    validation_step = training_step
    test_step = training_step

In [10]:
class LightningDistilledT5(LightningT5):
    def __init__(self, student_model_name, tokenizer, teacher_model, freeze_teacher=True, alpha=0, temp=1):
        super().__init__(student_model_name, tokenizer)
        self.teacher_model = teacher_model
        
        if freeze_teacher:
            for p in self.teacher_model.parameters():
                p.requires_grad = False
                
        self.criterion = DistillationLoss(temp=temp, alpha=alpha)  
    
    def _step(self, batch):
        input_ids, attention_mask, labels, decoder_attention_mask = batch
        labels[labels[:] == self.tokenizer.pad_token_id] = -100
        
        student_output = self.t5(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=self.t5._shift_right(labels),
            decoder_attention_mask=decoder_attention_mask
        )
        
        teacher_output = self.teacher_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=self.t5._shift_right(labels),
            decoder_attention_mask=decoder_attention_mask
        )
        
        student_logits = student_output.logits
        teacher_logits = teacher_output.logits
        
        return self.criterion(teacher_logits, student_logits, labels)

In [23]:
import gc

class Trainer:
    def __init__(self, model, max_epochs, filename):
        self.model = model
        self.max_epochs = max_epochs
        self.filename = filename
        
    def train(self, train_dl, max_steps=None):
        gc.collect()
        
        model = self.model
        
        trainer = pl.Trainer(
            gpus=1,
            max_epochs=self.max_epochs
        )
        
        trainer.fit(model, train_dl)
        model.save_pretrained(self.filename)
        
        return model
    
    def train_or_load_pretrained(self, train_dl, force=False, **kwargs):
        if os.path.exists(self.filename) and not force:
            self.model.load_model(self.filename)
            return self.model
        
        return self.train(train_dl, **kwargs)

In [None]:
distilled_t5_base_1_0 = Trainer(
    LightningDistilledT5(
        't5-base', 
        Tokenizer('t5-base', student_context), 
        finetuned_t5_base.t5,
        temp=1,
        alpha=0.5
    ), 
    2, 
    'distillation_1_0'
).train_or_load_pretrained(train_dl)