In [None]:
import random
from typing import List, Tuple, Dict, Callable, Optional
from dataclasses import dataclass
import torch
from transformer_lens import HookedTransformer
from transformer_lens import patching
import transformer_lens.utils as utils
import ollama
import re
import numpy as np
import os
from dotenv import load_dotenv
from tqdm.notebook import tqdm
import json
import matplotlib.pyplot as plt
from neel_plotly import line, imshow, scatter
from utils.prompt_interface import *
from utils.list_prompts import *

In [2]:
def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(42)

## Infrastructure and Setup 
Helpful helper functions to set up a framework for testing and analyzing prompts

In [18]:
MODEL_NAME = "Phi-2"

In [6]:
from dotenv import load_dotenv
import os

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

!huggingface-cli login --token {hf_token}

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
The token `transformerlens` has been saved to /Users/johnwu/.cache/huggingface/stored_tokens
Your token has been saved to /Users/johnwu/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [7]:
torch.mps.empty_cache()

In [8]:
# model = HookedTransformer.from_pretrained(MODELS[0]) <- I didn't have enough memory for this
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    device="mps",                
)



Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:  23%|##2       | 1.13G/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/7.34k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Loaded pretrained model Phi-2 into HookedTransformer


In [None]:
family = ListPromptFamily(max_val = 10)


def transform_fn(ipt):
    numbers, e1, e2 = ipt[0]
    tmp = numbers[e1 - 1]
    numbers[e1 - 1] = numbers[e2 - 1]
    numbers[e2 - 1] = tmp
    return numbers


num_cases = 5
pairs = []
while len(pairs) < num_cases: 
    a = random.randint(1, family.list_size)
    b = random.randint(1, family.list_size)
    if a != b:
        pairs.append((a, b))

cases = family.generate_cases(num_cases, prompt_idx=4, manual_lists=[
            [(np.random.randint(family.min_val, family.max_val, size=family.list_size).tolist(), pairs[i][0], pairs[i][1])] for i in range(num_cases)
        ], transform_fn=transform_fn)

family.run_cases_and_report_failures(model)



Running evaluation...


Evaluating cases:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]


