In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections

# CD-T Imports
import math
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import itertools

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.cdt_basic import *
from pyfunctions.cdt_source_to_target import *
from pyfunctions.cdt_from_source_nodes import *
from pyfunctions.ioi_dataset import IOIDataset
from pyfunctions.wrappers import *
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import GPT2Tokenizer, GPT2Model

In [3]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f86d4056c60>

In [5]:
device = 'cuda:0'

In [6]:
# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.
# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.
# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as 
# "folding" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.
# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source 
# code is impossible to traverse, and I gave up on it, thankfully quickly.

from transformer_lens import utils, HookedTransformer, ActivationCache
model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Loaded pretrained model gpt2-small into HookedTransformer


In [7]:
model.cfg.default_prepend_bos = False

## Model inspection

In [None]:
model.cfg.default_prepend_bos # prepends a beginning-of-sequence token? but GPT2 doesn't do this, I think?
model.tokenizer.padding_side # 'right'
model.cfg.positional_embedding_type # 'standard', as opposed to 'shortformer', which puts the positional embedding in the attention pattern
model.cfg

## Example forward pass

## Inspect model / hooks


In [None]:
for activation_name, activation in cache.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(f"{activation_name:30} {tuple(activation.shape)}")

## Test functioning of prop_GPT

In [None]:
source_list = [(0, 0, 0), (1, 1, 1)]
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, device)
encoding_idxs, attention_mask = encoding.input_ids, encoding.attention_mask
input_shape = encoding_idxs.size()
extended_attention_mask = get_extended_attention_mask(attention_mask, 
                                                        input_shape, 
                                                        model,
                                                        device)
out_decomps, target_decomps, _ = prop_GPT(encoding_idxs, extended_attention_mask, model, source_list, target_nodes, device=device, patched_values=None, mean_ablated=False)

In [None]:
print(len(out_decomps)) # one per source node
print(out_decomps[0].rel) 
print(out_decomps[0].rel.shape) # 512, d_vocab-- the effective output is this[-1], but most of this is zeroes, and is only nonzero at target node indices
print(out_decomps[0].irrel)
# TODO: given the formula we have calculated for relevance, is there a more compact representation of the output available?

# TODO: check correctness of various basic operations: if no target node specified, should irrel be the same as output? should sum over all input source nodes of output decomposition be equal to original output?


## Test correctness of forward pass

Given that the above code runs, the forward pass "type checks", e.g, it's at least able to run. We still have to test correctness; this is especially hard given that TL GPT has so many discrepancies from HF BERT.

In [None]:
# This loosely verifies that the out-decomposition's rel + irrel components form the output of the forward pass
compare_same(out_decomps[0].rel + out_decomps[0].irrel, out_decomps[1].rel + out_decomps[1].irrel)

# compare_same(cache['blocks.0.attn.hook_q'], cache['blocks.0.attn.hook_k']) # checking that there aren't the inputs, because the source code names them as such

In [None]:
# compare_same(cache['hook_embed'], cache['blocks.0.hook_resid_pre'])
# This just verifies an assumption about the input conventions to the model.
in_0 = cache['blocks.0.hook_resid_pre']
compare_same(cache['hook_embed'] + cache['hook_pos_embed'], cache['blocks.0.hook_resid_pre'])

# Ensure that the encodings are the same (they won't be if you set prepend_bos=True for one and not the other!)
# print(tokens)
# print(encoding.input_ids)

# Note that model.embed alone isn't the same thing as what we might normally call "the embedding"!
embedding_output = model.embed(encoding.input_ids) + model.pos_embed(encoding.input_ids)
compare_same(in_0, embedding_output[:, :16, :])



In [None]:
# Extra analysis of value matrix calcuation
print(model.blocks[0].attn.W_V.shape)
print(model.blocks[0].attn.W_O.shape)
# print(model.blocks[0].attn.hook_pattern) # 1, 12, seq_len, seq_len



In [None]:
# Test equivalence of attention layer
from fancy_einsum import einsum
'''
blocks.0.attn.hook_q           (1, 17, 12, 64)
blocks.0.attn.hook_k           (1, 17, 12, 64)
blocks.0.attn.hook_v           (1, 17, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 17, 17)
blocks.0.attn.hook_pattern     (1, 12, 17, 17)
blocks.0.attn.hook_z           (1, 17, 12, 64)
'''

