In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens
    %pip install torchtyping
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")
    # Import stuff
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from jaxtyping import Float
from typing import List, Union, Optional, Callable
from functools import partial
import copy
import itertools
import json
import gc

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML, Markdown
#import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from neel_plotly import line, imshow, scatter
import transformer_lens.patching as patching

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
model = HookedTransformer.from_pretrained(
    'EleutherAI/pythia-12b-deduped-v0',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
    n_devices=7,
    move_to_device=True,
    dtype='float16'
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.tokenizer.padding_side = "left"
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-12b-deduped-v0')
print(f"Using tokenizer {tokenizer}")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-12b-deduped-v0 into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Using tokenizer GPTNeoXTokenizerFast(name_or_path='EleutherAI/pythia-12b-deduped-v0', vocab_size=50254, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<|padding|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	50254: AddedToken("                        ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50255: AddedToken("                       ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50256: AddedToken("                      ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50257: AddedToken

In [3]:
# load data
import yaml
import pickle
import os
class DotDict(dict):
    """ Dot notation access to dictionary attributes """
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
yaml_file_path = "./conf/config.yaml"
with open(yaml_file_path, "r") as f:
    args = DotDict(yaml.safe_load(f))

file_name = args.data_dir
file_name += '/' + str(args.model)
file_name += '/intervention_' + str(args.n_shots) + '_shots_max_' + str(args.max_n) + '_' + args.representation
file_name += '_further_templates' if args.extended_templates else ''
file_name += '_acdc' if args.acdc_data else ''
file_name += '.pkl'

with open(file_name, 'rb') as f:
    intervention_list = pickle.load(f)
print("Loaded data from", file_name)
if args.debug_run:
    intervention_list = intervention_list[:2]

from demos import intervention_dataset
intervention_data = intervention_dataset.InterventionDataset(intervention_list, device, model.tokenizer)
intervention_data.create_intervention_dataset()
intervention_data.shuffle()

Loaded data from /shared-network/shared/2024_ml_master/data/EleutherAI/pythia-12b-deduped-v0/intervention_1_shots_max_20_arabic_further_templates_acdc.pkl


In [4]:
def ave_logit_difference(
    logits: Float[Tensor, 'batch seq d_vocab'],
    intervention_dataset,
    per_prompt: bool = False
):
    batch_size = logits.size(0)
    correct_logits = logits[range(batch_size), -1, intervention_dataset.res_base_toks[:batch_size]]
    incorrect_logits = logits[range(batch_size), -1, intervention_dataset.pred_res_alt_toks[:batch_size]]
    logit_diff = correct_logits - incorrect_logits
    return logit_diff if per_prompt else logit_diff.mean()

def logits_in_batches(model, tokens, attn_mask, bsize):
    model.eval()
    seq_len = tokens.size(0)
    all_logits = []

    with t.no_grad():
        for i in range(0, seq_len, bsize):
            input = tokens[i:i+bsize].to(model.cfg.device)
            attn_m = attn_mask[i:i+bsize].to(model.cfg.device)
            logits = model(input=input, attention_mask=attn_m)
            logits = logits.detach().cpu()
            input = input.detach().cpu()
            attn_mask = attn_mask.detach().cpu()
            all_logits.append(logits)
            del input
            del logits
    return t.cat(all_logits, dim=0)

clean_logits = logits_in_batches(model, intervention_data.base_string_toks, intervention_data.base_attention_mask, 9)
corrupt_logits = logits_in_batches(model, intervention_data.alt_string_toks, intervention_data.alt_attention_mask, 9)
clean_logit_diff = ave_logit_difference(clean_logits, intervention_data, per_prompt=False).item()
corrupt_logit_diff = ave_logit_difference(corrupt_logits, intervention_data, per_prompt=False).item()
print(clean_logit_diff)
print(corrupt_logit_diff)

def metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    intervention_dataset: intervention_data = intervention_data,
    per_prompt: bool = False
 ):
    patched_logit_diff = ave_logit_difference(logits, intervention_dataset, per_prompt)
    metric_result = (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)
    return metric_result

with t.no_grad():   
    clean_metric = metric(clean_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = False)
    corrupt_metric = metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = False)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

2.037109375
-0.78173828125
Clean direction: 2.037109375, Corrupt direction: -0.78173828125
Clean metric: 1.0, Corrupt metric: 0.0


In [5]:
Metric = Callable[[TT["batch_and_pos_dims", "d_model"]], float]

In [6]:
# filter_not_qkv_input = lambda name: "_input" not in name
# def get_cache_fwd_and_bwd(model, tokens, metric):
#     model.reset_hooks()
#     cache = {}
#     def forward_cache_hook(act, hook):
#         cache[hook.name] = act.detach()
#     model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")

#     grad_cache = {}
#     def backward_cache_hook(act, hook):
#         grad_cache[hook.name] = act.detach()
#     model.add_hook(filter_not_qkv_input, backward_cache_hook, "bwd")

#     value = metric(model(tokens))
#     value.backward()
#     model.reset_hooks()
#     return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)

# clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, intervention_data.base_string_toks, metric)
# print("Clean Value:", clean_value)
# print("Clean Activations Cached:", len(clean_cache))
# print("Clean Gradients Cached:", len(clean_grad_cache))
# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, intervention_data.alt_string_toks, metric)
# print("Corrupted Value:", corrupted_value)
# print("Corrupted Activations Cached:", len(corrupted_cache))
# print("Corrupted Gradients Cached:", len(corrupted_grad_cache))

In [7]:
# def get_cache_fwd_and_bwd_batched(model, tokens, metric, batch_size):
#     model.reset_hooks()
#     cache = {}
#     grad_cache = {}
#     num_tokens = tokens.size(0)
#     total_batches = (num_tokens + batch_size - 1) // batch_size  # Calculate the total number of batches
#     values = []
#     tokens = tokens.to("cpu")

#     def forward_cache_hook(act, hook):
#         if hook.name in cache:
#             cache[hook.name] = t.cat((cache[hook.name], act.cpu().detach()), dim=0)
#         else:
#             cache[hook.name] = act.cpu().detach()

#     def backward_cache_hook(act, hook):
#         if hook.name in grad_cache:
#             grad_cache[hook.name] = t.cat((grad_cache[hook.name], act.cpu().detach()), dim=0)
#         else:
#             grad_cache[hook.name] = act.cpu().detach()

#     filter_not_qkv_input = lambda name: "_input" not in name
#     model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")
#     model.add_hook(filter_not_qkv_input, backward_cache_hook, "bwd")

#     for batch_num in range(total_batches):
#         start_idx = batch_num * batch_size
#         end_idx = min((batch_num + 1) * batch_size, num_tokens)
#         batch_tokens = tokens[start_idx:end_idx]
#         batch_tokens = batch_tokens.to(model.cfg.device)

#         # Ensure model's forward method can handle slicing if your tokens tensor represents more complex inputs
#         value = metric(model(batch_tokens))
#         value.backward()  # Make sure your metric calculation and model's backward can handle batched inputs
#         values.append(value.item())
#         batch_tokens = batch_tokens.detach().cpu()

#     model.reset_hooks()

#     # Aggregate or average the metric values across batches if necessary
#     aggregated_value = sum(values) / len(values)

#     return aggregated_value, ActivationCache(cache, model), ActivationCache(grad_cache, model)

# # Example of calling the modified function with batch size:
# clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd_batched(model, intervention_data.base_string_toks, metric, batch_size=1)

# HEAD_NAMES = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
# HEAD_NAMES_SIGNED = [f"{name}{sign}" for name in HEAD_NAMES for sign in ["+", "-"]]
# HEAD_NAMES_QKV = [f"{name}{act_name}" for name in HEAD_NAMES for act_name in ["Q", "K", "V"]]


In [9]:
filter_not_qkv_input = lambda name: "_input" not in name
def get_cache_fwd(model, tokens, metric, bsize):
    model.reset_hooks()
    cache = {}
    tokens = tokens.to("cpu")
    values = []
    def forward_cache_hook(act, hook):
        if hook.name in cache:
            cache[hook.name] = t.cat((cache[hook.name], act.detach().cpu()), dim=0)
        else:
            cache[hook.name] = act.detach().cpu()
    model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")

    with t.no_grad():
        for i in tqdm.tqdm(range(0, tokens.size(0), bsize)):
            input = tokens[i:i+bsize].to(model.cfg.device)
            value = metric(model(input))
            values.append(value.item())
            input = input.detach().cpu()
            del input
            del value
        model.reset_hooks()
        value_result = sum(values) / len(values)
    del tokens
    gc.collect()
    return value_result, ActivationCache(cache, model)


clean_value, clean_cache = get_cache_fwd(model, intervention_data.base_string_toks, metric, bsize=2)




  0%|          | 0/54 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
clean_value

0.3899516352900752

In [None]:
filter_not_qkv_input = lambda name: "_input" not in name
def get_cache_fwd_and_bwd(model, tokens, metric, bsize):
    model.reset_hooks()
    values = []
    #tokens = tokens.to("cpu")
    cache = {}
    def forward_cache_hook(act, hook):
        if hook.name in cache:
            cache[hook.name] = t.cat((cache[hook.name], act.detach().cpu()), dim=0)
        else:
            cache[hook.name] = act.detach().cpu()
    model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")

    grad_cache = {}
    def backward_cache_hook(act, hook):
        if hook.name in grad_cache:
            grad_cache[hook.name] = t.cat((grad_cache[hook.name], act.cpu()), dim=0)
        else:
            grad_cache[hook.name] = act.cpu()
    model.add_hook(filter_not_qkv_input, backward_cache_hook, "bwd")

    for i in tqdm.tqdm(range(0, tokens.size(0), bsize)):
        input = tokens[i:i+bsize].to(model.cfg.device)
        value = metric(model(input))
        input = input.detach().cpu()
        del input
        value.backward()
        values.append(value.item())
        del value
    model.reset_hooks()
    value_result = sum(values) / len(values)
    gc.collect()
    return value_result, ActivationCache(cache, model), ActivationCache(grad_cache, model)

corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, intervention_data.alt_string_toks, metric, bsize=1)

  0%|          | 0/54 [00:00<?, ?it/s]

RuntimeError: [enforce fail at alloc_cpu.cpp:75] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 3656908800 bytes. Error code 12 (Cannot allocate memory)

In [None]:
print(f"clean cache shape {clean_cache['blocks.0.hook_resid_pre'].shape}")
print(f"corrupted cache shape {corrupted_cache['blocks.0.hook_resid_pre'].shape}")
print(f"clean grad cache shape {corrupted_grad_cache['blocks.0.hook_resid_pre'].shape}")

clean cache shape torch.Size([108, 24, 5120])
corrupted cache shape torch.Size([2, 24, 5120])
clean grad cache shape torch.Size([2, 24, 5120])


In [None]:
def attr_patch_layer_out(
        clean_cache: ActivationCache, 
        corrupted_cache: ActivationCache, 
        corrupted_grad_cache: ActivationCache,
    ) -> TT["component", "pos"]:
    clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)
    print(clean_layer_out.shape)
    corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)
    corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)
    layer_out_attr = einops.reduce(
        corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),
        "component batch pos d_model -> component pos",
        "sum"
    )
    return layer_out_attr, labels

layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)
imshow(layer_out_attr, y=layer_out_labels, yaxis="Component", xaxis="Position", title="Layer Output Attribution Patching")

torch.Size([73, 108, 24, 5120])


RuntimeError: The size of tensor a (108) must match the size of tensor b (2) at non-singleton dimension 1