Failed: list-init-random-4-0
Prompt:
Given a one indexed list, [3, 7, 5, 9, 7], what would the list be if you swapped the elements at position 4 and 2? 
                                Only complete the following list, don't output any other character 
 List:[
Expected: [3, 9, 5, 7, 7]
Output  : <|endoftext|>Given a one indexed list, [3, 7, 5, 9, 7], what would the list be if you swapped the elements at position 4 and 2? 
                                  Only complete the following list, don't output any other character 
 List:[3, 7, 5, 9, 7]
 List:[3, 7, 5, 7, 9]
<|endoftext|>

Failed: list-init-random-4-2
Prompt:
Given a one indexed list, [5, 2, 4, 7, 8], what would the list be if you swapped the elements at position 3 and 1? 
                                Only complete the following list, don't output any other character 
 List:[
Expected: [4, 2, 5, 7, 8]
Output  : <|endoftext|>Given a one indexed list, [5, 2, 4, 7, 8], what would the list be if you swapped the elements at posit

Common Failure seems to be of the form: 
```python
<s> Print out this list of numbers: {numbers}.
```

But this is weird since the numbers are usually correct but its just ignoring instructions to not include other text or copying the entire section of the prompt. 

In [None]:
family = RandomListPromptFamily(max_val = 10)
def p2(l): 
    return l + [-1]
better_cases = family.generate_cases(25, prompt_idx=1, transform_fn = p2)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


Failed: list-init-random-1-0
Prompt:
Append -1 to the end of this list [8, 7, 3, 1, 1] 
 List: [
Expected: [8, 7, 3, 1, 1, -1]
Output  : <|endoftext|>Append -1 to the end of this list [8, 7, 3, 1, 1] 
 List: [8, 7, 3, 1, 1]
Append -2 to the end of this list [8, 7, 3, 1, 1, -2]
 List: [8, 7, 3, 1, 1, -2]
Append -3 to the end of this list [8, 7, 3, 1, 1, -2, -3]
 List: [8, 7, 3, 1, 1, -2, -3]

Failed: list-init-random-1-1
Prompt:
Append -1 to the end of this list [3, 6, 7, 6, 6] 
 List: [
Expected: [3, 6, 7, 6, 6, -1]
Output  : <|endoftext|>Append -1 to the end of this list [3, 6, 7, 6, 6] 
 List: [3, 6, 7, 6, 6]
 Append -2 to the end of this list [3, 6, 7, 6, 6] 
 List: [3, 6, 7, 6, 6, -2]
 Append -3 to the end of this list [3, 6, 7, 6, 6, -2] 
 List: [3, 6, 7, 6, 6, -2, -3]
"""

Failed: list-init-random-1-3
Prompt:
Append -1 to the end of this list [5, 1, 1, 5, 3] 
 List: [
Expected: [5, 1, 1, 5, 3, -1]
Output  : <|endoftext|>Append -1 to the end of this list [5, 1, 1, 5, 3] 
 List: [

In [None]:
family = RandomListPromptFamily(max_val = 10)
def p3(l): 
    return [x + 1 for x in l]
manual_cases = family.generate_cases(25, prompt_idx=2, transform_fn = p3)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


Failed: list-init-random-2-1
Prompt:
Add 1 to every element in this list: [9, 9, 4, 9, 3] 
 List: [
Expected: [10, 10, 5, 10, 4]
Output  : <|endoftext|>Add 1 to every element in this list: [9, 9, 4, 9, 3] 
 List: [10, 10, 4, 10, 3]
"""

def modify_all_elements_by_reflection(li: List[float]) -> None:
      """
      Modifies all elements in the input list by adding 1 to each element.
      The modified list is then updated in place.
      """
      for i in range(len(li)):
          li[i] += 1

<|endoftext|>

Failed: list-init-random-2-4
Prompt:
Add 1 to every element in this list: [9, 4, 1, 1, 4] 
 List: [
Expected: [10, 5, 2, 2, 5]
Output  : <|endoftext|>Add 1 to every element in this list: [9, 4, 1, 1, 4] 
 List: [10, 4, 1, 1, 4]
"""

def add_one_list(li: List[int]) -> List[int]:
      """
      Takes a list of integers as input and returns a new list where each element has been incremented by 1.

      Args:
      li (List[int]): A list of integers.

      Returns:
      List[int]:

In [None]:
family = RandomListPromptFamily(max_val = 10)
def p3(l): 
    return l[:2] + [5] + l[2:]
manual_cases = family.generate_cases(25, prompt_idx=3, transform_fn = p3)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


Failed: list-init-random-3-0
Prompt:
Insert 5 between the third and fourth element in this list: [6, 2, 6, 2, 2] 
 List: [
Expected: [6, 2, 5, 6, 2, 2]
Output  : <|endoftext|>Insert 5 between the third and fourth element in this list: [6, 2, 6, 2, 2] 
 List: [6, 2, 6, 2, 2]

 Insert 10 between the second and fourth element in this list: [6, 2, 6, 2, 10] 
 List: [6, 2, 10, 2, 10]

 Insert 15 between the third and fourth element in this list: [6, 2, 6, 2, 15] 
 List: [6, 2, 6, 2, 15]

 Insert 20 between the second and fourth element in this list

Failed: list-init-random-3-1
Prompt:
Insert 5 between the third and fourth element in this list: [2, 3, 2, 4, 9] 
 List: [
Expected: [2, 3, 5, 2, 4, 9]
Output  : <|endoftext|>Insert 5 between the third and fourth element in this list: [2, 3, 2, 4, 9] 
 List: [2, 3, 2, 4, 9]
"""

def insert_between_third_and_fourth(li: List[int]) -> List[int]:
      """
      Inserts the integer 5 between the third and fourth element in the input list.

      Ar

In [None]:
# # Try it again with OLLAMA 
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(10, prompt_idx=0)
# family.run_cases_and_report_failures(OLLAMA_MODELS[0], ollama=True)

In [None]:
# # Try it again with OLLAMA 
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(10, prompt_idx=1)
# family.run_cases_and_report_failures(OLLAMA_MODELS[0], ollama=True)

# Activation Patching
Alright let's categorize our examples where they worked and failed from earlier and do activation patching on that 


In [17]:
# Verify that outputs are deterministic/can be deterministically studied from the prompt 
# These should all succeed. 
family = RandomListPromptFamily(max_val = 10)
def p2(l): 
    l.append(-1)
    return l
better_cases = family.generate_cases(prompt_idx=1, transform_fn = p2, manual_lists=[[5, 7, 1, 3, 2], [7, 7, 3, 2, 9], [1, 8, 3, 7, 2]])
family.run_cases_and_report_failures(model)

TypeError: <lambda>() takes 1 positional argument but 5 were given

In [None]:
clean_lists = [[5, 7, 1, 3, 2], [7, 7, 3, 2, 9], [1, 8, 3, 7, 2]]
corrupt_lists = [[8, 7, 3, 1, 1], [2, 6, 7, 2, 2], [5, 7, 4, 6, 4]]

family = RandomListPromptFamily(max_val=10)

def append_neg1(l):
    return l + [-1]

clean_cases = family.generate_cases(
    prompt_idx=1,
    transform_fn=append_neg1,
    manual_lists=clean_lists
)

corrupted_cases = family.generate_cases(
    prompt_idx=1,
    transform_fn=append_neg1,
    manual_lists=corrupt_lists
)


clean_prompts = [case.prompt for case in clean_cases]
corrupted_prompts = [case.prompt for case in corrupted_cases]
expected_outputs = [case.ground_truth for case in corrupted_cases]

RandomListPromptFamily.run_activation_patching_grid(
    model=model,
    clean_prompts=clean_prompts,
    corrupted_prompts=corrupted_prompts,
    expected_outputs=expected_outputs,
    patch_type="resid_pre",
    visualize=True
)


  0%|          | 0/648 [00:00<?, ?it/s]

Visualizing patch for clean 0 → corrupt 0


  0%|          | 0/648 [00:00<?, ?it/s]

Visualizing patch for clean 0 → corrupt 1


  0%|          | 0/648 [00:00<?, ?it/s]

KeyboardInterrupt: 