input_shape = encoding['input_ids'].size()

attention_mask = get_extended_attention_mask(encoding['attention_mask'], 
                                                        input_shape, 
                                                        model,
                                                        device)
sh = list(embedding_output.shape)
rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel[:] = embedding_output[:]
layer_module = model.blocks[0]
attention_module = layer_module.attn


rel_ln, irrel_ln = prop_layer_norm(rel, irrel, GPTLayerNormWrapper(layer_module.ln1))
'''
print("LayerNorm: ")
compare_same(cache['blocks.0.ln1.hook_normalized'], (rel_ln + irrel_ln)[0, :16, :])

desired_value_output = cache['blocks.0.attn.hook_v']
rel_value, irrel_value = prop_linear(rel_ln, irrel_ln, GPTAttentionWrapper(attention_module, layer_module.ln2).value)
value = rel_value + irrel_value
print("Value matrix:")
# desired_value_output = desired_value_output.transpose(-1, -2)
old_shape = desired_value_output.size()
new_shape = old_shape[:-2] + (old_shape[-2] * old_shape[-1],)
desired_value_output = desired_value_output.reshape(new_shape)
print(desired_value_output.shape)
print(value.shape)

compare_same(value[0, :16, :], desired_value_output, rtol = 1e-2) # some elements differ by more than 1e-3 in proportion, not great
'''

# desired_block_output = cache['blocks.0.hook_resid_mid']

desired_block_output = cache['blocks.0.hook_attn_out']
desired_attention_pattern = cache['blocks.0.attn.hook_pattern']
# level = 0
# source_node_list = [(0, 0, 0), (1, 1, 1)]
# target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]
attn_wrapper = GPTAttentionWrapper(attention_module)
rel_summed_values, irrel_summed_values, returned_att_probs = prop_self_attention_hh(rel_ln, irrel_ln, attention_mask, 
                                                                        None,
                                                                        attn_wrapper,
                                                                        att_probs=None, output_att_prob=True)
print("'Z' matrix:")
desired_z = cache['blocks.0.attn.hook_z']
old_shape = desired_z.size()
new_shape = old_shape[:-2] + (old_shape[-2] * old_shape[-1],)
desired_z = desired_z.reshape(new_shape)
z = rel_summed_values + irrel_summed_values
compare_same(desired_z, z[:, :16, :])

rel_attn_residual, irrel_attn_residual = prop_linear(rel_summed_values, irrel_summed_values, attn_wrapper.output)

print("Output matrix:")
output = rel_attn_residual + irrel_attn_residual
# output = rel_attn_residual + irrel_attn_residual + rel_ln + irrel_ln
compare_same(output[:, :16, :], desired_block_output)


desired_mid_output = cache['blocks.0.hook_resid_mid']
rel_mid, irrel_mid = rel + rel_attn_residual, irrel + irrel_attn_residual
mid_output = rel_mid + irrel_mid
compare_same(mid_output[:, :16, :], desired_mid_output)




In [None]:
# Test equivalence of entire transformer block
in_1 = cache['blocks.1.hook_resid_pre']
input_shape = encoding['input_ids'].size()

extended_attention_mask = get_extended_attention_mask(encoding['attention_mask'], 
                                                        input_shape, 
                                                        model,
                                                        device)
sh = list(embedding_output.shape)
rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel[:] = embedding_output[:]
layer_module = model.blocks[0]

source_node_list = [(0, 0, 0), (1, 1, 1)]
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]
rel_out, irrel_out, layer_target_decomps, _ = prop_GPT_layer_hh(rel, irrel, extended_attention_mask, 
                                                                                 None, source_node_list, 
                                                                                 target_nodes, 0, 
                                                                                 None,
                                                                                 layer_module, 
                                                                                 device,
                                                                                 None, False,
                                                                                 mean_ablated=False)

layer0_output = rel_out + irrel_out
compare_same(in_1, layer0_output[:, :16, :])

