In [None]:
import random
from typing import List, Tuple, Dict, Callable, Optional
from dataclasses import dataclass
import torch
from transformer_lens import HookedTransformer
from transformer_lens import patching
import transformer_lens.utils as utils
import ollama
import re
import numpy as np
import os
from dotenv import load_dotenv
from tqdm.notebook import tqdm



In [25]:
def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(42)

## Infrastructure and Setup 
Helpful helper functions to set up a framework for testing and analyzing prompts

In [26]:
MODEL_NAME = "Phi-1"
OK_COPY = ["qwen-1.8b, Phi-1"]
BAD_COPY = ["mistralai/Mistral-7B-v0.1", "EleutherAI/gpt-neo-2.7B", "gpt2-xl"]
OLLAMA_MODELS = ["mistral:7b"]

In [27]:
def ollama_generate(prompt, model):
    output = ollama.generate(model, prompt=prompt)
    return output["response"]

In [28]:
@dataclass
class PromptCase:
    task_id: str
    prompt: str
    ground_truth: str
    metadata: Dict

    def test_behavior(self, model_output: str) -> Dict:
        """
        Run standard checks to categorize the response.
        Returns tags like: {"has_extra_text": True, "wrong_numbers": True}
        """
        raise NotImplementedError()

    def analyze_tokens(self, tokenizer) -> Dict:
        """
        Tokenize prompt and return metadata:
        - total_tokens
        - location of numbers
        - special token positions
        """
        raise NotImplementedError()

class PromptFamily:
    def name(self) -> str:
        raise NotImplementedError()

    def generate_cases(self, n: int) -> List[PromptCase]:
        raise NotImplementedError()

    def test_behavior(self, case: PromptCase, model_output: str) -> Dict:
        """
        Should return standardized error tags (e.g., extra_text, wrong_values)
        """
        raise NotImplementedError()

    def analyze_tokens(self, case: PromptCase, tokenizer) -> Dict:
        """
        Return token-level info: total tokens, number token positions, etc.
        """
        raise NotImplementedError()


In [21]:
@dataclass
class PromptCase:
    task_id: str
    prompt: str
    ground_truth: str
    metadata: Dict

    def test_behavior(self, model_output: str) -> Dict:
        raise NotImplementedError()

    def analyze_tokens(self, tokenizer) -> Dict:
        raise NotImplementedError()

class PromptFamily:
    def name(self) -> str:
        raise NotImplementedError()

    def generate_cases(self, n: int) -> List[PromptCase]:
        raise NotImplementedError()

    def test_behavior(self, case: PromptCase, model_output: str) -> Dict:
        raise NotImplementedError()

    def analyze_tokens(self, case: PromptCase, tokenizer) -> Dict:
        raise NotImplementedError()

