In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections

# CD-T Imports
import math
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import itertools

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.cdt_basic import *
from pyfunctions.ioi_dataset import IOIDataset
from pyfunctions.wrappers import *
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import GPT2Tokenizer, GPT2Model

In [3]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x71046f237800>

In [4]:
device = 'cpu'

In [5]:
# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.
# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.
# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as 
# "folding" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.
# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source 
# code is impossible to traverse, and I gave up on it, thankfully quickly.

from transformer_lens import utils, HookedTransformer, ActivationCache
model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

Loaded pretrained model gpt2-small into HookedTransformer


In [6]:
model.cfg.default_prepend_bos = False

## Model inspection

In [7]:
model.cfg.default_prepend_bos # prepends a beginning-of-sequence token? but GPT2 doesn't do this, I think?
model.tokenizer.padding_side # 'right'
model.cfg.positional_embedding_type # 'standard', as opposed to 'shortformer', which puts the positional embedding in the attention pattern
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': False,
 'device': device(type='cpu'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional

## Example forward pass

In [7]:
text = "After John and Mary went to the store, John gave a bottle of milk to"
tokens = model.to_tokens(text, padding_side='right', truncate=False).to(device)
print(tokens.shape)
logits, cache = model.run_with_cache(tokens)
probs = logits.softmax(dim=-1)
most_likely_next_tokens = model.tokenizer.batch_decode(logits.argmax(dim=-1)[0])

torch.Size([1, 16])


## Inspect model / hooks


In [9]:
dir(model)

['OV',
 'QK',
 'T_destination',
 'W_E',
 'W_E_pos',
 'W_K',
 'W_O',
 'W_Q',
 'W_U',
 'W_V',
 'W_gate',
 'W_in',
 'W_out',
 'W_pos',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_enable_hook',
 '_enable_hook_with_name',
 '_enable_hooks_for_points',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_init_weights_gpt2',
 '_in

In [78]:
print(dir(model.blocks[0]))

['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hoo

In [11]:
for activation_name, activation in cache.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(f"{activation_name:30} {tuple(activation.shape)}")

hook_embed                     (1, 16, 768)
hook_pos_embed                 (1, 16, 768)
blocks.0.hook_resid_pre        (1, 16, 768)
blocks.0.ln1.hook_scale        (1, 16, 1)
blocks.0.ln1.hook_normalized   (1, 16, 768)
blocks.0.attn.hook_q           (1, 16, 12, 64)
blocks.0.attn.hook_k           (1, 16, 12, 64)
blocks.0.attn.hook_v           (1, 16, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 16, 16)
blocks.0.attn.hook_pattern     (1, 12, 16, 16)
blocks.0.attn.hook_z           (1, 16, 12, 64)
blocks.0.hook_attn_out         (1, 16, 768)
blocks.0.hook_resid_mid        (1, 16, 768)
blocks.0.ln2.hook_scale        (1, 16, 1)
blocks.0.ln2.hook_normalized   (1, 16, 768)
blocks.0.mlp.hook_pre          (1, 16, 3072)
blocks.0.mlp.hook_post         (1, 16, 3072)
blocks.0.hook_mlp_out          (1, 16, 768)
blocks.0.hook_resid_post       (1, 16, 768)
ln_final.hook_scale            (1, 16, 1)
ln_final.hook_normalized       (1, 16, 768)


## Test functioning of prop_GPT

In [8]:
source_list = [(0, 0, 0), (1, 1, 1)]
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, device)
encoding_idxs, attention_mask = encoding.input_ids, encoding.attention_mask
input_shape = encoding_idxs.size()
extended_attention_mask = get_extended_attention_mask(attention_mask, 
                                                        input_shape, 
                                                        model,
                                                        device)
out_decomps, target_decomps, _ = prop_GPT(encoding_idxs, extended_attention_mask, model, source_list, target_nodes, device=device, patched_values=None, mean_ablated=False)

In [None]:
print(len(out_decomps)) # one per source node
print(out_decomps[0].rel) 
print(out_decomps[0].rel.shape) # 512, d_vocab-- the effective output is this[-1], but most of this is zeroes, and is only nonzero at target node indices
print(out_decomps[0].irrel)
# TODO: given the formula we have calculated for relevance, is there a more compact representation of the output available?

# TODO: check correctness of various basic operations: if no target node specified, should irrel be the same as output? should sum over all input source nodes of output decomposition be equal to original output?


2
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(512, 50257)
[[ 5.5012274  8.864304   3.4959507 ... -4.0528517 -2.6054635  7.205061 ]
 [ 5.5009346  8.865846   3.4954956 ... -4.053485  -2.6107361  7.206545 ]
 [ 5.499007   8.864225   3.4940565 ... -4.0518312 -2.6099834  7.2044487]
 ...
 [ 5.4935718  8.862474   3.4887767 ... -4.0536976 -2.6203256  7.198222 ]
 [ 5.4935718  8.862474   3.4887767 ... -4.0536976 -2.6203256  7.198222 ]
 [ 5.4935718  8.862474   3.4887767 ... -4.0536976 -2.6203256  7.198222 ]]


## Test correctness of forward pass

Given that the above code runs, the forward pass "type checks", e.g, it's at least able to run. We still have to test correctness; this is especially hard given that TL GPT has so many discrepancies from HF BERT.

In [9]:
def compare_same(a, b, atol=1e-4, rtol=1e-3):
    if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
        comparison = torch.isclose(a, b, atol, rtol)
        print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
        return
    comparison = np.isclose(a, b, atol, rtol)
    print(f"{comparison.sum()/comparison.size:.2%} of the values are correct")

In [10]:
# This loosely verifies that the out-decomposition's rel + irrel components form the output of the forward pass
compare_same(out_decomps[0].rel + out_decomps[0].irrel, out_decomps[1].rel + out_decomps[1].irrel)

# compare_same(cache['blocks.0.attn.hook_q'], cache['blocks.0.attn.hook_k']) # checking that there aren't the inputs, because the source code names them as such

100.00% of the values are correct


In [11]:
# compare_same(cache['hook_embed'], cache['blocks.0.hook_resid_pre'])
# This just verifies an assumption about the input conventions to the model.
in_0 = cache['blocks.0.hook_resid_pre']
compare_same(cache['hook_embed'] + cache['hook_pos_embed'], cache['blocks.0.hook_resid_pre'])

# Ensure that the encodings are the same (they won't be if you set prepend_bos=True for one and not the other!)
# print(tokens)
# print(encoding.input_ids)

# Note that model.embed alone isn't the same thing as what we might normally call "the embedding"!
embedding_output = model.embed(encoding.input_ids) + model.pos_embed(encoding.input_ids)
compare_same(in_0, embedding_output[:, :16, :])



100.00% of the values are correct
100.00% of the values are correct


In [12]:
# Extra analysis of value matrix calcuation
print(model.blocks[0].attn.W_V.shape)
print(model.blocks[0].attn.W_O.shape)
# print(model.blocks[0].attn.hook_pattern) # 1, 12, seq_len, seq_len



torch.Size([12, 768, 64])
torch.Size([12, 64, 768])


In [13]:
# Test equivalence of attention layer
from fancy_einsum import einsum
'''
blocks.0.attn.hook_q           (1, 17, 12, 64)
blocks.0.attn.hook_k           (1, 17, 12, 64)
blocks.0.attn.hook_v           (1, 17, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 17, 17)
blocks.0.attn.hook_pattern     (1, 12, 17, 17)
blocks.0.attn.hook_z           (1, 17, 12, 64)
'''

input_shape = encoding['input_ids'].size()

attention_mask = get_extended_attention_mask(encoding['attention_mask'], 
                                                        input_shape, 
                                                        model,
                                                        device)
sh = list(embedding_output.shape)
rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel[:] = embedding_output[:]
layer_module = model.blocks[0]
attention_module = layer_module.attn


rel_ln, irrel_ln = prop_layer_norm(rel, irrel, GPTLayerNormWrapper(layer_module.ln1))
'''
print("LayerNorm: ")
compare_same(cache['blocks.0.ln1.hook_normalized'], (rel_ln + irrel_ln)[0, :16, :])

desired_value_output = cache['blocks.0.attn.hook_v']
rel_value, irrel_value = prop_linear(rel_ln, irrel_ln, GPTAttentionWrapper(attention_module, layer_module.ln2).value)
value = rel_value + irrel_value
print("Value matrix:")
# desired_value_output = desired_value_output.transpose(-1, -2)
old_shape = desired_value_output.size()
new_shape = old_shape[:-2] + (old_shape[-2] * old_shape[-1],)
desired_value_output = desired_value_output.reshape(new_shape)
print(desired_value_output.shape)
print(value.shape)

compare_same(value[0, :16, :], desired_value_output, rtol = 1e-2) # some elements differ by more than 1e-3 in proportion, not great
'''

# desired_block_output = cache['blocks.0.hook_resid_mid']

desired_block_output = cache['blocks.0.hook_attn_out']
desired_attention_pattern = cache['blocks.0.attn.hook_pattern']
# level = 0
# source_node_list = [(0, 0, 0), (1, 1, 1)]
# target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]
attn_wrapper = GPTAttentionWrapper(attention_module)
rel_summed_values, irrel_summed_values, returned_att_probs = prop_self_attention_hh(rel_ln, irrel_ln, attention_mask, 
                                                                        None,
                                                                        attn_wrapper,
                                                                        att_probs=None, output_att_prob=True)
print("'Z' matrix:")
desired_z = cache['blocks.0.attn.hook_z']
old_shape = desired_z.size()
new_shape = old_shape[:-2] + (old_shape[-2] * old_shape[-1],)
desired_z = desired_z.reshape(new_shape)
z = rel_summed_values + irrel_summed_values
compare_same(desired_z, z[:, :16, :])

rel_attn_residual, irrel_attn_residual = prop_linear(rel_summed_values, irrel_summed_values, attn_wrapper.output)

print("Output matrix:")
output = rel_attn_residual + irrel_attn_residual
# output = rel_attn_residual + irrel_attn_residual + rel_ln + irrel_ln
compare_same(output[:, :16, :], desired_block_output)


desired_mid_output = cache['blocks.0.hook_resid_mid']
rel_mid, irrel_mid = rel + rel_attn_residual, irrel + irrel_attn_residual
mid_output = rel_mid + irrel_mid
compare_same(mid_output[:, :16, :], desired_mid_output)




'Z' matrix:
99.71% of the values are correct
Output matrix:
99.89% of the values are correct
99.89% of the values are correct


In [14]:
# Test equivalence of entire transformer block
in_1 = cache['blocks.1.hook_resid_pre']
input_shape = encoding['input_ids'].size()

extended_attention_mask = get_extended_attention_mask(encoding['attention_mask'], 
                                                        input_shape, 
                                                        model,
                                                        device)
sh = list(embedding_output.shape)
rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
irrel[:] = embedding_output[:]
layer_module = model.blocks[0]

source_node_list = [(0, 0, 0), (1, 1, 1)]
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]
rel_out, irrel_out, layer_target_decomps, _ = prop_GPT_layer_hh(rel, irrel, extended_attention_mask, 
                                                                                 None, source_node_list, 
                                                                                 target_nodes, 0, 
                                                                                 None,
                                                                                 layer_module, 
                                                                                 device,
                                                                                 None, False,
                                                                                 mean_ablated=False)

layer0_output = rel_out + irrel_out
compare_same(in_1, layer0_output[:, :16, :])

# want to check: layer_target_decomps is what we think it is (currently gated by, i didn't implement the decomposition part)
# also, rel + irrel sums to the correct output
# layer_target_decomps[0].rel + layer_target_decomps[0].irrel == layer_target_decomps[1].rel + layer_target_decomps[1].irrel


99.97% of the values are correct


In [17]:
# Test equivalence of model run to logits
print(logits.shape)
model_out = out_decomps[0].rel + out_decomps[0].irrel
print(model_out.shape)
compare_same(logits[0], model_out[:16, :])

torch.Size([1, 16, 50257])
(512, 50257)
100.00% of the values are correct


: 

## Generate mean activations