In [None]:
from datasets import load_dataset
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import os
import torch.nn.functional as F
from easy_transformer import EasyTransformer
from torch.utils.data import Dataset, DataLoader
import random


In [None]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_gpt2


In [None]:


class CustomGPT2ForSequenceClassification(EasyTransformer):
    def __init__(self, config):
        super().__init__(config)
        self.unembed = None
        self.classification_head1 = torch.nn.Linear(config.d_model * config.n_ctx, num_labels)        
        
    def forward(self, input_ids):
       
        embed = self.embed(tokens=input_ids)
        embed = embed.squeeze(1)
        #print('embed',embed.shape)
        pos_embed = self.pos_embed(input_ids)
        #print('pos_embed',pos_embed.shape)
        residual = embed + pos_embed
        #print('residual', residual.shape)
        for block in self.blocks:
            normalized_resid_pre = block.ln1(residual)
            #print('normalized_resid_pre', normalized_resid_pre.shape)
            attn_out = block.attn(normalized_resid_pre)
            #print('attn_out', attn_out.shape)
            resid_mid = residual + attn_out
            #print('resid_mid', resid_mid.shape)

            normalized_resid_mid = block.ln2(resid_mid)
            #print('normalized_resid_mid', normalized_resid_mid.shape)
            mlp_out = block.mlp(normalized_resid_mid)
            #print('mlp_out', mlp_out.shape)
            resid_post = resid_mid + mlp_out
            #print('resid_post', resid_post.shape)
        normalized_resid_final = self.ln_final(resid_post)
        #print('normalized_resid_final', normalized_resid_final.shape)
        normalized_resid_final = normalized_resid_final.view(normalized_resid_final.shape[0], -1)
        #print('normalized_resid_final', normalized_resid_final.shape)
        logits = self.classification_head1(normalized_resid_final)
        return logits
        
config = reference_gpt2.cfg
num_labels = 2
model = CustomGPT2ForSequenceClassification(config)
model_path = '../trained_models/easy_transformer_gpt2small_qqp_try.pth' 
state_dict = torch.load(model_path)
model.load_state_dict(state_dict, strict=False)
device = torch.device("mps")
model.to(device)
model

In [None]:
model.eval()

In [None]:
num_params = sum(p.numel() for p in model.parameters())
print("Number of parameters in GPT-2 Small model:", num_params)


In [None]:
validation_dataset = load_dataset('glue', 'qqp', split='train')


In [None]:
validation_dataset[0], len(validation_dataset)

In [None]:
random_integer = 5310
validation_dataset[random_integer]['question1']#, validation_dataset[random_integer]['sentence2'], validation_dataset[random_integer]['label']


In [None]:
def tokenize(datapoint, max_length = 1024, token_to_add = 50256):
    sep_place = [0]
    sentence1 = datapoint['question1']
    sentence1_tokens = reference_gpt2.to_tokens(sentence1, prepend_bos = False)
    
    sep_2 = sentence1_tokens.size(1)
    sep_place.append(sep_2+1)
    sentence2 = datapoint['question2']
    sentence2_tokens = reference_gpt2.to_tokens(sentence2, prepend_bos = False)
    
    token_to_add = torch.tensor([50264], dtype=torch.long)
    token_to_add = token_to_add.unsqueeze(0) 
    sentence1_tokens = torch.cat((sentence1_tokens, token_to_add), dim=1)
    concatenated_tokens = torch.cat((sentence1_tokens, sentence2_tokens), dim=1)
    
    labels = torch.tensor(datapoint['label'])
    real_length = concatenated_tokens.size(1)
    remaining_length = max_length - concatenated_tokens.size(1)
    while remaining_length > 0:
        concatenated_tokens = torch.cat((concatenated_tokens, torch.tensor([[token_to_add]])), dim=1)
        remaining_length -= 1
    return concatenated_tokens, labels, real_length, sep_place


In [None]:
tokens, label, length, seperators = tokenize(validation_dataset[random_integer])
tokens, label, length, seperators

In [None]:
def register_hooks(module):
    def hook(module, input, output):
        print("Output shape:", output.shape)  
    # Register the hook to the module
    module.register_forward_hook(hook)

for module in model.modules():
    print(module)
    register_hooks(module)
    break

outputs = model(tokens)



In [None]:
attention_scores_list = []
def register_attention_hooks(module):
    if isinstance(module, EasyTransformer):
        for i, block in enumerate(module.blocks):
            attention_module = block.attn.hook_attn
            def hook(module, input, output):
                attention_scores = output[0]
                attention_scores_list.append(attention_scores)
            attention_module.register_forward_hook(hook)

register_attention_hooks(model)
outputs = model(tokens)


In [None]:
attention_scores_list[0].shape, len(attention_scores_list)

In [None]:
folder = '../gpt2_small/attention_arrays/qqp_5310_try/'
for attention in range(len(attention_scores_list)):
    for head in range(attention_scores_list[attention].shape[0]):
        print(attention, head)
        file_name = folder + 'attn_' + str(attention) + '_' + str(head) + '.npy'
        np.save(file_name, attention_scores_list[attention][head, :, :].cpu().detach().numpy())
