Setup

In [None]:
Key reference used: https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=rIsrKr2QzpBV%pip install git+https://github.com/neelnanda-io/TransformerLens.git
%pip install gputil
%pip install psutil
%pip install humanize

In [1]:
import torch
import einops
import pysvelte
import tqdm
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [2]:
from functools import partial
from jaxtyping import Float, Int
import plotly.express as px
import os,sys,humanize,psutil,GPUtil
import gc
import copy

In [3]:
def mem_report():
  print("CPU RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ))
  
  GPUs = GPUtil.getGPUs()
  for i, gpu in enumerate(GPUs):
    print('GPU {:d} ... Mem Free: {:.0f}MB / {:.0f}MB | Utilization {:3.0f}%'.format(i, gpu.memoryFree, gpu.memoryTotal, gpu.memoryUtil*100))


In [4]:
gpt2_large = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)
torch.cuda.empty_cache()

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


In [5]:
races = ["Indian", "Polish"]
for race in tqdm.tqdm(races):
  text = "An example of an " + race + " food is the"
  for i in range(2):
      logits = gpt2_large(text, return_type="logits")
      text += gpt2_large.tokenizer.decode(logits[-1, -1].argmax())
  print(text)

 50%|█████     | 1/2 [00:03<00:03,  3.64s/it]

An example of an Indian food is the dal


100%|██████████| 2/2 [00:03<00:00,  1.93s/it]

An example of an Polish food is the "K





In [11]:
races = ["Indian", "Chinese"]
for race in tqdm.tqdm(races):
  text =  "The " + race + " have an annual festival called"
  for i in range(15):
      logits = gpt2_large(text, return_type="logits")
      text += gpt2_large.tokenizer.decode(logits[-1, -1].argmax())
  print(text)

 50%|█████     | 1/2 [00:01<00:01,  1.40s/it]

The Indian have an annual festival called Diwali, which is celebrated on the first day of the Hindu lunar


100%|██████████| 2/2 [00:03<00:00,  1.69s/it]

The Chinese have an annual festival called the Spring Festival, which is celebrated on the first day of the lunar new





In [5]:
text = "Indian people's right"
for i in range(30):
      logits = gpt2_large(text, return_type="logits")
      text += gpt2_large.tokenizer.decode(logits[-1, -1].argmax())
print(text)

Indian students should be admitted over their peers because they are more likely to be successful in the job market, the country's top education regulator has said.


The National Council for Technical Education


In [5]:
mem_report()

CPU RAM Free: 20.4 GB
GPU 0 ... Mem Free: 4040MB / 16384MB | Utilization  74%


In [6]:
clean_prompts = ["The Polish live in the city of", "The Indian live in the city of",
                 "An example of an Indian food is the", "The Indian have an annual festival called",
                 "The Chinese have an annual festival called"]
corrupted_prompts = ["The Indian live in the city of", "The Polish live in the city of",
                     "An example of an Polish food is the", "The Chinese have an annual festival called",
                     "The Indian have an annual festival called"]
answers = ["Mumbai", "Krakow", "dal", "Diwali", "the Spring Festival"]

In [None]:
def get_output_logits(prompt, answer_tokens):
    prompt_tokens = gpt2_large.to_tokens(prompt, prepend_bos=True)
    logits, cache = gpt2_large.run_with_cache(prompt_tokens, return_type="logits")
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    return answer_logits.mean()

In [16]:
clean_caches = []
for i in range(5):
    clean_prompt = clean_prompts[i]
    corrupted_prompt = corrupted_prompts[i]
    answer = answers[i]

    prompt_tokens = gpt2_large.to_tokens(clean_prompt, prepend_bos=True)
    logits, clean_cache = gpt2_large.run_with_cache(prompt_tokens, device="cpu", return_type="logits")
    clean_caches.append(clean_cache)

In [22]:
mlp_list = []

def get_mlp_hook(
  hook, 
  pos, 
  cache):
  # patch corrupted component for some position in text with corresponding
  # compontent from the clean version
  mlp_list.append(cache[hook.name][:, pos, :])
  return 

