In [38]:
import random
from typing import List, Tuple, Dict, Callable, Optional
from dataclasses import dataclass
import torch
from transformer_lens import HookedTransformer
from transformer_lens import patching
import transformer_lens.utils as utils
import ollama
import re
import numpy as np
import os
from dotenv import load_dotenv
from tqdm.notebook import tqdm
import json
import matplotlib.pyplot as plt
from neel_plotly import line, imshow, scatter

In [2]:
def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(42)

## Infrastructure and Setup 
Helpful helper functions to set up a framework for testing and analyzing prompts

In [3]:
MODEL_NAME = "Phi-1"
OK_COPY = ["qwen-1.8b, Phi-1"]
BAD_COPY = ["mistralai/Mistral-7B-v0.1", "EleutherAI/gpt-neo-2.7B", "gpt2-xl"]
OLLAMA_MODELS = ["mistral:7b"]

In [4]:
def ollama_generate(prompt, model):
    output = ollama.generate(model, prompt=prompt)
    return output["response"]

In [5]:
@dataclass
class PromptCase:
    task_id: str
    prompt: str
    ground_truth: str
    metadata: Dict

    def test_behavior(self, model_output: str) -> Dict:
        """
        Run standard checks to categorize the response.
        Returns tags like: {"has_extra_text": True, "wrong_numbers": True}
        """
        raise NotImplementedError()

    def analyze_tokens(self, tokenizer) -> Dict:
        """
        Tokenize prompt and return metadata:
        - total_tokens
        - location of numbers
        - special token positions
        """
        raise NotImplementedError()

class PromptFamily:
    def name(self) -> str:
        raise NotImplementedError()

    def generate_cases(self, n: int) -> List[PromptCase]:
        raise NotImplementedError()

    def test_behavior(self, case: PromptCase, model_output: str) -> Dict:
        """
        Should return standardized error tags (e.g., extra_text, wrong_values)
        """
        raise NotImplementedError()

    def analyze_tokens(self, case: PromptCase, tokenizer) -> Dict:
        """
        Return token-level info: total tokens, number token positions, etc.
        """
        raise NotImplementedError()


In [39]:
@dataclass
class PromptCase:
    task_id: str
    prompt: str
    ground_truth: str
    metadata: Dict

    def test_behavior(self, model_output: str) -> Dict:
        raise NotImplementedError()

    def analyze_tokens(self, tokenizer) -> Dict:
        raise NotImplementedError()

class PromptFamily:
    def name(self) -> str:
        raise NotImplementedError()

    def generate_cases(self, n: int) -> List[PromptCase]:
        raise NotImplementedError()

    def test_behavior(self, case: PromptCase, model_output: str) -> Dict:
        raise NotImplementedError()

    def analyze_tokens(self, case: PromptCase, tokenizer) -> Dict:
        raise NotImplementedError()

