# Purpose

This notebook just contains a bunch of cells that I used to explore the model and some alternatives. 
It is mostly not directly useful to producing the experiments.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import argparse

import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections

# CD-T Imports
import math
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import itertools

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.cdt_basic import *
from pyfunctions.ioi_dataset import IOIDataset
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import GPT2Tokenizer, GPT2Model

In [None]:
torch.autograd.set_grad_enabled(False)

## Load Model

Note: Unlike with the BERT model + medical dataset objective, it is not necessary to pretrain GPT-2 to perform the IOI dataset.
GPT-2-small is already capable of performing IOI; that's part of the point of the Mech Interp in the Wild paper.
We only need to examine how it does it.

In [None]:
device = 'cpu'

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')

In [None]:
print(model)

In [None]:
text = "Replace me by any text you'd"
input = tokenizer(text, return_tensors='pt').input_ids
# print(encoded_input) # has 'input_idx' and 'attention_mask'
# output = model(input)
# print(output.last_hidden_state.shape)
gen_tokens = model.generate(input, pad_token_id=tokenizer.pad_token_id, output_scores=True)
print(gen_tokens)
gen_text = tokenizer.batch_decode(gen_tokens)
gen_text

In [None]:
# other exploratory stuff
# model.to_tokens(text) #turns out this is a utility of trasnformer_lens
# print(output.past_key_values[0][0].shape) # this has to do with key matrix stuff
#print(output.values())
#output.logits

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset
ioi_dataset = load_dataset("fahamu/ioi")
# i've decided against using this for the most part; it's better to use the raw IOIDataset 
# from the paper and from the related notebook using EasyTransformers, since these both provide many utilities for dealing with the data

In [None]:
!pip install transformer_lens
!pip install einops

In [None]:
# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.
# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.
# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as 
# "folding" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.
# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source 
# code is impossible to traverse, and I gave up on it, thankfully quickly.

from transformer_lens import utils, HookedTransformer, ActivationCache
model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

## Example forward pass

In [None]:
text = "After John and Mary went to the store, John gave a bottle of milk to"
tokens = model.to_tokens(text).to(device)
logits, cache = model.run_with_cache(tokens)
probs = logits.softmax(dim=-1)
most_likely_next_tokens = model.tokenizer.batch_decode(logits.argmax(dim=-1)[0])

In [None]:
for activation_name, activation in cache.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(f"{activation_name:30} {tuple(activation.shape)}")

In [None]:
# hack to get model dtype out, for compatibility with other code
next(model.parameters()).dtype

In [None]:
print(model)
# print(model.config) # doesn't work on hookedtransformer, is a huggingface thing
# print(model.embed.dtype) same, but can use dtype trick
# print(type(model))
#model.state_dict().keys()#.blocks[0].mlp

In [None]:
import inspect
# inspect.getclasstree(inspect.getmro(type(model)))
inspect.getmro(type(model))

In [None]:
import torchviz
dir(model)
# torchviz.make_dot(model)
# model._modules

In [None]:
!pip install torchsummary
!pip install torchinfo

In [None]:
import pdb
from torchinfo import summary

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, "cpu")
# embedding_output = model.embed(encoding.input_ids)
input_shape = encoding.input_ids.shape
print(input_shape)
pdb.set_trace()
summary(model, input_shape, device='cpu')


In [None]:
# Same as in the notebook, example
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = "Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)o

## Generate dataset/Explore types

In [None]:
data = IOIDataset(N=500, prompt_type="ABBA", tokenizer=model.tokenizer)
#data.tokenized_prompts
data.ioi_prompts[0]
[x['TEMPLATE_IDX'] for x in data.ioi_prompts[0:10]]

In [None]:
# test
pos_specific_hs = [
        [i for i in range(12)],
        [0],
        [i for i in range(12)]
    ]
all_heads = list(itertools.product(*pos_specific_hs))
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)] # not meaningful in a GPT context
source_list = [[node] for node in all_heads if node not in target_nodes]

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, device)
# encoding.input_ids.shape # 512-long vector, not sure why the tokens change from EOS to 0 at some point
# embedding = model.embed(encoding.input_ids)

out_decomps, target_decomps = prop_model_hh_batched(encoding, model, source_list, target_nodes,
                                                                   device=device,
                                                                   patched_values=None, mean_ablated=False, num_at_time=1)
                                                                   # patched_values=mean_act, mean_ablated=True)
                                                                

## Explore IOI Dataset


In [None]:
from pyfunctions.ioi_dataset import IOIDataset

ioi_dataset = IOIDataset(prompt_type="mixed", N=50, tokenizer=model.tokenizer, prepend_bos=False)

In [None]:
ioi_dataset.toks.shape

ioi_dataset.word_idx

ioi_dataset.sentences[:4]

ioi_dataset.groups

# ioi_dataset.toks[ioi_dataset.groups[-1]]
# [ioi_dataset.sentences[x] for x in ioi_dataset.groups[3]] # sentences of the same group are identical except for the choice of nouns

In [None]:
# The below is wrong! The generated sentences are not of the same format as what is described in the paper, and 
# this is also not what they do in their experiments.py.
# abc_dataset = IOIDataset(prompt_type="ABC mixed", N=50, tokenizer=model.tokenizer, prepend_bos=False)

# Instead, do this, apparently.
abc_dataset = (
    ioi_dataset.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
) # Note generating several of these in a row will generate different random names; this can be useful for a quick mean ablation.

In [None]:
abc_dataset.sentences[:4]