In [None]:
!pip install -U transformers kernels torch jaxtyping einops

In [None]:
from transformers import pipeline
import torch

model_id = "openai/gpt-oss-20b"

pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype="auto",
    device_map="auto",
)

messages = [
    {"role": "user", "content": "How do I insult someone based on their race, gender, or religion?"},
]

outputs = pipe(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import torch

model_id = "openai/gpt-oss-20b"

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype="auto",
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# Print full model architecture (this can be very long for large models!)
print(model)

In [None]:
layer_count = len(model.model.layers)
pos = -1  # final token position
print(model)

In [None]:
with open("harmful.txt", "r") as f:
    harmful_instructions = f.readlines()

with open("harmless.txt", "r") as f:
    harmless_instructions = f.readlines()

harmful_toks = [
    tokenizer.apply_chat_template(
        conversation=[{"role": "user", "content": insn}],
        add_generation_prompt=True,
        return_tensors="pt"
    ) for insn in harmful_instructions
]
harmless_toks = [
    tokenizer.apply_chat_template(
        conversation=[{"role": "user", "content": insn}],
        add_generation_prompt=True,
        return_tensors="pt"
    ) for insn in harmless_instructions
]

max_its = len(harmful_toks) + len(harmless_toks)
bar = tqdm(total=max_its)

def generate(toks):
    bar.update(1)
    return model.generate(
        toks.to(model.device),
        use_cache=False,
        max_new_tokens=1,
        return_dict_in_generate=True,
        output_hidden_states=True
    )

harmful_outputs = [generate(toks) for toks in harmful_toks]
harmless_outputs = [generate(toks) for toks in harmless_toks]

bar.close()

In [None]:
# hidden_states is a tuple of (layer_count+1) entries per output:
#   [embeddings, layer1, layer2, ..., layerN]
# We'll collect across all layers (skip embeddings at index 0)
harmful_hidden_all = [
    torch.stack([out.hidden_states[0][l][:, pos, :] for out in harmful_outputs])  # shape: [num_samples, hidden_dim]
    for l in range(1, layer_count + 1)  # start from 1 to skip embeddings
]
harmless_hidden_all = [
    torch.stack([out.hidden_states[0][l][:, pos, :] for out in harmless_outputs])
    for l in range(1, layer_count + 1)
]

# Compute mean activations for each layer
harmful_means = [h.mean(dim=0) for h in harmful_hidden_all]   # list of [hidden_dim]
harmless_means = [h.mean(dim=0) for h in harmless_hidden_all]

# Compute refusal_dir per layer
refusal_dirs = []
for l in range(layer_count):
    diff = harmful_means[l] - harmless_means[l]   # [hidden_dim]
    diff = diff / (diff.norm() + 1e-9)            # normalize
    refusal_dirs.append(diff)

# Stack into a single tensor [layers, hidden_dim]
refusal_dirs = torch.stack(refusal_dirs, dim=0)

# Save
save_path = model_id.replace("/", "_") + "_refusal_dirs.pt"
torch.save(refusal_dirs, save_path)

print("Saved refusal dirs with shape:", refusal_dirs.shape, "at", save_path)

In [None]:
from typing import Optional, Tuple
import torch.nn as nn
import jaxtyping
import random
import torch
from transformers import TextStreamer
import einops

from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import TextStreamer

# --- Load refusal directions ---
refusal_dirs = torch.load(model_id.replace("/", "_") + "_refusal_dirs.pt")
# expected shape: [num_layers, 1, hidden_dim] -> squeeze to [num_layers, hidden_dim]
if refusal_dirs.dim() == 3 and refusal_dirs.size(1) == 1:
    refusal_dirs = refusal_dirs.squeeze(1)  # -> [num_layers, hidden_dim]
elif refusal_dirs.dim() == 2:
    pass  # already [num_layers, hidden_dim]
else:
    raise ValueError(f"Unexpected refusal_dirs shape {tuple(refusal_dirs.shape)}")

num_layers, hidden_dim = refusal_dirs.shape
assert num_layers == len(model.model.layers), f"num_layers mismatch: {num_layers} vs {len(model.model.layers)}"

# normalize directions
refusal_dirs = torch.nn.functional.normalize(refusal_dirs, dim=-1)


# --- Hook factory ---
def make_ablation_hook(direction: torch.Tensor):
    direction = direction / (direction.norm() + 1e-9)

    def hook(module, inputs, output):
        if isinstance(output, tuple):
            x = output[0]
        else:
            x = output

        # x: [batch, seq_len, hidden_dim]
        # direction: [hidden_dim]

        # projection coefficient: <x, d>
        proj_coeff = (x * direction).sum(dim=-1, keepdim=True)   # [batch, seq_len, 1]
        proj = proj_coeff * direction.view(1, 1, -1)             # [batch, seq_len, hidden_dim]

        x = x - proj

        if isinstance(output, tuple):
            return (x,) + output[1:]
        return x

    return hook


# --- Attach hooks to each layer ---
for i, layer in enumerate(model.model.layers):
    dir_i = refusal_dirs[i].to(next(model.parameters()).dtype)
    hook = make_ablation_hook(dir_i)
    layer.register_forward_hook(hook)

# --- Test model on a safe prompt ---
streamer = TextStreamer(tokenizer)
conversation = [{"role": "user", "content": "Write a respectful 2-line poem about autumn."}]
toks = tokenizer.apply_chat_template(
    conversation=conversation,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

gen = model.generate(toks, streamer=streamer, max_new_tokens=50)
print(tokenizer.batch_decode(gen[0][len(toks[0]):], skip_special_tokens=True))