# want to check: layer_target_decomps is what we think it is (currently gated by, i didn't implement the decomposition part)
# also, rel + irrel sums to the correct output
# layer_target_decomps[0].rel + layer_target_decomps[0].irrel == layer_target_decomps[1].rel + layer_target_decomps[1].irrel


In [None]:
# Test equivalence of model run to logits
print(logits.shape)
model_out = out_decomps[0].rel + out_decomps[0].irrel
print(model_out.shape)
compare_same(logits[0], model_out[:16, :])

## Test correctness of batching

In [10]:
from pyfunctions.ioi_dataset import IOIDataset

# Generate a dataset all consisting of one template, randomly chosen.
# nb_templates = 2 due to some logic internal to IOIDataset:
# essentially, the nouns can be an ABBA or ABAB order and that counts as separate templates.
ioi_dataset = IOIDataset(prompt_type="mixed", N=3, tokenizer=model.tokenizer, prepend_bos=False, nb_templates=2)

# This is the P_ABC that is mentioned in the IOI paper, which we use for mean ablation.
# Importantly, passing in prompt_type="ABC" or similar is NOT the same thing as this.
abc_dataset = (
    ioi_dataset.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
)

2024-08-30 13:20:47.277592: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
logits, cache = model.run_with_cache(abc_dataset.toks) # run on entire dataset along batch dimension
attention_outputs = [cache['blocks.' + str(i) + '.hook_attn_out'] for i in range(12)]
attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, head, seq, d_model
mean_acts = torch.mean(attention_outputs, dim=0)
mean_acts.shape


torch.Size([12, 16, 768])

In [24]:
text = ioi_dataset.sentences[0]
encoding = model.tokenizer.encode_plus(text, 
                                 add_special_tokens=True, 
                                 max_length=512,
                                 truncation=True, 
                                 padding = "longest", 
                                 return_attention_mask=True, 
                                 return_tensors="pt").to(device)
encoding_idxs, attention_mask = encoding.input_ids, encoding.attention_mask
input_shape = encoding_idxs.size()
extended_attention_mask = get_extended_attention_mask(attention_mask, 
                                                        input_shape, 
                                                        model,
                                                        device)

import functools
ranges = [
        [layer for layer in range(11, 12)],
        [sequence_position for sequence_position in range(input_shape[-1])],
        [attention_head_idx for attention_head_idx in range(1)]
    ]

source_nodes = list(itertools.product(*ranges))
# print(source_nodes[:64])
# target_nodes = [(7, 0, 1)]
target_nodes = []

prop_fn = lambda snl: prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, snl, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)
batch_out_decomps, target_decomps = batch_run(prop_fn, source_nodes)
iter_out_decomps = []

# batching is broken for now, just run one by one
for layer in range(11, 12):
    for sequence_position in range(input_shape[-1]):
        for attention_head_idx in range(1):
            source_node = (layer, sequence_position, attention_head_idx)
            target_nodes = []
            out_decomp, _, _ = prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, [source_node], target_nodes, mean_acts=mean_acts, set_irrel_to_mean=True, device=device)
            iter_out_decomps.append(out_decomp[0])
            


In [None]:

manual_batch_out_decomps, _, _ = prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, source_nodes, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)

In [None]:
print(batch_out_decomps[0].rel[0][4][11])
print(iter_out_decomps[0].rel[0][4][11])
print(batch_out_decomps[0].source_node)

In [None]:
for seq_pos in range(16):
    for head_idx in range(12):
        print(seq_pos, head_idx)
        compare_same(batch_out_decomps[0].rel[0][seq_pos][head_idx], iter_out_decomps[0].rel[0][seq_pos][head_idx])


In [25]:
compare_same(batch_out_decomps[0].rel, iter_out_decomps[0].rel)
# compare_same(manual_batch_out_decomps[0].rel, iter_out_decomps[0].rel)
# compare_same(manual_batch_out_decomps[0].rel, batch_out_decomps[0].rel)



100.00% of the values are correct


In [None]:
import importlib
import pyfunctions.cdt_source_to_target
from pyfunctions.cdt_source_to_target import *

importlib.reload(pyfunctions.cdt_source_to_target)


In [None]:
for x in range(16):
    batch = batch_out_decomps[x]
    it = iter_out_decomps[x]
    print(batch.rel.shape, it.rel.shape)