In [1]:
from datasets import load_dataset
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import os
import torch.nn.functional as F
from easy_transformer import EasyTransformer
from torch.utils.data import Dataset, DataLoader
import random
from jaxtyping import Float
from easy_transformer.hook_points import (
    HookPoint,
) 


In [2]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_gpt2


Moving model to device:  cpu
Finished loading pretrained model gpt2-small into EasyTransformer!


EasyTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_attn): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNorm(
    (hook_scale): HookPoint()
    (hook_norm

In [None]:


class CustomGPT2ForSequenceClassification(EasyTransformer):
    def __init__(self, config):
        super().__init__(config)
        self.unembed = None
        self.classification_head1 = torch.nn.Linear(config.d_model * config.n_ctx, num_labels)        
        
    def forward(self, input_ids):
       
        embed = self.embed(tokens=input_ids)
        embed = embed.squeeze(1)
        #print('embed',embed.shape)
        pos_embed = self.pos_embed(input_ids)
        #print('pos_embed',pos_embed.shape)
        residual = embed + pos_embed
        #print('residual', residual.shape)
        for block in self.blocks:
            normalized_resid_pre = block.ln1(residual)
            #print('normalized_resid_pre', normalized_resid_pre.shape)
            attn_out = block.attn(normalized_resid_pre)
            #print('attn_out', attn_out.shape)
            resid_mid = residual + attn_out
            #print('resid_mid', resid_mid.shape)

            normalized_resid_mid = block.ln2(resid_mid)
            #print('normalized_resid_mid', normalized_resid_mid.shape)
            mlp_out = block.mlp(normalized_resid_mid)
            #print('mlp_out', mlp_out.shape)
            resid_post = resid_mid + mlp_out
            #print('resid_post', resid_post.shape)
        normalized_resid_final = self.ln_final(resid_post)
        #print('normalized_resid_final', normalized_resid_final.shape)
        normalized_resid_final = normalized_resid_final.view(normalized_resid_final.shape[0], -1)
        #print('normalized_resid_final', normalized_resid_final.shape)
        logits = self.classification_head1(normalized_resid_final)
        return logits
        


In [None]:
validation_dataset = load_dataset('glue', 'qqp', split='train')


In [None]:
validation_dataset[120], len(validation_dataset)

In [None]:
integer_list = list(range(0, len(validation_dataset)))
random.shuffle(integer_list)
random_integer = random.choice(integer_list)
print("Random Integer:", random_integer)


In [None]:
random_integer = 5310
validation_dataset[random_integer]['question1']#, validation_dataset[random_integer]['sentence2'], validation_dataset[random_integer]['label']


In [None]:
def tokenize(datapoint, max_length = 1024, token_to_add = 50256):
    sep_place = []#[0]
    sentence1 = datapoint['question1']
    sentence1_tokens = reference_gpt2.to_tokens(sentence1, prepend_bos = False)
    
    sep_2 = sentence1_tokens.size(1)
    sep_place.append(sep_2+1)
    sentence2 = datapoint['question2']
    sentence2_tokens = reference_gpt2.to_tokens(sentence2, prepend_bos = False)
    
    #token_to_add = torch.tensor([50256], dtype=torch.long)
    #token_to_add = token_to_add.unsqueeze(0) 
    #sentence1_tokens = torch.cat((sentence1_tokens, token_to_add), dim=1)
    concatenated_tokens = torch.cat((sentence1_tokens, sentence2_tokens), dim=1)
    
    labels = torch.tensor(datapoint['label'])
    real_length = concatenated_tokens.size(1)
    remaining_length = max_length - concatenated_tokens.size(1)
    while remaining_length > 0:
        concatenated_tokens = torch.cat((concatenated_tokens, torch.tensor([[token_to_add]])), dim=1)
        remaining_length -= 1
    return concatenated_tokens, labels, real_length, sep_place


In [None]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
config = reference_gpt2.cfg
num_labels = 2
model = CustomGPT2ForSequenceClassification(config)
model_path = '../trained_models/easy_transformer_gpt2small_qqp.pth' 
state_dict = torch.load(model_path)
model.load_state_dict(state_dict, strict=False)
device = torch.device("cpu")
model.to(device)
model
model.eval()

In [None]:
tokens, label, length, seperators = tokenize(validation_dataset[random_integer])
tokens, label, length, seperators

In [3]:
model.blocks[0].attn

NameError: name 'model' is not defined

In [None]:
def register_hooks(module):
    def hook(module, input, output):
        print("Output shape:", output.shape)  
    # Register the hook to the module
    module.register_forward_hook(hook)


In [None]:
attention_scores_list = []
def register_attention_hooks(module):
    if isinstance(module, EasyTransformer):
        for i, block in enumerate(module.blocks):
            attention_module = block.attn.hook_attn
            def hook(module, input, output):
                # Assuming 'output' is a tuple where the first element is the attention scores
                attention_scores = output[0]
                attention_scores_list.append(attention_scores)
                #print(f"Attention scores shape for block {i}: {attention_scores.shape}")
                # Here you can process or store the attention scores as needed
                # For example, you can save them to a list or perform further analysis

            # Register the hook to the Attention module
            attention_module.register_forward_hook(hook)



In [None]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
config = reference_gpt2.cfg
num_labels = 2
model = CustomGPT2ForSequenceClassification(config)
model_path = '../trained_models/easy_transformer_gpt2small_qqp.pth' 
state_dict = torch.load(model_path)
model.load_state_dict(state_dict, strict=False)
device = torch.device("cpu")
model.to(device)
model
model.eval()

num_params = sum(p.numel() for p in model.parameters())
print("Number of parameters in GPT-2 Small model:", num_params)

In [None]:
gpt2_tokens = tokenize(validation_dataset[1], max_length = 1024, token_to_add = 50256)
gpt2_tokens[0].shape

In [None]:
layer_to_ablate = 5
head_index_to_ablate = 2

def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, 2:5, :] = 0#/1024
    return value

original_loss = model(gpt2_tokens[0])
ablated_loss = model.run_with_hooks(
    gpt2_tokens[0],
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate),
        head_ablation_hook
        )]
    )

original_loss, ablated_loss

In [None]:
import torch
import torch.nn as nn
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  
from transformer_lens import HookedTransformer, FactoredMatrix

In [None]:
device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)


In [None]:
model_description_text = """## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. See my explainer for documentation of all supported models, and this table for hyper-parameters and the name used to load them. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly. 

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!"""
loss = model(model_description_text, return_type="loss")
print("Model loss:", loss)

In [None]:
gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_tokens = model.to_tokens(gpt2_text)
print(gpt2_tokens.device)
gpt2_logits, gpt2_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True)

In [None]:
layer_to_ablate = 0
head_index_to_ablate = 8

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(gpt2_tokens)
ablated_loss = model.run_with_hooks(
    gpt2_tokens, 
     
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
original_loss, ablated_loss

In [None]:
config = reference_gpt2.cfg
num_labels = 2
model = CustomGPT2ForSequenceClassification(config)
model_path = '../trained_models/easy_transformer_gpt2small_qqp.pth' 
state_dict = torch.load(model_path)
model.load_state_dict(state_dict, strict=False)
device = torch.device("cpu")
model.to(device)
model
model.eval()

In [None]:
import transformer_lens.utils as utils

In [None]:
layer_to_ablate = 0
head_index_to_ablate = 8

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(tokens)
ablated_loss = model.run_with_hooks(
    tokens, 
     
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
original_loss, ablated_loss