class RandomListPromptFamily(PromptFamily):
    def __init__(self, min_val=1, max_val=10, list_size=5):
        self.min_val = min_val
        self.max_val = max_val
        self.list_size = list_size
        self.cases = []
        self.prompts = [
            lambda numbers: f"Print out this list of numbers: {numbers}. List: [",
            lambda numbers: f"""Append -1 to the end of this list {numbers} \n List: [""",
            lambda numbers: f"""Add 1 to every element in this list: {numbers} \n List: [""",
            ]

    def name(self):
        return "list-init-random"

    def generate_cases(
        self,
        n: int = 0,
        prompt_idx: int = 0,
        transform_fn: Callable[[List[int]], List[int]] = None,
        manual_lists: Optional[List[List[int]]] = None
    ) -> List[PromptCase]:
        self.cases = []
        inputs = manual_lists if manual_lists is not None else [
            np.random.randint(self.min_val, self.max_val, size=self.list_size).tolist() for _ in range(n)
        ]

        for i, numbers in enumerate(inputs):
            prompt = self.prompts[prompt_idx](numbers)

            if transform_fn:
                expected_list = transform_fn(numbers)
            else:
                expected_list = numbers

            expected = str(expected_list)

            self.cases.append(PromptCase(
                task_id=f"{self.name()}-{prompt_idx}-{i}",
                prompt=prompt,
                ground_truth=expected,
                metadata={
                    "list": numbers,
                    "expected": expected_list,
                    "family": self.name(),
                    "transformation": prompt_idx
                }
            ))
        return self.cases


    def analyze_tokens(self, case: PromptCase, model: HookedTransformer) -> Dict:
        tokens = model.to_tokens(case.prompt)[0]
        token_strs = model.to_str_tokens(tokens)

        number_token_spans = {}
        split_tokens = []
        numbers = case.metadata.get("list", [])

        for num in numbers:
            num_str = str(num)
            tokenized = model.to_tokens(num_str)[0]
            number_token_spans[num] = len(tokenized)
            if len(tokenized) > 1:
                split_tokens.append(num)

        return {
            "total_tokens": len(tokens),
            "token_ids": tokens.tolist(),
            "token_strs": token_strs,
            "number_token_spans": number_token_spans,
            "split_tokens": split_tokens,
            "num_splits": len(split_tokens)
        }

    def run_cases_and_report_failures(self, model: HookedTransformer, max_tokens: int = 100, ollama=False, suppress_correct=False):
        print("Running evaluation...")
        results = []

        for case in tqdm(self.cases, desc="Evaluating cases"):
            try:
                if ollama:
                    decoded = ollama_generate(case.prompt, model=model).strip()
                else:
                    tokens = model.to_tokens(case.prompt)
                    generated_tokens = model.generate(
                        tokens,
                        max_new_tokens=max_tokens,
                        temperature=0.0,
                        top_k=0
                    )
                    decoded = model.tokenizer.decode(generated_tokens[0]).strip()

                model_output = decoded
                expected = case.ground_truth.strip()

                correct = expected in model_output 

                results.append((not correct, case.task_id, case.prompt, expected, decoded, correct))

                if not ollama:
                    del tokens, generated_tokens, decoded, model_output
                    torch.mps.empty_cache()

            except Exception as e:
                results.append((True, case.task_id, case.prompt, "<ERROR>", str(e), False))

        # Sort so incorrect first
        results.sort(key=lambda x: x[0], reverse=True)

        for is_incorrect, task_id, prompt, expected, output, correct in results:
            if not correct or not suppress_correct:
                print(f"\n{'Failed' if is_incorrect else 'Passed'}: {task_id}")
                print(f"Prompt:\n{prompt}")
                print(f"Expected: {expected}")
                print(f"Output  : {output}")

    @staticmethod
    def run_activation_patching_grid(
        model: HookedTransformer,
        clean_prompts: List[str],
        corrupted_prompts: List[str],
        expected_outputs: List[str],
        patch_type: str = "resid_pre",
        max_tokens: int = 100,
        output_file: str = "activation_patch_results.json",
        visualize: bool = False
    ):
        assert patch_type in {"resid_pre", "block", "attention", "mlp"}, f"Invalid patch type: {patch_type}"

        def decode_logits(logits):
            return model.tokenizer.decode(logits.argmax(dim=-1)[0])

        results = {}

        for i, clean_prompt in enumerate(clean_prompts):
            clean_tokens = model.to_tokens(clean_prompt)
            _, clean_cache = model.run_with_cache(clean_tokens)

            for j, (corrupted_prompt, expected_output) in enumerate(zip(corrupted_prompts, expected_outputs)):
                corrupted_tokens = model.to_tokens(corrupted_prompt)

                def metric(logits):
                    decoded = decode_logits(logits)
                    return torch.tensor(float(expected_output in decoded))

                if patch_type == "resid_pre":
                    patch_result = patching.get_act_patch_resid_pre(
                        model, corrupted_tokens, clean_cache, metric
                    )
                elif patch_type == "block":
                    patch_result = patching.get_act_patch_block(
                        model, corrupted_tokens, clean_cache, metric
                    )
                elif patch_type == "attention":
                    patch_result = patching.get_act_patch_attention(
                        model, corrupted_tokens, clean_cache, metric
                    )
                elif patch_type == "mlp":
                    patch_result = patching.get_act_patch_mlp(
                        model, corrupted_tokens, clean_cache, metric
                    )

                key = f"clean_{i}_corrupt_{j}"
                results[key] = {
                    "clean_prompt": clean_prompt,
                    "corrupted_prompt": corrupted_prompt,
                    "expected_output": expected_output,
                    "patch_type": patch_type,
                    "patch_result": patch_result.tolist() if hasattr(patch_result, "tolist") else patch_result
                }

                if visualize:
                    print(f"Visualizing patch for clean {i} → corrupt {j}")
                    imshow(resid_pre_act_patch_results, 
                           yaxis="Layer", 
                           xaxis="Position", 
                           x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
                           title=f"{patch_type} Activation Patching")
                torch.mps.empty_cache()

        with open(output_file, "w") as f:
            json.dump(results, f, indent=2)

        print(f"Saved activation patching results to {output_file}")
    

