In [28]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import HookedSAETransformer, SAE
from collections import defaultdict
import pandas as pd
from functools import partial
import json
import transformer_lens.utils as utils

## Setup

In [79]:
MODEL_NAME   = "meta-llama/Llama-3.1-8B"
HF_TOKEN     = "..."
FEATURES_CSV = "./latent_outputs/chosen_features.csv"
DEV_PROMPTS  = "./prompts_dataset_dev.json"
TEST_PROMPTS = "./prompts_dataset_test.json"
# LAYERS       = [15, 16]
LAYERS       = [15, 16]
ALPHAS       = [0.1, 0.5, 1.0, 1.5, 2.0]
DEVICE       = 'cuda' if torch.cuda.is_available() else 'cpu'

# Seed
torch.manual_seed(0)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x14a332b586d0>

Load Model & Tokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
model = HookedSAETransformer.from_pretrained(
    MODEL_NAME, 
    device_map="auto", 
    use_auth_token=HF_TOKEN
).to(DEVICE)
model.eval()



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer
Moving model to device:  cuda


HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_

In [9]:
if tokenizer.pad_token is None:
    # reuse eos_token as pad
    tokenizer.pad_token = tokenizer.eos_token

Load SAEs for each layer and collect feature indices

In [4]:
# Read CSV of chosen features (with columns: label, layer, feature)
df_feats = pd.read_csv(FEATURES_CSV)
dims_map = defaultdict(lambda: defaultdict(list))
for _, row in df_feats.iterrows():
    dims_map[row['label']][int(row['layer'])].append(int(row['feature']))
    
dims_map

defaultdict(<function __main__.<lambda>()>,
            {'past': defaultdict(list,
                         {15: [18461, 2056, 18461, 17012, 20700, 2056],
                          16: [3146, 2546, 3146, 2546],
                          17: [3247, 8103, 3247, 18353, 8103]}),
             'present': defaultdict(list,
                         {15: [18003, 18003, 25167],
                          16: [20666, 27221, 20666],
                          17: [17274,
                           26778,
                           24848,
                           11683,
                           17274,
                           23345,
                           26778]}),
             'future': defaultdict(list,
                         {15: [32069, 32069, 23445],
                          16: [7255, 7255, 8351, 10758],
                          17: [22423, 22423, 18865, 10144]})})

In [27]:
# Pre-load SAE modules
saes = {}
for layer in LAYERS:
    hook_name = f"l{layer}r_8x"
    sae_module, _, _ = SAE.from_pretrained(
        release="llama_scope_lxr_8x",
        sae_id=hook_name,
        device=DEVICE
    )
    sae_module.eval()
#     saes[hook_name] = sae_module
    saes[sae_module.cfg.hook_name] = sae_module

print(saes.keys())

dict_keys(['blocks.15.hook_resid_post', 'blocks.16.hook_resid_post'])


### Helper functions

In [80]:
def evaluate_baseline(prompts):
    counts   = defaultdict(int)
    corrects = defaultdict(int)
    
    for ex in prompts:
        prompt = ex["prompt_text"]
        gold   = ex["gold_answer"]
        label  = ex["gold_tense"]

        # 1. Tokenize single sequence, no padding
        toks = tokenizer(prompt, return_tensors="pt").to(DEVICE)

        # 2. Run once to get logits
        logits, _ = model.run_with_cache(toks.input_ids)

        # 3. Grab the last‐token logits
        last_logits = logits[0, -1, :]          # shape (vocab_size,)

        # 4. Greedy pick
        next_id = last_logits.argmax(dim=-1)  # scalar

        # 5. Decode only that new token
        pred = tokenizer.decode(next_id.unsqueeze(0), skip_special_tokens=True).strip()

        # 6. Print + accumulate
        print(f"Prompt: {prompt!r} → Pred: {pred!r} | Gold: {gold!r}")
        
        counts[label] += 1
        if pred and pred[0] == gold:
            corrects[label] += 1

    # compute accuracy per label
    accuracies = {
        lbl: corrects[lbl] / counts[lbl]
        for lbl in counts
    }
    return accuracies

# --- Custom Ablation Function ---
def ablate_sae_feature(h, pos: int or None, feature_ids: list):
    # h: [B, T, latent_dim]
    if pos is None:
        h[:, :, feature_ids] = 0.0
    else:
        h[:, pos, feature_ids] = 0.0
    return h

# --- Evaluate with run_with_saes (scaling or ablation) ---
def evaluate_with_saes(prompts, alpha):
    counts   = defaultdict(int)
    corrects = defaultdict(int)

    for ex in prompts:
        prompt = ex["prompt_text"]
        gold   = ex["gold_answer"]
        label  = ex["gold_tense"]

        # load one SAE module per layer, scaled by alpha
        scaled_saes = []
        for layer in LAYERS:
            hook_name = utils.get_act_name("resid_post", layer)  # e.g., "blocks.15.hook_resid_post"
            feature_idxs = dims_map[label][layer]
