In [1]:
# Disable widget-based progress bars
import warnings
warnings.filterwarnings('ignore')

import os
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"  # Add this!

# Use simple progress bars instead of widgets
import transformers
transformers.logging.set_verbosity_error()

print("✓ Environment configured")

✓ Environment configured


### Clear GPU Memory

In [2]:
import torch
import gc

# Force cleanup
gc.collect()
torch.cuda.empty_cache()

# Check memory
print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB total")
print(f"GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

GPU Memory available: 94.97 GB total
GPU Memory allocated: 0.00 GB
GPU Memory reserved: 0.00 GB


## Load the model

In [3]:
# from transformers import AutoModelForCausalLM, AutoTokenizer
# import torch

# # Point to local model directory
# model_id = "./gpt-oss-20b/"

# # Load model and tokenizer
# print("Loading model with MXFP4 quantization...")
# model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     torch_dtype="auto",
#     device_map="auto",
#     trust_remote_code=True,
#     local_files_only=True,
#     low_cpu_mem_usage=True,
# )

# print("Loading tokenizer...")
# tokenizer = AutoTokenizer.from_pretrained(
#     model_id,
#     local_files_only=True
# )

# print("✓ Model loaded successfully!")
# print(f"GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

## Run inference


In [4]:
# from transformers import TextStreamer
# # Test inference
# messages = [
#     {"role": "user", "content": "What is 2+2?"}
# ]

# # Apply chat template (harmony format)
# text = tokenizer.apply_chat_template(
#     messages,
#     tokenize=False,
#     add_generation_prompt=True
# )

# inputs = tokenizer(text, return_tensors="pt").to(model.device)

# # Generate
# print("Generating...")


# # Use streamer to see tokens as they're generated
# streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# outputs = model.generate(
#     **inputs,
#     max_new_tokens=256,
#     do_sample=False,
#     streamer=streamer  # Add this
# )

# response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# print("\nResponse:", response)

# Steering


In [14]:
import logging
import os
import sys
import timez
from collections import defaultdict
from pathlib import Path

import circuitsvis as cv
import einops
import numpy as np
import torch as t
from IPython.display import display
from jaxtyping import Float
from nnsight import CONFIG, LanguageModel
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from torch import Tensor

# Hide some info logging messages from nnsight
logging.disable(sys.maxsize)

t.set_grad_enabled(False)
device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)

# Import from local plotly_utils.py file (not the pip package)
import importlib.util
notebook_dir = Path(__file__).parent if "__file__" in globals() else Path.cwd()
plotly_utils_path = notebook_dir / "plotly_utils.py"
if not plotly_utils_path.exists():
    plotly_utils_path = Path("/root/Agent-Robustness-Via-ToM/plotly_utils.py")
spec = importlib.util.spec_from_file_location("plotly_utils_local", str(plotly_utils_path))
plotly_utils_local = importlib.util.module_from_spec(spec)
spec.loader.exec_module(plotly_utils_local)
imshow = plotly_utils_local.imshow

MAIN = __name__ == "__main__"

In [None]:
# Determine the correct path to the model
model_path = Path("./gpt-oss-20b")
if not model_path.exists():
    model_path = Path("/root/Agent-Robustness-Via-ToM/gpt-oss-20b")
    
print(f"Loading model from: {model_path}")
model = LanguageModel(str(model_path), device_map="auto", torch_dtype=t.bfloat16)
tokenizer = model.tokenizer

print()

N_HEADS = model.config.num_attention_heads
N_LAYERS = model.config.num_hidden_layers
D_MODEL = model.config.hidden_size
D_HEAD = model.config.head_dim

print(f"Number of heads: {N_HEADS}")
print(f"Number of layers: {N_LAYERS}")
print(f"Model (hidden) dimension: {D_MODEL}")
print(f"Head dimension: {D_HEAD}\n")

print("Entire config: ", model.config)

OSError: Can't load the configuration of './gpt-oss-20b'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure './gpt-oss-20b' is the correct path to a directory containing a config.json file

## Access Hidden States

### GPT-OSS-20B Quick Reference

- 24 layers | 2880 hidden | 64 heads | 201K vocab
- **Key difference**: `model.model.layers[i]` not `model.transformer.h[i]`

### Basic Usage
```python
with model.trace(prompt, remote=False):
    # Hidden states from layer i
    hidden = model.model.layers[i].output[0].save()
    
    # Logits (last token)
    logits = model.lm_head.output[0, -1].save()
```

### Steering
```python
with model.trace(prompt, remote=False):
    hidden = model.model.layers[i].output[0]
    hidden[:, -1, :] += steering_vector
    logits = model.lm_head.output[0, -1].save()
```