In [7]:
from dotenv import load_dotenv
import os

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

!huggingface-cli login --token {hf_token}

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
The token `transformerlens` has been saved to /Users/johnwu/.cache/huggingface/stored_tokens
Your token has been saved to /Users/johnwu/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [8]:
torch.mps.empty_cache()

In [9]:
# model = HookedTransformer.from_pretrained(MODELS[0]) <- I didn't have enough memory for this
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    device="mps",                
)

Loaded pretrained model Phi-1 into HookedTransformer


In [None]:
family = RandomListPromptFamily(max_val = 10)
cases = family.generate_cases(25, prompt_idx=0)
family.run_cases_and_report_failures(model)

Common Failure seems to be of the form: 
```python
<s> Print out this list of numbers: {numbers}.
```

But this is weird since the numbers are usually correct but its just ignoring instructions to not include other text or copying the entire section of the prompt. 

In [13]:
family = RandomListPromptFamily(max_val = 10)
def p2(l): 
    return l + [-1]
better_cases = family.generate_cases(25, prompt_idx=1, transform_fn = p2)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


Failed: list-init-random-1-0
Prompt:
Append -1 to the end of this list [8, 7, 3, 1, 1] 
 List: [
Expected: [8, 7, 3, 1, 1, -1]
Output  : <|endoftext|>Append -1 to the end of this list [8, 7, 3, 1, 1] 
 List: [8, 7, 3, 1, 1]
Append -2 to the end of this list [8, 7, 3, 1, 1, -2]
 List: [8, 7, 3, 1, 1, -2]
Append -3 to the end of this list [8, 7, 3, 1, 1, -2, -3]
 List: [8, 7, 3, 1, 1, -2, -3]

Failed: list-init-random-1-1
Prompt:
Append -1 to the end of this list [3, 6, 7, 6, 6] 
 List: [
Expected: [3, 6, 7, 6, 6, -1]
Output  : <|endoftext|>Append -1 to the end of this list [3, 6, 7, 6, 6] 
 List: [3, 6, 7, 6, 6]
 Append -2 to the end of this list [3, 6, 7, 6, 6] 
 List: [3, 6, 7, 6, 6, -2]
 Append -3 to the end of this list [3, 6, 7, 6, 6, -2] 
 List: [3, 6, 7, 6, 6, -2, -3]
"""

Failed: list-init-random-1-3
Prompt:
Append -1 to the end of this list [5, 1, 1, 5, 3] 
 List: [
Expected: [5, 1, 1, 5, 3, -1]
Output  : <|endoftext|>Append -1 to the end of this list [5, 1, 1, 5, 3] 
 List: [

In [11]:
family = RandomListPromptFamily(max_val = 10)
def p3(l): 
    return [x + 1 for x in l]
manual_cases = family.generate_cases(25, prompt_idx=2, transform_fn = p3)
family.run_cases_and_report_failures(model)

Running evaluation...


Evaluating cases:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


Failed: list-init-random-2-1
Prompt:
Add 1 to every element in this list: [9, 9, 4, 9, 3] 
 List: [
Expected: [10, 10, 5, 10, 4]
Output  : <|endoftext|>Add 1 to every element in this list: [9, 9, 4, 9, 3] 
 List: [10, 10, 4, 10, 3]
"""

def modify_all_elements_by_reflection(li: List[float]) -> None:
      """
      Modifies all elements in the input list by adding 1 to each element.
      The modified list is then updated in place.
      """
      for i in range(len(li)):
          li[i] += 1

<|endoftext|>

Failed: list-init-random-2-4
Prompt:
Add 1 to every element in this list: [9, 4, 1, 1, 4] 
 List: [
Expected: [10, 5, 2, 2, 5]
Output  : <|endoftext|>Add 1 to every element in this list: [9, 4, 1, 1, 4] 
 List: [10, 4, 1, 1, 4]
"""

def add_one_list(li: List[int]) -> List[int]:
      """
      Takes a list of integers as input and returns a new list where each element has been incremented by 1.

      Args:
      li (List[int]): A list of integers.

      Returns:
      List[int]:

In [12]:
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(1, prompt_idx=2)
# family.run_cases_and_report_failures(model)

In [13]:
# # Try it again with OLLAMA 
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(10, prompt_idx=0)
# family.run_cases_and_report_failures(OLLAMA_MODELS[0], ollama=True)

In [14]:
# # Try it again with OLLAMA 
# family = RandomListPromptFamily(max_val = 10)
# cases = family.generate_cases(10, prompt_idx=1)
# family.run_cases_and_report_failures(OLLAMA_MODELS[0], ollama=True)

# Activation Patching
Alright let's categorize our examples where they worked and failed from earlier and do activation patching on that 

## Append -1 to the end of the list 

### Passed 
[5, 7, 1, 3, 2] 

[7, 7, 3, 2, 9]

[1, 8, 3, 7, 2]


In [None]:
# Verify that outputs are deterministic/can be deterministically studied from the prompt 
# These should all succeed. 
family = RandomListPromptFamily(max_val = 10)
def p2(l): 
    l.append(-1)
    return l
better_cases = family.generate_cases(prompt_idx=1, transform_fn = p2, manual_lists=[[5, 7, 1, 3, 2], [7, 7, 3, 2, 9], [1, 8, 3, 7, 2]])
family.run_cases_and_report_failures(model)

In [None]:
clean_lists = [[5, 7, 1, 3, 2], [7, 7, 3, 2, 9], [1, 8, 3, 7, 2]]
corrupt_lists = [[8, 7, 3, 1, 1], [2, 6, 7, 2, 2], [5, 7, 4, 6, 4]]

family = RandomListPromptFamily(max_val=10)

def append_neg1(l):
    return l + [-1]

clean_cases = family.generate_cases(
    prompt_idx=1,
    transform_fn=append_neg1,
    manual_lists=clean_lists
)

corrupted_cases = family.generate_cases(
    prompt_idx=1,
    transform_fn=append_neg1,
    manual_lists=corrupt_lists
)


clean_prompts = [case.prompt for case in clean_cases]
corrupted_prompts = [case.prompt for case in corrupted_cases]
expected_outputs = [case.ground_truth for case in corrupted_cases]

RandomListPromptFamily.run_activation_patching_grid(
    model=model,
    clean_prompts=clean_prompts,
    corrupted_prompts=corrupted_prompts,
    expected_outputs=expected_outputs,
    patch_type="resid_pre",
    visualize=True
)


  0%|          | 0/648 [00:00<?, ?it/s]