In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections

# CD-T Imports
import math
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import itertools

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.patch_hh import *
from pyfunctions.ioi_dataset import IOIDataset
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import GPT2Tokenizer, GPT2Model

In [3]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7e56804eac60>

In [4]:
device = 'cpu'

In [5]:
# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.
# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.
# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as 
# "folding" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.
# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source 
# code is impossible to traverse, and I gave up on it, thankfully quickly.

from transformer_lens import utils, HookedTransformer, ActivationCache
model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

Loaded pretrained model gpt2-small into HookedTransformer


## Example forward pass

In [6]:
text = "After John and Mary went to the store, John gave a bottle of milk to"
tokens = model.to_tokens(text).to(device)
logits, cache = model.run_with_cache(tokens)
probs = logits.softmax(dim=-1)
most_likely_next_tokens = model.tokenizer.batch_decode(logits.argmax(dim=-1)[0])

## Inspect model or hooks


In [20]:
print(model.blocks[0].attn.W_Q.shape)

torch.Size([12, 768, 64])


In [26]:
for activation_name, activation in cache.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(f"{activation_name:30} {tuple(activation.shape)}")

hook_embed                     (1, 17, 768)
hook_pos_embed                 (1, 17, 768)
blocks.0.hook_resid_pre        (1, 17, 768)
blocks.0.ln1.hook_scale        (1, 17, 1)
blocks.0.ln1.hook_normalized   (1, 17, 768)
blocks.0.attn.hook_q           (1, 17, 12, 64)
blocks.0.attn.hook_k           (1, 17, 12, 64)
blocks.0.attn.hook_v           (1, 17, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 17, 17)
blocks.0.attn.hook_pattern     (1, 12, 17, 17)
blocks.0.attn.hook_z           (1, 17, 12, 64)
blocks.0.hook_attn_out         (1, 17, 768)
blocks.0.hook_resid_mid        (1, 17, 768)
blocks.0.ln2.hook_scale        (1, 17, 1)
blocks.0.ln2.hook_normalized   (1, 17, 768)
blocks.0.mlp.hook_pre          (1, 17, 3072)
blocks.0.mlp.hook_post         (1, 17, 3072)
blocks.0.hook_mlp_out          (1, 17, 768)
blocks.0.hook_resid_post       (1, 17, 768)
ln_final.hook_scale            (1, 17, 1)
ln_final.hook_normalized       (1, 17, 768)


## Test, currently

In [14]:
source_list = [(0, 0, 0), (1, 1, 1)]
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, device)
out_decomps, target_decomps, _ = prop_GPT(encoding, model, source_list, target_nodes, device=device, patched_values=None, mean_ablated=False)

In [17]:
print(len(out_decomps)) # one per source node
print(out_decomps[0].rel) 
print(out_decomps[0].rel.shape) # 512, d_vocab-- the effective output is this[-1], but most of this is zeroes, and is only nonzero at target node indices
# TODO: given the formula we have calculated for relevance, is there a more compact representation of the output available?

# TODO: calculate the relevance to logits using this function, or isn't it already calculated?
# TODO: check correctness of various basic operations: if no target node specified, should irrel be the same as output? should sum over all input source nodes of output decomposition be equal to original output?


2
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(512, 50257)


In [22]:
# test
pos_specific_hs = [
        [i for i in range(12)],
        [0],
        [i for i in range(12)]
    ]
all_heads = list(itertools.product(*pos_specific_hs))
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)] # not meaningful in a GPT context
source_list = [[node] for node in all_heads if node not in target_nodes]

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, device)
# encoding.input_ids.shape # 512-long vector, not sure why the tokens change from EOS to 0 at some point
# embedding = model.embed(encoding.input_ids)

out_decomps, target_decomps = prop_model_hh_batched(encoding, model, source_list, target_nodes,
                                                                   device=device,
                                                                   patched_values=None, mean_ablated=False, num_at_time=1)
                                                                   # patched_values=mean_act, mean_ablated=True)
                                                                

## Generate mean activations