In [1]:
import torch
from transformers import AutoTokenizer
from sae_lens import HookedSAETransformer, SAE
from collections import defaultdict
import pandas as pd
import json
from tqdm import tqdm
import transformer_lens.utils as utils
import csv

## Setup

In [2]:
MODEL_NAME   = "meta-llama/Llama-3.1-8B"
HF_TOKEN     = "..."
FEATURES_CSV = "./latent_outputs_yusser/chosen_features.csv"
DEV_PROMPTS  = "./prompts_dataset_dev.json"
TEST_PROMPTS = "./prompts_dataset_test.json"
LAYERS       = [15, 16]
ALPHAS       = [1.5, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
DEVICE       = 'cuda' if torch.cuda.is_available() else 'cpu'

# Thresholds for labeling
ACCURACY_THRESHOLD = {
    'past': 0.74,
    'present': 0.64,
    'future': 0.8,
}

# Seed
torch.manual_seed(0)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x14e3696ca800>

Load Model & Tokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
model = HookedSAETransformer.from_pretrained(
    MODEL_NAME, 
    device_map="auto", 
    use_auth_token=HF_TOKEN
).to(DEVICE)
model.eval()

# Custom [PAD] special token
if tokenizer.pad_token is None:
    # reuse eos_token as pad
    tokenizer.pad_token = tokenizer.eos_token



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer
Moving model to device:  cuda


Load SAEs for each layer and collect condidate feature indices

In [4]:
# Read CSV of chosen features (with columns: label, layer, feature)
df_feats = pd.read_csv(FEATURES_CSV)

candidate_features = defaultdict(list)
for _, row in df_feats.iterrows():
    candidate_features[int(row['layer'])].append(int(row['feature']))
    
for k, v in candidate_features.items():
    candidate_features[k] = list(dict.fromkeys(v))
    
candidate_features

defaultdict(list,
            {15: [30777,
              702,
              28855,
              5112,
              23112,
              32090,
              26492,
              12722,
              15316,
              7890,
              15706],
             16: [5215,
              1221,
              32043,
              3689,
              9951,
              6922,
              25624,
              17716,
              7895,
              3638,
              12508,
              28602,
              23504],
             17: [26956,
              30639,
              23198,
              2164,
              9067,
              363,
              1383,
              32205,
              6612,
              17052,
              3212,
              10675,
              469,
              12337]})

In [5]:
# Pre-load SAE modules
saes = {}
for layer in LAYERS:
    hook_name = f"blocks.{layer}.hook_resid_post"
    sae_module, _, _ = SAE.from_pretrained(
        release="Yusser/multilingual_llama3.1-8B_saes",
        sae_id=hook_name,
        device=DEVICE
    )
    sae_module.eval()
#     saes[hook_name] = sae_module
    saes[sae_module.cfg.hook_name] = sae_module

print(saes.keys())

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


dict_keys(['blocks.15.hook_resid_post', 'blocks.16.hook_resid_post'])


In [6]:
# ScaledSAE wrapper at module scope
torch.set_grad_enabled(False)
class ScaledSAE(torch.nn.Module):
    def __init__(self, base_sae, idxs, alpha, hook_name):
        super().__init__()
        self.sae   = base_sae
        self.idxs  = idxs
        self.alpha = alpha
        self.cfg   = base_sae.cfg
        self.cfg.hook_name = hook_name

    def forward(self, z):
        h = self.sae.encode(z)
        if self.idxs:
            h[:, :, self.idxs] *= self.alpha
        return self.sae.decode(h)
    
# class ScaledSAE(torch.nn.Module):
#     def __init__(self, sae, idxs, alpha, hook_name):
#         super().__init__()
#         self.sae   = sae
#         self.idxs  = idxs
#         self.alpha = alpha
#         self.cfg   = sae.cfg
#         self.cfg.hook_name = hook_name

#     def forward(self, z):
#         h = self.sae.encode(z)
#         h[:, :, self.idxs] *= self.alpha
#         return self.sae.decode(h)

### Helper functions

In [10]:
# Evaluation function
def evaluate_with_saes(layer, feature, alpha, prompts):
    counts   = defaultdict(int)
    corrects = defaultdict(int)

    # Build scaled SAE list: only scale the given feature at given layer
    scaled_saes = []
    for lyr in LAYERS:
        hook_name = utils.get_act_name("resid_post", lyr)
        base_sae  = saes[hook_name]
        if lyr == layer:
            idxs = [feature]
            alpha_val = alpha
        else:
            idxs = []
            alpha_val = 1.0
        scaled_saes.append(ScaledSAE(base_sae, idxs, alpha_val, hook_name))

    with torch.no_grad():
        for ex in prompts:
            prompt = ex["prompt_text"]
            gold   = ex["gold_answer"]
            label  = ex["gold_tense"]

            toks = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
            logits = model.run_with_saes(
                toks.input_ids,
                saes=scaled_saes,
                use_error_term=None
            )
            # Greedy next token
            next_id = logits[0, -1, :].argmax(dim=-1)
            pred    = tokenizer.decode(next_id.unsqueeze(0), skip_special_tokens=True).strip()

            counts[label] += 1
            if pred and pred[0] == gold:
                corrects[label] += 1

    # Compute per-tense accuracy
    acc = {}
    for tense in ['past', 'present', 'future']:
        acc[tense] = corrects[tense] / counts.get(tense, 1)
    return acc

# Load prompts
def load_prompts(path):
    with open(path) as f:
        return json.load(f)
                    

# # Run grid search and write to CSV with labeling and thresholding
# def run_grid(prompts, output_csv):
#     with open(output_csv, 'w', newline='') as csvfile:
#         fieldnames = ['label', 'layer', 'feature', 'alpha_amplify', 'accuracy_past', 'accuracy_present', 'accuracy_future']
#         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
#         writer.writeheader()

#         for layer in LAYERS:
#             for feature in candidate_features[layer]:
#                 for alpha in ALPHAS:
#                     acc = evaluate_with_saes(layer, feature, alpha, prompts)
#                     valid_labels = [t for t, thr in ACCURACY_THRESHOLD.items() if acc[t] >= thr]
#                     if not valid_labels:
#                         continue
#                     label = '_'.join(sorted(valid_labels))
#                     writer.writerow({
#                         'label': label,
#                         'layer': layer,
#                         'feature': feature,
#                         'alpha_amplify': alpha,
#                         'accuracy_past': acc['past'],
#                         'accuracy_present': acc['present'],
#                         'accuracy_future': acc['future'],
#                     })


# Run grid search and write to CSV with labeling and thresholding
def run_grid(prompts, output_csv):
    with open(output_csv, 'w', newline='') as csvfile:
        fieldnames = ['label','layer','feature','alpha_amplify','accuracy_past','accuracy_present','accuracy_future']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for layer in LAYERS:
            for feature in candidate_features[layer]:
                for alpha in ALPHAS:
                    acc = evaluate_with_saes(layer, feature, alpha, prompts)
                    valid = [t for t,thr in ACCURACY_THRESHOLD.items() if acc[t]>=thr]
                    if not valid: continue
                    for label in valid:
                        writer.writerow({
                            'label': label,
                            'layer': layer,
                            'feature': feature,
                            'alpha_amplify': alpha,
                            'accuracy_past': acc['past'],
                            'accuracy_present': acc['present'],
                            'accuracy_future': acc['future'],
                        })
                    
# Process grid results: select best alpha per (layer,feature), relabel by thresholds
# def process_results(input_csv, output_csv):
#     df = pd.read_csv(input_csv)
#     # Compute max metric per row
#     df['max_acc'] = df[['accuracy_past','accuracy_present','accuracy_future']].max(axis=1)
#     # Select best row per layer-feature
#     idx = df.groupby(['layer','feature'])['max_acc'].idxmax()
#     best = df.loc[idx].copy()
#     # Recompute multi-label from thresholds
#     def make_label(row):
#         labs = [t for t,thr in ACCURACY_THRESHOLD.items() if row[f'accuracy_{t}']>=thr]
#         return '_'.join(sorted(labs)) if labs else None
#     best['label'] = best.apply(make_label, axis=1)
#     best = best.dropna(subset=['label'])
#     # Write processed CSV
#     best[['layer','label','feature','alpha_amplify']].to_csv(output_csv, index=False)

# Process grid results: select best alpha per (layer,feature,label)
def process_results(input_csv, output_csv):
    df = pd.read_csv(input_csv)
    rows = []
    for label in ['past','present','future']:
        df_label = df[df['label'] == label].copy()
        df_label['acc'] = df_label[f'accuracy_{label}']
        idx = df_label.groupby(['layer','feature'])['acc'].idxmax()
        best = df_label.loc[idx]
        for _, row in best.iterrows():
            rows.append({
                'layer': row['layer'],
                'label': label,
                'feature': row['feature'],
                'alpha_amplify': row['alpha_amplify']
            })
    pd.DataFrame(rows).to_csv(output_csv, index=False)

## Grid search for features, layers, and alphas

In [8]:
dev_prompts = load_prompts(DEV_PROMPTS)

In [11]:
run_grid(dev_prompts, "grid_search_dev_results_translation.csv")

## Labeling the features

In [12]:
# Post-process test results
process_results("grid_search_dev_results_translation.csv", "found_features.csv")