def get_mlp(layer, position, cache):
  gpt2_large.run_with_hooks(
        fwd_hooks = [(utils.get_act_name("mlp_out", layer), position, cache,
                      get_mlp_hook)], 
        return_type="None"
  )

get_mlp()
#clean_caches[0].get_neuron_results(layer=1)
#print(clean_caches[1].get_neuron_results(1, pos_slice=(2))[0::].shape)
#clean_caches[2].get_neuron_results(1, pos_slice=(2))[0::]
# 5, 2

SyntaxError: ignored

In [None]:
answer_tokens = gpt2_large.to_tokens(answer, prepend_bos=False)
corrupted_tokens = gpt2_large.to_tokens(corrupted_prompt, prepend_bos=True)
corrupted_answer_logits = get_output_logits(corrupted_prompt, answer_tokens)
clean_answer_logits = get_output_logits(clean_prompt, answer_tokens)

In [17]:
print(answer_tokens, clean_answer_logits, corrupted_answer_logits)

tensor([[   35, 14246,  7344]], device='cuda:0') tensor(3.1139, device='cuda:0', grad_fn=<MeanBackward0>) tensor(2.1499, device='cuda:0', grad_fn=<MeanBackward0>)


In [7]:
prompt_tokens = gpt2_large.to_tokens(clean_prompt, prepend_bos=True)
logits, clean_cache = gpt2_large.run_with_cache(prompt_tokens, device="cpu", return_type="logits")

In [8]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

In [21]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook, 
    pos, 
    clean_cache):
    # patch corrupted component for some position in text with corresponding
    # compontent from the clean version
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalized_patched_logit(patched_logits):
    # 0 -> no change between corrupted before and after aptch.
    # < 1 -> logits before patch were larger -> dal more probable before patch
    # > 1 -> logits after patch were larger -> dal more probable after patch
    return (patched_logits - corrupted_answer_logits)/(clean_answer_logits - corrupted_answer_logits)


In [10]:
def runPatchedMLP(layer, position):
    hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache)
    patched_MLP_logits = gpt2_large.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("mlp_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
    return normalized_patched_logit(patched_MLP_logits[:, -1, :].gather(dim=-1, index=answer_tokens).mean())
  
def runPatchedResids(layer, position):
    hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache)
    patched_resid_stream_logits = gpt2_large.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("resid_pre", layer), 
                hook_fn)], 
            return_type="logits"
        )
    return normalized_patched_logit(patched_resid_stream_logits[:, -1, :].gather(dim=-1, index=answer_tokens).mean())

def runPatchedAttention(layer, position):
    hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache)
    patched_attn_logits = gpt2_large.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("attn_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
    return normalized_patched_logit(patched_attn_logits[:, -1, :].gather(dim=-1, index=answer_tokens).mean())

def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook, 
    head_index, 
    clean_cache):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][:, :, head_index, :]
    return corrupted_head_vector

def runPatchedHeads(layer, head_index):
    hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=clean_cache)
    patched_head_logits = gpt2_large.run_with_hooks(
        corrupted_tokens, 
        fwd_hooks = [(utils.get_act_name("z", layer, "attn"), 
            hook_fn)], 
        return_type="logits"
    )
    return normalized_patched_logit(patched_head_logits[:, -1, :].gather(dim=-1, index=answer_tokens).mean())


