# Installs

In [None]:
!pip install git+https://github.com/andres-vs/TransformerLens.git@b0de195fa5a0f427427e142e9a7066f47bf193f9
!pip install datasets --upgrade

# Set-up

In [None]:
import torch
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm

In [None]:
from huggingface_hub import login
login(token="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", add_to_git_credential=True)

In [None]:
DEPTH = 1
QDep = False
RETRAINED = False
BATCH_SIZE = 64
ALL_HEADS = True

ALL_EXAMPLES = False
PROOF_DEPTH = 1
PROOF_STRATEGY = "proof"

tokenizer_name = "bert-base-uncased"
model_name = "andres-vs/bert-base-uncased-finetuned_Att-Noneg"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
if QDep:
    dataset_name = f"andres-vs/ruletaker-Att-Noneg-QDep{DEPTH}-NoRconc"
    if RETRAINED:
        model_name = model_name + f"-QDep{DEPTH}-NoRconc_retrained"
    else:
        model_name = model_name + f"-QDep{DEPTH}-NoRconc"
else:
    dataset_name = f"andres-vs/ruletaker-Att-Noneg-depth{DEPTH}"
    if RETRAINED:
        model_name = model_name + f"-depth{DEPTH}_retrained-1"
    else:
        model_name = model_name + f"-depth{DEPTH}"

In [None]:
print(model_name)
print(dataset_name)

# Load dataset and preprocessing

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
dataset = load_dataset(dataset_name)

In [None]:
if ALL_EXAMPLES:
    filtered_dataset = dataset['test']
else:
    # Filter the dataset
    filtered_dataset = dataset['test'].filter(lambda example: example['depth'] == PROOF_DEPTH and example['proof_strategy'] == PROOF_STRATEGY)

In [None]:
def tokenize_function(examples):
    tokenized_output = tokenizer(examples["input"], truncation=True, padding=False)
    # Convert labels to one-hot encoding using PyTorch
    labels = torch.tensor(examples['label'], dtype=torch.int64)
    one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=2).float()
    tokenized_output['label'] = one_hot_labels.tolist()  # Convert back to list for datasets
    return tokenized_output

In [None]:
tokenized_dataset = filtered_dataset.map(tokenize_function, batched=True)

# Set the format of the dataset to PyTorch tensors
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label', 'token_type_ids'])

In [None]:
from torch.utils.data import DataLoader

# Define a collate function to stack tensors
def collate_fn(batch):
    # Extract inputs
    inputs = {key: [example[key] for example in batch] for key in batch[0].keys()}

    # Use the tokenizer to pad dynamically
    padded_inputs = tokenizer.pad(
        inputs,
        padding="longest",  # Pad to the longest sequence in the batch
        return_tensors="pt",  # Return PyTorch tensors
    )
    return padded_inputs


# Create a DataLoader for batched access
dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

# Load in model and set up hooks

In [None]:
from transformer_lens import HookedEncoder

# Load the model
model = HookedEncoder.from_pretrained(model_name, tokenizer=tokenizer, head_type='classification')
model.cfg.use_attn_result = True
model.to(device)

In [None]:
# A function to extract the numeric layer and head number from the string
def extract_layer_number_and_head_number(tup):
    layer_str = tup[0]
    # Extract the number between "blocks." and ".attn"
    layer_number = int(layer_str.split('.')[1])
    return layer_number, tup[1]  # Return the layer number and head number for sorting

# A function to extract the numeric layer from the string
def extract_layer_number(hook_name):
    # Extract the number between "blocks." and ".attn"
    layer_number = int(hook_name.split('.')[1])
    return layer_number  # Return the layer number and head number for sorting

In [None]:
pre_attention_layer_residual = []
attention_layer_output = []
post_attention_layer_residual = []
attention_layer_result = []
attention_layer_z_score = []
attention_layer_value_vectors = []

def pre_attention_layer_residual_hook_fn(module_output, hook=None):
    pre_attention_layer_residual.append(module_output.detach())

def attention_layer_output_hook_fn(module_output, hook=None):
    attention_layer_output.append(module_output.detach())

def post_attention_layer_residual_hook_fn(module_output, hook=None):
    post_attention_layer_residual.append(module_output.detach())

def attention_layer_result_hook_fn(module_output, hook=None):
    attention_layer_result.append(module_output.detach())

def attention_layer_z_score_hook_fn(module_output, hook=None):
    attention_layer_z_score.append(module_output.detach())

def attention_layer_value_vectors_hook_fn(module_output, hook=None):
    attention_layer_value_vectors.append(module_output.detach())

# Extract and save attention patterns

In [None]:
NUMBER_OF_EXAMPLES = len(filtered_dataset)
LAYER = 3
HEAD_NUMBER = 11
HOOKS_TO_MONITOR = [(f"blocks.{LAYER}.hook_resid_pre", pre_attention_layer_residual_hook_fn), (f"blocks.{LAYER}.hook_resid_mid", post_attention_layer_residual_hook_fn), (f"blocks.{LAYER}.hook_attn_out", attention_layer_output_hook_fn), (f"blocks.{LAYER}.attn.hook_result", attention_layer_result_hook_fn), (f"blocks.{LAYER}.attn.hook_z", attention_layer_z_score_hook_fn), (f"blocks.{LAYER}.attn.hook_v", attention_layer_value_vectors_hook_fn)]

In [None]:
import os

folder_path = f"./extracted_activations"

if not os.path.exists(folder_path):
  os.makedirs(folder_path)

In [None]:
with model.hooks(fwd_hooks=HOOKS_TO_MONITOR):
    for batch_num, batch in enumerate(tqdm(dataloader)):
        # Reset the kept activations for each batch
        pre_attention_layer_residual = []
        attention_layer_output = []
        post_attention_layer_residual = []
        attention_layer_result = []
        attention_layer_z_score = []
        attention_layer_value_vectors = []

        # Run the model with hooks
        with torch.no_grad():
            model(input=batch['input_ids'], one_zero_attention_mask=batch['attention_mask'])

        output_file = f"./extracted_activations/extracted_activations_depth{DEPTH}_batch{batch_num}_head_a{LAYER}h{HEAD_NUMBER}.pt"
        torch.save({"pre_attention_layer_residual": pre_attention_layer_residual[0], "post_attention_layer_residual": post_attention_layer_residual[0], "attention_layer_output": attention_layer_output[0], "attention_layer_result": attention_layer_result[0][:,:,11,:], "attention_layer_z_score": attention_layer_z_score[0][:,:,11,:], "attention_layer_value_vectors": attention_layer_value_vectors[0][:,:,11,:]}, output_file)

        break


  0%|          | 0/11 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  0%|          | 0/11 [00:06<?, ?it/s]


In [None]:
print(attention_layer_output[0].shape)
print(pre_attention_layer_residual[0].shape)
print(post_attention_layer_residual[0].shape)
print(attention_layer_result[0].shape)
print(attention_layer_z_score[0].shape)
print(attention_layer_value_vectors[0].shape)

torch.Size([64, 142, 768])
torch.Size([64, 142, 768])
torch.Size([64, 142, 768])
torch.Size([64, 142, 12, 768])
torch.Size([64, 142, 12, 64])
torch.Size([64, 142, 12, 64])
