In [None]:
from datasets import load_dataset
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import os
import torch.nn.functional as F
from easy_transformer import EasyTransformer
from torch.utils.data import Dataset, DataLoader


In [None]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_gpt2


In [None]:
num_params = sum(p.numel() for p in reference_gpt2.parameters())
print("Number of parameters in GPT-2 Small model:", num_params)


In [None]:
train_dataset = load_dataset('glue', 'qnli', split='train')
validation_dataset = dataset_study = load_dataset('glue', 'qnli', split='validation')
test_dataset = dataset_study = load_dataset('glue', 'qnli', split='test')


In [None]:
len(train_dataset), len(validation_dataset), len(test_dataset)

In [None]:
num_samples = 3000
test_samples = 500 
valid_samples = 500 
subset_indices = torch.randperm(len(train_dataset)).tolist()[:num_samples]
subset_dataset_train = train_dataset.select(subset_indices)

subset_indices = torch.randperm(len(validation_dataset)).tolist()[:valid_samples]
subset_dataset_valid = train_dataset.select(subset_indices)

subset_indices = torch.randperm(len(test_dataset)).tolist()[:test_samples]
subset_dataset_test = train_dataset.select(subset_indices)

len(subset_dataset_train), len(subset_dataset_valid), len(subset_dataset_test)


In [None]:
c0 = 0
c1 = 0
for each in subset_dataset_train:
    if each['label'] == 0:
        c0 += 1
    if each['label'] == 1:
        c1 += 1
print(c0, c1)

c0 = 0
c1 = 0
for each in subset_dataset_valid:
    if each['label'] == 0:
        c0 += 1
    if each['label'] == 1:
        c1 += 1
print(c0, c1)

c0 = 0
c1 = 0
for each in subset_dataset_test:
    if each['label'] == 0:
        c0 += 1
    if each['label'] == 1:
        c1 += 1
print(c0, c1)

In [None]:
train_dataset[0]

In [None]:
values = []
for each in validation_dataset:
    values.append(each['label'])
set(values)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset, model, max_length = 1024, token_to_add = 50256):
        self.dataset = dataset
        self.max_length = max_length
        self.token_to_add = token_to_add
        self.model = model
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sentence1 = self.dataset[idx]['question']
        sentence1_tokens = self.model.to_tokens(sentence1, prepend_bos = True)
        sentence2 = self.dataset[idx]['sentence']
        sentence2_tokens = self.model.to_tokens(sentence2, prepend_bos = True)
        concatenated_tokens = torch.cat((sentence1_tokens, sentence2_tokens), dim=1)
        labels = torch.tensor(self.dataset[idx]['label'])
        
        remaining_length = self.max_length - concatenated_tokens.size(1)
        while remaining_length > 0:
            concatenated_tokens = torch.cat((concatenated_tokens, torch.tensor([[self.token_to_add]])), dim=1)
            remaining_length -= 1

        
        return concatenated_tokens, labels



In [None]:
dataset = CustomDataset(subset_dataset_train, reference_gpt2, max_length = 1024, token_to_add = 50256)
data_loader_train = DataLoader(dataset, batch_size=32, shuffle=True)
len(data_loader_train)

In [None]:
dataset = CustomDataset(subset_dataset_valid, reference_gpt2, max_length = 1024, token_to_add = 50256)
data_loader_valid = DataLoader(dataset, batch_size=32, shuffle=True)
len(data_loader_valid)

In [None]:
dataset = CustomDataset(subset_dataset_test, reference_gpt2, max_length = 1024, token_to_add = 50256)
data_loader_test = DataLoader(dataset, batch_size=32, shuffle=True)
len(data_loader_test)

In [None]:
for tokens, labels in data_loader_train:
    print("Tokens:", tokens.shape)
    print("Labels:", labels.shape)
    break

In [None]:
reference_gpt2.cfg

In [None]:


class CustomGPT2ForSequenceClassification(EasyTransformer):
    def __init__(self, config):
        super().__init__(config)
        self.unembed = None
        self.classification_head = torch.nn.Linear(config.d_model * config.n_ctx, num_labels)
        
        
        
    def forward(self, input_ids):
       
        embed = self.embed(tokens=input_ids)
        embed = embed.squeeze(1)
        #print('embed',embed.shape)
        pos_embed = self.pos_embed(input_ids)
        #print('pos_embed',pos_embed.shape)
        residual = embed + pos_embed
        #print('residual', residual.shape)
        for block in self.blocks:
            normalized_resid_pre = block.ln1(residual)
            #print('normalized_resid_pre', normalized_resid_pre.shape)
            attn_out = block.attn(normalized_resid_pre)
            #print('attn_out', attn_out.shape)
            resid_mid = residual + attn_out
            #print('resid_mid', resid_mid.shape)

            normalized_resid_mid = block.ln2(resid_mid)
            #print('normalized_resid_mid', normalized_resid_mid.shape)
            mlp_out = block.mlp(normalized_resid_mid)
            #print('mlp_out', mlp_out.shape)
            resid_post = resid_mid + mlp_out
            #print('resid_post', resid_post.shape)
        normalized_resid_final = self.ln_final(resid_post)
        #print('normalized_resid_final', normalized_resid_final.shape)
        normalized_resid_final = normalized_resid_final.view(normalized_resid_final.shape[0], -1)
        #print('normalized_resid_final', normalized_resid_final.shape)
        logits = self.classification_head(normalized_resid_final)
        return logits

# Example usage:
config = reference_gpt2.cfg
num_labels = 2  
model = CustomGPT2ForSequenceClassification(config)

model.load_state_dict(reference_gpt2.state_dict(), strict=False)
device = torch.device("mps")
model.to(device)
model


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
model.float() 

num_epochs = 3

# Training loop
for epoch in range(num_epochs):
    print(epoch)
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    # Wrap the dataloader with tqdm to add the progress bar
    for input_ids, labels in tqdm(data_loader_train, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        input_ids = input_ids.to(device).long() 
        labels = labels.to(device).long() 
        
        optimizer.zero_grad()
        logits = model(input_ids)
        
        loss = loss_fn(logits, labels)
        
        total_loss += loss.item()
        total_samples += input_ids.size(0)
        total_correct += (logits.argmax(dim=-1) == labels).sum().item()
        
        loss.backward()
        optimizer.step()
        
    # Calculate metrics
    accuracy = total_correct / total_samples
    average_loss = total_loss / len(data_loader_test)
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}, Accuracy: {accuracy}")
    torch.save(model.state_dict(), '../trained_models/easy_transformer_gpt2small_qnli.pth')
    
model.eval()
total_loss = 0
total_correct = 0
total_samples = 0    
for input_ids, labels in tqdm(data_loader_valid, desc=f'Epoch {epoch + 1}/{num_epochs}'):
    input_ids = input_ids.to(device).long() 
    labels = labels.to(device).long() 
    logits = model(input_ids) 
    loss = loss_fn(logits, labels)   
    total_loss += loss.item()
    total_samples += input_ids.size(0)
    total_correct += (logits.argmax(dim=-1) == labels).sum().item()        
accuracy = total_correct / total_samples
average_loss = total_loss / len(data_loader_valid)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}, Accuracy: {accuracy}")


total_loss = 0
total_correct = 0
total_samples = 0    
for input_ids, labels in tqdm(data_loader_test, desc=f'Epoch {epoch + 1}/{num_epochs}'):
    input_ids = input_ids.to(device).long() 
    labels = labels.to(device).long() 
    logits = model(input_ids) 
    loss = loss_fn(logits, labels)   
    total_loss += loss.item()
    total_samples += input_ids.size(0)
    total_correct += (logits.argmax(dim=-1) == labels).sum().item()        
accuracy = total_correct / total_samples
average_loss = total_loss / len(data_loader_test)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}, Accuracy: {accuracy}")