### Multi-Layer Extraction
```python
with model.trace(prompt, remote=False):
    all_hidden = [model.model.layers[i].output[0].save() for i in range(24)]
```

## Layer Components
```python
model.model.layers[i].self_attn     # Attention block
model.model.layers[i].mlp           # MoE block
model.model.embed_tokens            # Embeddings
model.model.norm                    # Final norm
model.lm_head                       # Output layer
```

## Expected Shapes

- Hidden: `(1, seq_len, 2880)`
- Logits: `(201088,)` for single token

In [10]:
# If you have an API key & want to work remotely, then set REMOTE = True and replace "YOUR-API-KEY"
# with your actual key. If not, then leave REMOTE = False.
REMOTE = False
prompt = "The Eiffel Tower is in the city of"

with model.trace(prompt, remote=REMOTE):
    # Save the model's hidden states
    hidden_states = model.model.layers[-1].output[0].save()

    # Save the model's logit output
    logits = model.lm_head.output[0, -1].save()

# Get the model's logit output, and it's next token prediction
print(f"logits.shape = {logits.shape} = (vocab_size,)")
print("Predicted token ID =", predicted_token_id := logits.argmax().item())
print(f"Predicted token = {tokenizer.decode(predicted_token_id)!r}")

# Print the shape of the model's residual stream
print(f"\nresid.shape = {hidden_states.shape} = (batch_size, seq_len, d_model)")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

logits.shape = torch.Size([201088]) = (vocab_size,)
Predicted token ID = 12650
Predicted token = ' Paris'

resid.shape = torch.Size([8, 2880]) = (batch_size, seq_len, d_model)


# Datasets


## Load ToM Datasets

Instead of using the ICL dataset, we'll load two Theory of Mind (ToM) datasets:
1. **ToM Dataset** (`first_order_1_tom_prompts.jsonl`) - Questions that require theory of mind reasoning
2. **No-ToM Dataset** (`first_order_1_no_tom_prompts.jsonl`) - Questions that don't require theory of mind


In [15]:
import json
from pathlib import Path

class ToMDataset:
    """
    Dataset for Theory of Mind (ToM) tasks from JSONL files.
    
    Inputs:
        jsonl_path: Path to the JSONL file containing prompts
        size: Optional limit on number of examples to load (None = load all)
    """
    
    def __init__(self, jsonl_path: str, size: int = None):
        self.jsonl_path = Path(jsonl_path)
        self.data = []
        self.prompts = []
        self.completions = []
        self.answers = []
        self.question_types = []
        self.story_types = []
        self.requires_tom = []
        
        # Load data from JSONL file
        with open(self.jsonl_path, 'r') as f:
            for i, line in enumerate(f):
                if size is not None and i >= size:
                    break
                item = json.loads(line)
                self.data.append(item)
                self.prompts.append(item['prompt'])
                # Completions are the answers with a leading space
                self.completions.append(' ' + item['answer'])
                self.answers.append(item['answer'])
                self.question_types.append(item['question_type'])
                self.story_types.append(item['story_type'])
                self.requires_tom.append(item['requires_tom'])
        
        self.size = len(self.data)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx: int):
        return self.data[idx]
    
    def __repr__(self):
        return f"ToMDataset(path={self.jsonl_path.name}, size={self.size}, requires_tom={self.requires_tom[0] if self.size > 0 else 'N/A'})"


# Load the two datasets
tom_dataset = ToMDataset(
    "tom_benchmarks/tomi/tomi_pairs/first_order_1_tom_prompts.jsonl",
    size=10
)
no_tom_dataset = ToMDataset(
    "tom_benchmarks/tomi/tomi_pairs/first_order_1_no_tom_prompts.jsonl",
    size=10
)

print(f"ToM dataset: {len(tom_dataset)} samples")
print(f"No-ToM dataset: {len(no_tom_dataset)} samples")

print("ToM Dataset (requires ToM):")
print(f"  {tom_dataset}")
print(f"  Example prompt:\n{tom_dataset.prompts[0][:200]}...")
print(f"  Example answer: {tom_dataset.answers[0]}")

print("\nNo-ToM Dataset (doesn't require ToM):")
print(f"  {no_tom_dataset}")
print(f"  Example prompt:\n{no_tom_dataset.prompts[0][:200]}...")
print(f"  Example answer: {no_tom_dataset.answers[0]}")


ToM dataset: 10 samples
No-ToM dataset: 10 samples
ToM Dataset (requires ToM):
  ToMDataset(path=first_order_1_tom_prompts.jsonl, size=10, requires_tom=True)
  Example prompt:
Story:
Isabella entered the den.
Olivia entered the den.
Isabella dislikes the pumpkin
The broccoli is in the blue_pantry.
Isabella exited the den.
Olivia moved the broccoli to the red_drawer.
Abigail...
  Example answer: blue_pantry

No-ToM Dataset (doesn't require ToM):
  ToMDataset(path=first_order_1_no_tom_prompts.jsonl, size=10, requires_tom=False)
  Example prompt:
Story:
Aria entered the front_yard.
Aiden entered the front_yard.
The grapefruit is in the green_bucket.
Aria moved the grapefruit to the blue_container.
Aiden exited the front_yard.
Noah entered the ...
  Example answer: blue_container


### Usage Notes

The `ToMDataset` class has the same interface as `ICLDataset`, so it can be used interchangeably:
- `.prompts` - list of prompt strings
- `.completions` - list of completion strings (with leading space)
- `.answers` - list of answer strings
- `.size` - number of examples

**Key difference from ICLDataset**: 
- ToMDataset does **not** have a `create_corrupted_dataset()` method
- If you need corrupted data for contrastive analysis, you'll need to implement that separately


In [16]:
# Example: View a complete example from each dataset
print("=" * 80)
print("EXAMPLE FROM TOM DATASET (requires ToM reasoning):")
print("=" * 80)
print(tom_dataset[0]['prompt'])
print(f"Correct answer: {tom_dataset[0]['answer']}")
print(f"Story type: {tom_dataset[0]['story_type']}")
print(f"Requires ToM: {tom_dataset[0]['requires_tom']}")

print("\n" + "=" * 80)
print("EXAMPLE FROM NO-TOM DATASET (doesn't require ToM):")
print("=" * 80)
print(no_tom_dataset[0]['prompt'])
print(f"Correct answer: {no_tom_dataset[0]['answer']}")
print(f"Story type: {no_tom_dataset[0]['story_type']}")
print(f"Requires ToM: {no_tom_dataset[0]['requires_tom']}")


EXAMPLE FROM TOM DATASET (requires ToM reasoning):
Story:
Isabella entered the den.
Olivia entered the den.
Isabella dislikes the pumpkin
The broccoli is in the blue_pantry.
Isabella exited the den.
Olivia moved the broccoli to the red_drawer.
Abigail entered the garden.
Isabella entered the garden.

Question: Where will Isabella look for the broccoli?
Answer:
Correct answer: blue_pantry
Story type: false_belief
Requires ToM: True

EXAMPLE FROM NO-TOM DATASET (doesn't require ToM):
Story:
Aria entered the front_yard.
Aiden entered the front_yard.
The grapefruit is in the green_bucket.
Aria moved the grapefruit to the blue_container.
Aiden exited the front_yard.
Noah entered the playroom.

Question: Where will Aiden look for the grapefruit?
Answer:
Correct answer: blue_container
Story type: true_belief
Requires ToM: False


In [17]:
prompt = """Isabella entered the den.
Olivia entered the den.
Isabella dislikes the pumpkin
The broccoli is in the blue_pantry.
Isabella exited the den.
Olivia moved the broccoli to the red_drawer.
Abigail entered the garden.
Isabella entered the garden.

Question: Where will Isabella look for the broccoli?
Answer:
"""

# Tokenize the prompt and move to GPU
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Use nnsight's generate method (pass tokenized inputs)
output = model.generate(
    inputs["input_ids"],
    max_new_tokens=500,
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id
)


# Or decode just the new tokens
new_tokens = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print("\nGenerated text only:", new_tokens)


Generated text only: We have a story with characters: Isabella, Olivia, Abigail. Items: broccoli, pumpkin, blue_pantry, red_drawer. The question: "Where will Isabella look for the broccoli?" We need to infer from the story. Let's parse the story:

- Isabella entered the den.
- Olivia entered the den.
- Isabella dislikes the pumpkin.
- The broccoli is in the blue_pantry.
- Isabella exited the den.
- Olivia moved the broccoli to the red_drawer.
- Abigail entered the garden.
- Isabella entered the garden.

We need to answer: Where will Isabella look for the broccoli? The story says the broccoli was originally in the blue_pantry. Then Olivia moved the broccoli to the red_drawer. So the broccoli is now in the red_drawer. Isabella is in the garden. So where will she look? She will look in the red_drawer. But maybe she will look in the garden? But the broccoli is not in the garden. The question: "Where will Isabella look for the broccoli?" The answer: She will look in the red_drawer. But may

# Find the relevant heads

In [None]:
def calculate_fn_vectors_and_intervene(
    model: LanguageModel,
    dataset: ICLDataset,
    layers: list[int] | None = None,
) -> Float[Tensor, "layers heads"]:
    """
    Returns a tensor of shape (layers, heads), containing the CIE for each head.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the function vector (we'll also
            create a corrupted version of this dataset for interventions)
        layers: list[int] | None
            the layers which this function will calculate score for (if None, this means all layers)
    """
    layers = range(model.config.n_layer) if (layers is None) else layers
    heads = range(model.config.n_head)

    # Get corrupted dataset
    corrupted_dataset = dataset.create_corrupted_dataset()
    N = len(dataset)

    # Get correct token ids, so we can get correct token logprobs
    correct_completion_ids = [toks[0] for toks in tokenizer(dataset.completions)["input_ids"]]

    with model.trace(remote=REMOTE) as tracer:
        # Run a forward pass on clean prompts, where we store attention head outputs
        z_dict = {}
        with tracer.invoke(dataset.prompts):
            for layer in layers:
                # Get hidden states, reshape to get head dimension, store the mean tensor
                z = model.transformer.h[layer].attn.out_proj.input[:, -1]
                z_reshaped = z.reshape(N, N_HEADS, D_HEAD).mean(dim=0)
                for head in heads:
                    z_dict[(layer, head)] = z_reshaped[head]

        # Run a forward pass on corrupted prompts, where we don't intervene or store activations (just so we can get the
        # correct-token logprobs to compare with our intervention)
        with tracer.invoke(corrupted_dataset.prompts):
            logits = model.lm_head.output[:, -1]
            correct_logprobs_corrupted = logits.log_softmax(dim=-1)[
                t.arange(N), correct_completion_ids
            ].save()

        # For each head, run a forward pass on corrupted prompts (here we need multiple different forward passes, since
        # we're doing different interventions each time)
        correct_logprobs_dict = {}
        for layer in layers:
            for head in heads:
                with tracer.invoke(corrupted_dataset.prompts):
                    # Get hidden states, reshape to get head dimension, then set it to the a-vector
                    z = model.transformer.h[layer].attn.out_proj.input[:, -1]
                    z.reshape(N, N_HEADS, D_HEAD)[:, head] = z_dict[(layer, head)]
                    # Get logprobs at the end, which we'll compare with our corrupted logprobs
                    logits = model.lm_head.output[:, -1]
                    correct_logprobs_dict[(layer, head)] = logits.log_softmax(dim=-1)[
                        t.arange(N), correct_completion_ids
                    ].save()

    # Get difference between intervention logprobs and corrupted logprobs, and take mean over batch dim
    all_correct_logprobs_intervention = einops.rearrange(
        t.stack([v for v in correct_logprobs_dict.values()]),
        "(layers heads) batch -> layers heads batch",
        layers=len(layers),
    )
    logprobs_diff = (
        all_correct_logprobs_intervention - correct_logprobs_corrupted
    )  # shape [layers heads batch]

    # Return mean effect of intervention, over the batch dimension
    return logprobs_diff.mean(dim=-1)

## Calculate Function Vector

In [18]:
def calculate_fn_vector(
    model: LanguageModel,
    dataset: ToMDataset,
    head_list: list[tuple[int, int]],
) -> Float[Tensor, "D_MODEL"]:
    """
    Returns a vector of length `D_MODEL`, containing the sum of vectors written to the residual
    stream by the attention heads in `head_list`, averaged over all inputs in `dataset`.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        dataset: ToMDataset
            the dataset of prompts from which we'll extract the function vector
        head_list: list[tuple[int, int]]
            list of attention heads we're calculating the function vector from
    """
    # Turn head_list into a dict of {layer: heads we need in this layer}
    head_dict = defaultdict(set)
    for layer, head in head_list:
        head_dict[layer].add(head)

    fn_vector_list = []

    with model.trace(dataset.prompts, remote=REMOTE):
        for layer, heads_in_layer in head_dict.items():
            # Get the output projection layer (GPT-OSS uses o_proj, not out_proj)
            o_proj = model.model.layers[layer].self_attn.o_proj

            # Get the mean output projection input (note, setting values of this tensor will not
            # have downstream effects on other tensors)
            hidden_states = o_proj.input[:, -1].mean(dim=0)

            # Zero-ablate all heads which aren't in our list, then get the output (which
            # will be the sum over the heads we actually do want!)
            heads_to_ablate = set(range(N_HEADS)) - head_dict[layer]
            for head in heads_to_ablate:
                hidden_states.reshape(N_HEADS, D_HEAD)[head] = 0.0

            # Now that we've zeroed all unimportant heads, get the output & add it to the list
            # (we need a single batch dimension so we can use `o_proj`)
            o_proj_output = o_proj(hidden_states.unsqueeze(0)).squeeze()
            fn_vector_list.append(o_proj_output.save())

    # We sum all attention head outputs to get our function vector
    fn_vector = sum([v for v in fn_vector_list])

    assert fn_vector.shape == (D_MODEL,)
    return fn_vector

In [21]:
def intervene_with_fn_vector(
    model: LanguageModel,
    prompt: str,
    layer: int,
    fn_vector: Float[Tensor, "D_MODEL"],
    n_tokens: int = 50,
) -> tuple[str, str]:
    """
    Intervenes with a function vector, by adding it at the last sequence position of a generated
    prompt.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        prompt: str
            The prompt to generate from (no longer template-based, just pass full prompt)
        layer: int
            The layer we'll make the intervention (by adding the function vector)
        fn_vector: Float[Tensor, "D_MODEL"]
            The vector we'll add to the final sequence position for each new token to be generated
        n_tokens: int
            The number of additional tokens we'll generate for our unsteered / steered completions

    Returns:
        completion: str
            The full completion (including original prompt) for the no-intervention case
        completion_intervention: str
            The full completion (including original prompt) for the intervention case
    """
    with model.generate(
        remote=REMOTE, max_new_tokens=n_tokens, pad_token_id=tokenizer.pad_token_id
    ) as generator:
        with model.all():
            with generator.invoke(prompt):
                tokens = model.generator.output.save()

            with generator.invoke(prompt):
                # Fixed: Need [:, -1, :] not [:, -1] to match dimensions (batch, seq, hidden)
                model.model.layers[layer].output[0][:, -1, :] += fn_vector
                tokens_intervention = model.generator.output.save()

    completion, completion_intervention = tokenizer.batch_decode(
        [tokens.squeeze().tolist(), tokens_intervention.squeeze().tolist()]
    )
    return completion, completion_intervention

## Test the intervention

In [37]:
# Instead of removing a word, we'll use the ToM datasets directly
# The goal is to extract a "ToM reasoning" function vector

# Define our datasets
# We'll use the no_tom_dataset to extract what the model does when NOT using ToM reasoning
# (This is analogous to using the clean antonym pairs in the original)
dataset = no_tom_dataset  # or tom_dataset, depending on what vector you want to extract

# Define the attention heads you'll use
# NOTE: You'll need to identify which heads are important for ToM reasoning
# This was done in the original via analysis - you may need to run similar analysis
head_list = [
    (8, 0),
    (8, 1),
    (9, 14),
    (11, 0),
    (12, 10),
    (13, 12),
    (13, 13),
    (14, 9),
    (15, 5),
    (16, 14),
]

# Extract the function vector
fn_vector = calculate_fn_vector(model, dataset, head_list)



# test_prompt = tom_dataset.prompts[0]  # or select a specific test case
test_prompt = "tell me a short story."

# Simple generation with steering at a specific layer
with model.generate(remote=REMOTE, max_new_tokens=500, pad_token_id=tokenizer.pad_token_id) as generator:
    # No intervention
    with generator.invoke(test_prompt):
        tokens_no_intervention = model.generator.output.save()
    
    # With intervention (adding the function vector)
    with generator.invoke(test_prompt):
        # During generation, layer output at position [0] is 2D: (batch, hidden_dim)
        # because it processes one token at a time
        # So we just add directly without the sequence dimension
        model.model.layers[9].output[0] += 1.5 * fn_vector
        tokens_with_intervention = model.generator.output.save()

completion_no_intervention = tokenizer.decode(tokens_no_intervention.squeeze().tolist())
completion_with_intervention = tokenizer.decode(tokens_with_intervention.squeeze().tolist())

print("NO INTERVENTION:", completion_no_intervention)
print("INTERVENTION:", completion_with_intervention)


NO INTERVENTION: tell me a short story. Don't feel much.."

They can produce short story with no emotion.

We don't have to produce too many paragraphs if needed.

We can maybe produce something like:

"At the edge of town, a man named Lian had a bench. The bench had seen decades pass, and the leaves fell upon it. It had a scar. The scar was an old hole where a child had once dropped a coin."

But no mention of emotion.

OK.

I will produce simple story in a few sentences, following constraints.

We must keep in mind the formatting of the last paragraph: "The last line has to be something that the main character says about the future or hope for a better world. The last line should not say that the entire story happened. It should feel hopeful or inspiring. Only 'we' could reflect something like 'we will do...". So final line is a quote: "We will..." no "past".

Given all, I'll write something. And the story should keep all constraints. Let's go.

I have to double-check each line for n