#             print(f'label = {label}, layer = {layer}')
#             print(feature_idxs)
            base_sae     = saes[hook_name]

            class ScaledSAE(torch.nn.Module):
                def __init__(self, sae, idxs, alpha, hook_name):
                    super().__init__()
                    self.sae   = sae
                    self.idxs  = idxs
                    self.alpha = alpha
                    self.cfg   = sae.cfg
                    self.cfg.hook_name = hook_name

                def forward(self, z):
                    h = self.sae.encode(z)
                    h[:, :, self.idxs] *= self.alpha
                    return self.sae.decode(h)
            
            mod = ScaledSAE(base_sae, feature_idxs, alpha, hook_name)
            scaled_saes.append(mod)

        # tokenize single prompt
        toks = tokenizer(prompt, return_tensors="pt").to(DEVICE)

        # run with SAEs
        logits = model.run_with_saes(
            toks.input_ids,
            saes=scaled_saes,
            use_error_term=None
        )

        # greedy next‐token
        next_id = logits[0, -1, :].argmax(dim=-1)
        pred    = tokenizer.decode(next_id.unsqueeze(0), skip_special_tokens=True).strip()

        print(f"Prompt: {prompt!r} → Pred: {pred!r} | Gold: {gold!r}")

        counts[label] += 1
        if pred and pred[0] == gold:
            corrects[label] += 1

    # compute accuracy per label
    accuracies = {
        lbl: corrects[lbl] / counts[lbl]
        for lbl in counts
    }
    return accuracies

## Demo test before Grid search

In [7]:
# # Prepare first 5 sentences from dev per label for demo
# with open(DEV_PROMPTS) as f:
#     dev_data = json.load(f)

# pilot = []
# counts = {
#     'past': 0,
#     'present': 0,
#     'future': 0,
# }
# for ex in dev_data:
#     if counts[ex['gold_tense']] < 5:
#         pilot.append(ex)
#         counts[ex['gold_tense']] += 1
#     if all(c >= 5 for c in counts.values()):
#         break

In [70]:
# Baseline
accuracy = evaluate_baseline(dev_data)
print(accuracy)

Prompt: 'Last night, I ___ the journal.\nA) drew\nB) will draw\nC) draw\nAnswer:' → Pred: 'C' | Gold: 'A'
Prompt: 'Earlier today, she ___ the lecture.\nA) will explore\nB) explored\nC) explores\nAnswer:' → Pred: 'A' | Gold: 'B'
Prompt: 'Yesterday, I ___ the ball.\nA) watched\nB) will watch\nC) watch\nAnswer:' → Pred: 'A' | Gold: 'A'
Prompt: 'Last summer, the dog ___ the exam.\nA) summarizes\nB) summarized\nC) will summarize\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'A week ago, the musician ___ the cake.\nA) drew\nB) will draw\nC) draws\nAnswer:' → Pred: 'A' | Gold: 'A'
Prompt: 'Back then, the students ___ the cake.\nA) will speak\nB) spoke\nC) speaks\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'Last night, he ___ the game.\nA) walks\nB) will walk\nC) walked\nAnswer:' → Pred: 'C' | Gold: 'C'
Prompt: 'Once upon a time, the birds ___ a letter.\nA) taught\nB) teaches\nC) will teach\nAnswer:' → Pred: 'A' | Gold: 'A'
Prompt: 'Back then, he ___ the phone.\nA) paints\nB) painted\nC) will paint\

Layer 15

In [74]:
# Run with SAEs
alpha = 0
accuracy = evaluate_with_saes(dev_data, alpha)
print(accuracy)

Prompt: 'Last night, I ___ the journal.\nA) drew\nB) will draw\nC) draw\nAnswer:' → Pred: 'draw' | Gold: 'A'
Prompt: 'Earlier today, she ___ the lecture.\nA) will explore\nB) explored\nC) explores\nAnswer:' → Pred: 'A' | Gold: 'B'
Prompt: 'Yesterday, I ___ the ball.\nA) watched\nB) will watch\nC) watch\nAnswer:' → Pred: 'watch' | Gold: 'A'
Prompt: 'Last summer, the dog ___ the exam.\nA) summarizes\nB) summarized\nC) will summarize\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'A week ago, the musician ___ the cake.\nA) drew\nB) will draw\nC) draws\nAnswer:' → Pred: 'draw' | Gold: 'A'
Prompt: 'Back then, the students ___ the cake.\nA) will speak\nB) spoke\nC) speaks\nAnswer:' → Pred: 'A' | Gold: 'B'
Prompt: 'Last night, he ___ the game.\nA) walks\nB) will walk\nC) walked\nAnswer:' → Pred: 'will' | Gold: 'C'
Prompt: 'Once upon a time, the birds ___ a letter.\nA) taught\nB) teaches\nC) will teach\nAnswer:' → Pred: 'A' | Gold: 'A'
Prompt: 'Back then, he ___ the phone.\nA) paints\nB) painted\nC

In [65]:
# Run with SAEs
alpha = 1
accuracy = evaluate_with_saes(pilot, alpha)
print(accuracy)