In [13]:
def MLPPatching():
  patched_mlp_logits = torch.zeros(gpt2_large.cfg.n_layers, corrupted_tokens.shape[1], device="cpu", dtype=torch.float32)
  for layer in tqdm.tqdm(range(gpt2_large.cfg.n_layers)):
    for position in range(corrupted_tokens.shape[1]):
      # Define a hook function to replacing an activation with its cleaned counterpart
      # evaluate patched model on corrupted prompt
      patched_mlp_logits[layer, position] = copy.copy(runPatchedMLP(layer, position))
  prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt2_large.to_str_tokens(corrupted_tokens[0]))]
  imshow(patched_mlp_logits, x=prompt_position_labels, title="Logit From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})    

def AttnPatching():
  patched_attn_logits = torch.zeros(gpt2_large.cfg.n_layers, corrupted_tokens.shape[1], device="cpu", dtype=torch.float32)
  for layer in tqdm.tqdm(range(gpt2_large.cfg.n_layers)):
    for position in range(corrupted_tokens.shape[1]):
      patched_attn_logits[layer, position] = copy.copy(runPatchedAttention(layer, position))
  prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt2_large.to_str_tokens(corrupted_tokens[0]))]
  imshow(patched_attn_logits, x=prompt_position_labels, title="Logit From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})    

def ResidStreamPatching():
  patched_resid_logits = torch.zeros(gpt2_large.cfg.n_layers, corrupted_tokens.shape[1], device="cpu", dtype=torch.float32)
  for layer in tqdm.tqdm(range(gpt2_large.cfg.n_layers)):
    for position in range(corrupted_tokens.shape[1]):
      patched_resid_logits[layer, position] = copy.copy(runPatchedResids(layer, position))
  prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt2_large.to_str_tokens(corrupted_tokens[0]))]
  imshow(patched_resid_logits, x=prompt_position_labels, title="Logit From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})    


def HeadPatching():
  patched_head_logits = torch.zeros(gpt2_large.cfg.n_layers, gpt2_large.cfg.n_heads, device="cpu", dtype=torch.float32)
  for layer in range(gpt2_large.cfg.n_layers):
    for head_index in range(gpt2_large.cfg.n_heads):
      patched_head_logits[layer, head_index] = copy.copy(runPatchedHeads(layer, head_index))
      prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt2_large.to_str_tokens(corrupted_tokens[0]))]
  imshow(patched_head_logits, title="Logit From Patched Head output", labels={"x":"Head", "y":"Layer"})    

MLPPatching()
#AttnPatching()
#ResidStreamPatching()
#HeadPatching()
# Clean: Polish city is Warsaw. 0, 1, 2, 3, 4, 5, 6, 8, 9
# Clean: Indian city is Mumbai. 0, 1, 2, 18
# Clean: Indian food is dal. 0, 1, 2, 18
# Clean: Indian festival is Diwali. 0, 1, 2, 3, 5, 7, 18
# Clean: Chinese festival is the spring festival. 0, 1, 2

100%|██████████| 36/36 [00:29<00:00,  1.23it/s]


In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook, 
    pos, 
    clean_cache):
    # patch corrupted component for some position in text with corresponding
    # compontent from the clean version
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component
    
def MLPPatching():
  patched_mlp_logits = torch.zeros(gpt2_large.cfg.n_layers, corrupted_tokens.shape[1], device="cpu", dtype=torch.float32)
  for layer in tqdm.tqdm(range(gpt2_large.cfg.n_layers)):
    for position in range(corrupted_tokens.shape[1]):
      # Define a hook function to replacing an activation with its cleaned counterpart
      # evaluate patched model on corrupted prompt
      patched_mlp_logits[layer, position] = copy.copy(runPatchedMLP(layer, position))
  prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt2_large.to_str_tokens(corrupted_tokens[0]))]
  imshow(patched_mlp_logits, x=prompt_position_labels, title="Logit From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})    

def zoomIntoMLP(layer, position):
    hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache)
    patched_MLP_logits = gpt2_large.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("mlp_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
    return normalized_patched_logit(patched_MLP_logits[:, -1, :].gather(dim=-1, index=answer_tokens))
# MLP layers 0, 1, 2, 4, 5, 18
layers = [0, 1, 2, 4, 5, 18]
for layer in layers:


In [None]:
import copy
# [n_layers, position]
patched_mlp_logits = torch.zeros(gpt2_large.cfg.n_layers, corrupted_tokens.shape[1], device="cpu", dtype=torch.float32)
for layer in range(gpt2_large.cfg.n_layers):
  for position in range(corrupted_tokens.shape[1]):
    print(layer, position)
    # Define a hook function to  replacing an activation with its cleaned counterpart
    # evaluate patched model on corrupted prompt
    patched_mlp_logits[layer, position] = copy.copy(runPatched(layer, position))
    print(patched_mlp_logits.get_device()) 

prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(gpt2_large.to_str_tokens(corrupted_tokens[0]))]
imshow(patched_mlp_logits, x=prompt_position_labels, title="Logit From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})

In [7]:
torch.cuda.empty_cache()