class RandomListPromptFamily(PromptFamily):
    def __init__(self, min_val=1, max_val=10, list_size=5):
        self.min_val = min_val
        self.max_val = max_val
        self.list_size = list_size
        self.cases = []
        self.prompts = [
            lambda numbers: f"Print out this list of numbers: {numbers}. List: [",
            lambda numbers: f"""Append -1 to the end of this list {numbers} \n List: [""",
            lambda numbers: f"""Add 1 to every element in this list: {numbers} \n List: [""",
            ]

    def name(self):
        return "list-init-random"

    def generate_cases(
        self,
        n: int = 0,
        prompt_idx: int = 0,
        transform_fn: Callable[[List[int]], List[int]] = None,
        manual_lists: Optional[List[List[int]]] = None
    ) -> List[PromptCase]:
        self.cases = []
        inputs = manual_lists if manual_lists is not None else [
            np.random.randint(self.min_val, self.max_val, size=self.list_size).tolist() for _ in range(n)
        ]

        for i, numbers in enumerate(inputs):
            prompt = self.prompts[prompt_idx](numbers)

            if transform_fn:
                expected_list = transform_fn(numbers)
            else:
                expected_list = numbers

            expected = str(expected_list)

            self.cases.append(PromptCase(
                task_id=f"{self.name()}-{prompt_idx}-{i}",
                prompt=prompt,
                ground_truth=expected,
                metadata={
                    "list": numbers,
                    "expected": expected_list,
                    "family": self.name(),
                    "transformation": prompt_idx
                }
            ))
        return self.cases


    def analyze_tokens(self, case: PromptCase, model: HookedTransformer) -> Dict:
        tokens = model.to_tokens(case.prompt)[0]
        token_strs = model.to_str_tokens(tokens)

        number_token_spans = {}
        split_tokens = []
        numbers = case.metadata.get("list", [])

        for num in numbers:
            num_str = str(num)
            tokenized = model.to_tokens(num_str)[0]
            number_token_spans[num] = len(tokenized)
            if len(tokenized) > 1:
                split_tokens.append(num)

        return {
            "total_tokens": len(tokens),
            "token_ids": tokens.tolist(),
            "token_strs": token_strs,
            "number_token_spans": number_token_spans,
            "split_tokens": split_tokens,
            "num_splits": len(split_tokens)
        }

    def run_cases_and_report_failures(self, model: HookedTransformer, max_tokens: int = 100, ollama=False, suppress_correct=False):
        print("Running evaluation...")
        results = []

        for case in tqdm(self.cases, desc="Evaluating cases"):
            try:
                if ollama:
                    decoded = ollama_generate(case.prompt, model=model).strip()
                else:
                    tokens = model.to_tokens(case.prompt)
                    generated_tokens = model.generate(tokens, max_new_tokens=max_tokens)
                    decoded = model.tokenizer.decode(generated_tokens[0]).strip()

                model_output = decoded
                expected = case.ground_truth.strip()

                correct = expected in model_output 

                results.append((not correct, case.task_id, case.prompt, expected, decoded, correct))

                if not ollama:
                    del tokens, generated_tokens, decoded, model_output
                    torch.mps.empty_cache()

            except Exception as e:
                results.append((True, case.task_id, case.prompt, "<ERROR>", str(e), False))

        # Sort so incorrect first
        results.sort(key=lambda x: x[0], reverse=True)

        for is_incorrect, task_id, prompt, expected, output, correct in results:
            if not correct or not suppress_correct:
                print(f"\n{'Failed' if is_incorrect else 'Passed'}: {task_id}")
                print(f"Prompt:\n{prompt}")
                print(f"Expected: {expected}")
                print(f"Output  : {output}")


In [6]:
from dotenv import load_dotenv
import os

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

!huggingface-cli login --token {hf_token}

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
The token `transformerlens` has been saved to /Users/johnwu/.cache/huggingface/stored_tokens
Your token has been saved to /Users/johnwu/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [7]:
torch.mps.empty_cache()

In [8]:
# model = HookedTransformer.from_pretrained(MODELS[0]) <- I didn't have enough memory for this
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    device="mps",                
)

Loaded pretrained model Phi-1 into HookedTransformer


In [9]:
family = RandomListPromptFamily(max_val = 10)
cases = family.generate_cases(25, prompt_idx=0)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]