Prompt: 'Last night, I ___ the journal.\nA) drew\nB) will draw\nC) draw\nAnswer:' → Pred: 'A' | Gold: 'A'
Prompt: 'Earlier today, she ___ the lecture.\nA) will explore\nB) explored\nC) explores\nAnswer:' → Pred: 'A' | Gold: 'B'
Prompt: 'Yesterday, I ___ the ball.\nA) watched\nB) will watch\nC) watch\nAnswer:' → Pred: 'A' | Gold: 'A'
Prompt: 'Last summer, the dog ___ the exam.\nA) summarizes\nB) summarized\nC) will summarize\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'A week ago, the musician ___ the cake.\nA) drew\nB) will draw\nC) draws\nAnswer:' → Pred: 'draw' | Gold: 'A'
Prompt: 'Always, the robot ___ the painting.\nA) will write\nB) wrote\nC) writes\nAnswer:' → Pred: 'A' | Gold: 'C'
Prompt: 'Every Monday, they ___ the cake.\nA) design\nB) designed\nC) will design\nAnswer:' → Pred: 'design' | Gold: 'A'
Prompt: 'Usually, the robot ___ the painting.\nA) plays\nB) played\nC) will play\nAnswer:' → Pred: 'A' | Gold: 'A'
Prompt: 'Every Saturday, the cooks ___ the project.\nA) jumps\nB) wil

Layer 16

In [77]:
# Run with SAEs
alpha = 2.0
accuracy = evaluate_with_saes(dev_data, alpha)
print(accuracy)

Prompt: 'Last night, I ___ the journal.\nA) drew\nB) will draw\nC) draw\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Earlier today, she ___ the lecture.\nA) will explore\nB) explored\nC) explores\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'Yesterday, I ___ the ball.\nA) watched\nB) will watch\nC) watch\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Last summer, the dog ___ the exam.\nA) summarizes\nB) summarized\nC) will summarize\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'A week ago, the musician ___ the cake.\nA) drew\nB) will draw\nC) draws\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Back then, the students ___ the cake.\nA) will speak\nB) spoke\nC) speaks\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'Last night, he ___ the game.\nA) walks\nB) will walk\nC) walked\nAnswer:' → Pred: 'B' | Gold: 'C'
Prompt: 'Once upon a time, the birds ___ a letter.\nA) taught\nB) teaches\nC) will teach\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Back then, he ___ the phone.\nA) paints\nB) painted\nC) will paint\

In [78]:
# Run with SAEs
alpha = 0.1
accuracy = evaluate_with_saes(dev_data, alpha)
print(accuracy)

Prompt: 'Last night, I ___ the journal.\nA) drew\nB) will draw\nC) draw\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Earlier today, she ___ the lecture.\nA) will explore\nB) explored\nC) explores\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'Yesterday, I ___ the ball.\nA) watched\nB) will watch\nC) watch\nAnswer:' → Pred: 'watch' | Gold: 'A'
Prompt: 'Last summer, the dog ___ the exam.\nA) summarizes\nB) summarized\nC) will summarize\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'A week ago, the musician ___ the cake.\nA) drew\nB) will draw\nC) draws\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Back then, the students ___ the cake.\nA) will speak\nB) spoke\nC) speaks\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'Last night, he ___ the game.\nA) walks\nB) will walk\nC) walked\nAnswer:' → Pred: 'B' | Gold: 'C'
Prompt: 'Once upon a time, the birds ___ a letter.\nA) taught\nB) teaches\nC) will teach\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Back then, he ___ the phone.\nA) paints\nB) painted\nC) will pa

#### Layer 15, 16

In [86]:
# Run with SAEs
alpha = 1.5
accuracy = evaluate_with_saes(dev_data, alpha)
print(accuracy)

Prompt: 'Last night, I ___ the journal.\nA) drew\nB) will draw\nC) draw\nAnswer:' → Pred: '__' | Gold: 'A'
Prompt: 'Earlier today, she ___ the lecture.\nA) will explore\nB) explored\nC) explores\nAnswer:' → Pred: 'B' | Gold: 'B'
Prompt: 'Yesterday, I ___ the ball.\nA) watched\nB) will watch\nC) watch\nAnswer:' → Pred: 'B' | Gold: 'A'
Prompt: 'Last summer, the dog ___ the exam.\nA) summarizes\nB) summarized\nC) will summarize\nAnswer:' → Pred: '______' | Gold: 'B'
Prompt: 'A week ago, the musician ___ the cake.\nA) drew\nB) will draw\nC) draws\nAnswer:' → Pred: '(' | Gold: 'A'
Prompt: 'Back then, the students ___ the cake.\nA) will speak\nB) spoke\nC) speaks\nAnswer:' → Pred: 'C' | Gold: 'B'
Prompt: 'Last night, he ___ the game.\nA) walks\nB) will walk\nC) walked\nAnswer:' → Pred: '__' | Gold: 'C'
Prompt: 'Once upon a time, the birds ___ a letter.\nA) taught\nB) teaches\nC) will teach\nAnswer:' → Pred: 'a' | Gold: 'A'
Prompt: 'Back then, he ___ the phone.\nA) paints\nB) painted\nC) will