In [None]:
from datasets import load_dataset
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import os
import torch.nn.functional as F
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
from easy_transformer import EasyTransformer
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import random

In [None]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_gpt2


In [None]:
num_params = sum(p.numel() for p in reference_gpt2.parameters())
print("Number of parameters in GPT-2 Small model:", num_params)


In [None]:
train_qqp = load_dataset("glue",'qqp' split="train")


In [None]:
len(train_qqp)

In [None]:
values = []
for each in train_qqp:
    values.append(each['is_duplicate'])
set(values)

In [None]:
len(train_qqp['questions']), len(train_qqp['is_duplicate'])

In [None]:
train_qqp_text, validation_qqp_text, train_qqp_label, validation_qqp_label  = train_test_split(train_qqp['questions'], train_qqp['is_duplicate'], test_size=0.1, random_state=42)
train_qqp_text = train_qqp_text[:3000]
validation_qqp_text = validation_qqp_text[:500]
train_qqp_label = train_qqp_label[:3000]
validation_qqp_label = validation_qqp_label[:500]

print("Train set size:", len(train_qqp_text))
print("Validation set size:", len(validation_qqp_text))


In [None]:
cf = 0
ct = 0
for each in train_qqp_label:
    if each== False:
        cf += 1
    if each == True:
        ct += 1
print(cf, ct)

cf = 0
ct = 0
for each in validation_qqp_label:
    if each == False:
        cf += 1
    if each == True:
        ct += 1
print(cf, ct)

In [None]:
train_qqp_text[0]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, questions, labels, model, max_length = 1024, token_to_add = 50256):
        self.questions = questions
        self.labels = labels
        self.max_length = max_length
        self.token_to_add = token_to_add
        self.model = model
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        sentence1 = self.questions[idx]['text'][0]
        sentence1_tokens = self.model.to_tokens(sentence1, prepend_bos = False)
        sentence2 = self.questions[idx]['text'][1]
        sentence2_tokens = self.model.to_tokens(sentence2, prepend_bos = False)
        
        token_to_add = torch.tensor([50256], dtype=torch.long)
        token_to_add = token_to_add.unsqueeze(0) 
        sentence1_tokens = torch.cat((sentence1_tokens, token_to_add), dim=1)
        concatenated_tokens = torch.cat((sentence1_tokens, sentence2_tokens), dim=1)
        
        if self.labels[idx] == True:
            labels = torch.tensor(1)
        else:
            labels = torch.tensor(0)
        
        remaining_length = self.max_length - concatenated_tokens.size(1)
        while remaining_length > 0:
            concatenated_tokens = torch.cat((concatenated_tokens, torch.tensor([[self.token_to_add]])), dim=1)
            remaining_length -= 1

        
        return concatenated_tokens, labels



In [None]:
dataset = CustomDataset(train_qqp_text, train_qqp_label, reference_gpt2, max_length = 1024, token_to_add = 50256)
data_loader_train = DataLoader(dataset, batch_size=32, shuffle=True)
len(data_loader_train)

In [None]:
dataset = CustomDataset(validation_qqp_text, validation_qqp_label, reference_gpt2, max_length = 1024, token_to_add = 50256)
data_loader_valid = DataLoader(dataset, batch_size=32, shuffle=True)
len(data_loader_valid)

In [None]:
for tokens, labels in data_loader_train:
    print("Tokens:", tokens.shape)
    print("Labels:", labels.shape)
    break

In [None]:
reference_gpt2.cfg

In [None]:


class CustomGPT2ForSequenceClassification(EasyTransformer):
    def __init__(self, config):
        super().__init__(config)
        self.unembed = None
        self.classification_head = torch.nn.Linear(config.d_model * config.n_ctx, num_labels)
        
        
        
    def forward(self, input_ids):
       
        embed = self.embed(tokens=input_ids)
        embed = embed.squeeze(1)
        #print('embed',embed.shape)
        pos_embed = self.pos_embed(input_ids)
        #print('pos_embed',pos_embed.shape)
        residual = embed + pos_embed
        #print('residual', residual.shape)
        for block in self.blocks:
            normalized_resid_pre = block.ln1(residual)
            #print('normalized_resid_pre', normalized_resid_pre.shape)
            attn_out = block.attn(normalized_resid_pre)
            #print('attn_out', attn_out.shape)
            resid_mid = residual + attn_out
            #print('resid_mid', resid_mid.shape)

            normalized_resid_mid = block.ln2(resid_mid)
            #print('normalized_resid_mid', normalized_resid_mid.shape)
            mlp_out = block.mlp(normalized_resid_mid)
            #print('mlp_out', mlp_out.shape)
            resid_post = resid_mid + mlp_out
            #print('resid_post', resid_post.shape)
        normalized_resid_final = self.ln_final(resid_post)
        #print('normalized_resid_final', normalized_resid_final.shape)
        normalized_resid_final = normalized_resid_final.view(normalized_resid_final.shape[0], -1)
        #print('normalized_resid_final', normalized_resid_final.shape)
        logits = self.classification_head(normalized_resid_final)
        return logits      

# Example usage:
config = reference_gpt2.cfg
num_labels = 2  
model = CustomGPT2ForSequenceClassification(config)
model.load_state_dict(reference_gpt2.state_dict(), strict=False)
device = torch.device("mps")
model.to(device)
model


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
model.float() 

num_epochs = 3

# Training loop
for epoch in range(num_epochs):
    print(epoch)
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    # Wrap the dataloader with tqdm to add the progress bar
    for input_ids, labels in tqdm(data_loader_train, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        input_ids = input_ids.to(device).long() 
        labels = labels.to(device).long() 
        
        optimizer.zero_grad()
        logits = model(input_ids)
        
        loss = loss_fn(logits, labels)
        
        total_loss += loss.item()
        total_samples += input_ids.size(0)
        total_correct += (logits.argmax(dim=-1) == labels).sum().item()
        
        loss.backward()
        optimizer.step()
        
    # Calculate metrics
    accuracy = total_correct / total_samples
    average_loss = total_loss / len(data_loader_train)
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}, Accuracy: {accuracy}")
    torch.save(model.state_dict(), '../trained_models/easy_transformer_gpt2small_qqp_try.pth')
    
model.eval()
total_loss = 0
total_correct = 0
total_samples = 0
for input_ids, labels in tqdm(data_loader_valid, desc=f'Epoch {epoch + 1}/{num_epochs}'):
    input_ids = input_ids.to(device).long() 
    labels = labels.to(device).long() 
    logits = model(input_ids)        
    loss = loss_fn(logits, labels)
    total_loss += loss.item()
    total_samples += input_ids.size(0)
    total_correct += (logits.argmax(dim=-1) == labels).sum().item()
accuracy = total_correct / total_samples
average_loss = total_loss / len(data_loader_valid)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}, Accuracy: {accuracy}")