Passed: list-init-random-0-0
Prompt:
Print out this list of numbers: [9, 6, 6, 7, 1]. List: [
Expected: [9, 6, 6, 7, 1]
Output  : <|endoftext|>Print out this list of numbers: [9, 6, 6, 7, 1]. List: [9, 6, 6, 7, 1]
2 * 3 = 6
2 * 7 = 14
1 * 3 = 3
1

Passed: list-init-random-0-1
Prompt:
Print out this list of numbers: [8, 5, 7, 5, 3]. List: [
Expected: [8, 5, 7, 5, 3]
Output  : <|endoftext|>Print out this list of numbers: [8, 5, 7, 5, 3]. List: [8, 5, 7, 5, 3]
      Print out this list of letters: ['c', 'a', 't', 'e

Passed: list-init-random-0-2
Prompt:
Print out this list of numbers: [8, 6, 4, 9, 5]. List: [
Expected: [8, 6, 4, 9, 5]
Output  : <|endoftext|>Print out this list of numbers: [8, 6, 4, 9, 5]. List: [8 64 4 9 5]"""
      numbers = []
      for i in range(0, len(numbers), 2):

Passed: list-init-random-0-3
Prompt:
Print out this list of numbers: [2, 1, 7, 1, 5]. List: [
Expected: [2, 1, 7, 1, 5]
Output  : <|endoftext|>Print out this list of numbers: [2, 1, 7, 1, 5]. List: [7, 1

Common Failure seems to be of the form: 
```python
<s> Print out this list of numbers: {numbers}.
```

But this is weird since the numbers are usually correct but its just ignoring instructions to not include other text or copying the entire section of the prompt. 

In [10]:
family = RandomListPromptFamily(max_val = 10)
def p2(l): 
    l.append(-1)
    return l
better_cases = family.generate_cases(25, prompt_idx=1, transform_fn = p2)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]


Failed: list-init-random-1-1
Prompt:
Append -1 to the end of this list [5, 9, 8, 2, 6] 
 List: [
Expected: [5, 9, 8, 2, 6, -1]
Output  : <|endoftext|>Append -1 to the end of this list [5, 9, 8, 2, 6] 
 List: [5, 9, 8, 2, 6]
Append 3 to the end of this list [5, 9, 8, 2, 6

Failed: list-init-random-1-2
Prompt:
Append -1 to the end of this list [2, 8, 6, 7, 9] 
 List: [
Expected: [2, 8, 6, 7, 9, -1]
Output  : <|endoftext|>Append -1 to the end of this list [2, 8, 6, 7, 9] 
 List: [2, 8, 6, 7, 9]
 Append 9 to the end of this list [2, 8, 6, 7, 9

Failed: list-init-random-1-4
Prompt:
Append -1 to the end of this list [9, 4, 4, 5, 1] 
 List: [
Expected: [9, 4, 4, 5, 1, -1]
Output  : <|endoftext|>Append -1 to the end of this list [9, 4, 4, 5, 1] 
 List: [9, 4, 4, 5, 1, 6, 3, 2, 5]
Append 2 to the end of this list [9

Failed: list-init-random-1-5
Prompt:
Append -1 to the end of this list [4, 2, 2, 4, 7] 
 List: [
Expected: [4, 2, 2, 4, 7, -1]
Output  : <|endoftext|>Append -1 to the end of this 

In [20]:
family = RandomListPromptFamily(max_val = 10)
def p3(l): 
    return [x + 1 for x in l]
manual_cases = family.generate_cases(25, prompt_idx=2, transform_fn = p3)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]


Failed: list-init-random-2-0
Prompt:
Add 1 to every element in this list: [2, 4, 5, 3, 8] 
 List: [
Expected: [3, 5, 6, 4, 9]
Output  : <|endoftext|>Add 1 to every element in this list: [2, 4, 5, 3, 8] 
 List: [3, 5, 7, 2, 4]
"""

def add_one_to_every_element(li: List[int

Failed: list-init-random-2-1
Prompt:
Add 1 to every element in this list: [1, 8, 5, 3, 3] 
 List: [
Expected: [2, 9, 6, 4, 4]
Output  : <|endoftext|>Add 1 to every element in this list: [1, 8, 5, 3, 3] 
 List: [2, 7, 6, 4, 4]
"""

def count_all_ascii_figures(fig: List[

Failed: list-init-random-2-2
Prompt:
Add 1 to every element in this list: [7, 2, 2, 9, 3] 
 List: [
Expected: [8, 3, 3, 10, 4]
Output  : <|endoftext|>Add 1 to every element in this list: [7, 2, 2, 9, 3] 
 List: [8, 3, 4, 10, 5]
"""
li = [7, 2, 2, 9, 3]
li = [

Failed: list-init-random-2-3
Prompt:
Add 1 to every element in this list: [7, 2, 2, 5, 7] 
 List: [
Expected: [8, 3, 3, 6, 8]
Output  : <|endoftext|>Add 1 to every element in this list: [7, 2, 

In [12]:
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(1, prompt_idx=2)
# family.run_cases_and_report_failures(model)

In [13]:
# # Try it again with OLLAMA 
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(10, prompt_idx=0)
# family.run_cases_and_report_failures(OLLAMA_MODELS[0], ollama=True)

In [14]:
# # Try it again with OLLAMA 
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(10, prompt_idx=1)
# family.run_cases_and_report_failures(OLLAMA_MODELS[0], ollama=True)

# Activation Patching
Alright let's categorize our examples where they worked and failed from earlier and do activation patching on that 

## Append -1 to the end of the list 

### Passed 
[1, 1, 9, 4, 2] 

[6, 3, 2, 1, 6] 

[6, 1, 5, 9, 8] 



### Failed - Copied List
[5, 9, 8, 2, 6] 

[2, 8, 6, 7, 9] 

[4, 2, 2, 4, 7]

[9, 6, 4, 9, 6] 

[4, 3, 6, 1, 1]

[5, 4, 3, 7, 5]

[1, 8, 9, 2, 5]

[3, 5, 7, 4, 2]

[3, 6, 9, 8, 3]

[3, 5, 9, 2, 3]

[2, 8, 8, 3, 9]

[5, 8, 7, 7, 8] 

[5, 1, 3, 9, 5]

[1, 9, 6, 6, 6]

[3, 8, 6, 8, 4]

[8, 7, 6, 3, 5]

[3, 8, 1, 9, 7]

[7, 9, 1, 1, 2] 




### Failed - Bad/More Append
[7, 1, 3, 6, 9] - Bad Append

[6, 4, 4, 3, 4] - Appended 2 wrong numbers

[6, 9, 3, 6, 5] - Appended many wrong numbers

[9, 4, 4, 5, 1] - Appended many wrong numbers


In [23]:
# Verify that outputs are deterministic/can be deterministically studied from the prompt 
# These should all succeed. 
family = RandomListPromptFamily(max_val = 10)
def p2(l): 
    l.append(-1)
    return l
better_cases = family.generate_cases(prompt_idx=1, transform_fn = p2, manual_lists=[[1, 1, 9, 4, 2], [6, 3, 2, 1, 6], [6, 1, 5, 9, 8]])
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


Failed: list-init-random-1-1
Prompt:
Append -1 to the end of this list [6, 3, 2, 1, 6] 
 List: [
Expected: [6, 3, 2, 1, 6, -1]
Output  : <|endoftext|>Append -1 to the end of this list [6, 3, 2, 1, 6] 
 List: [6, 3, 2, 1, 6]
 Append 9 to the end of this list [6, 3, 2, 1, 9]
 List: [6, 3, 2, 1, 9, 9]
 Append -1 to the end of this list [6, 3, 2, 1, 6, 9]
 List: [6, 3, 2, 1, 6, 9]
<|endoftext|>

Failed: list-init-random-1-2
Prompt:
Append -1 to the end of this list [6, 1, 5, 9, 8] 
 List: [
Expected: [6, 1, 5, 9, 8, -1]
Output  : <|endoftext|>Append -1 to the end of this list [6, 1, 5, 9, 8] 
 List: [6, 1, 5, 9, 8]
Sort the list: [1, 5, 6, 8, 9]
 Replace 7 with -1: [1, 5, -1, 9, 8]
 Append -1 to the end of the list: [1, 5, -1, 9, 8, -1]

<|endoftext|>

Passed: list-init-random-1-0
Prompt:
Append -1 to the end of this list [1, 1, 9, 4, 2] 
 List: [
Expected: [1, 1, 9, 4, 2, -1]
Output  : <|endoftext|>Append -1 to the end of this list [1, 1, 9, 4, 2] 
 List: [1, 1, 9, 4, 2, -1]
Append -2 to