In [1]:
!pip install biopython



# Download 16s rRNA

In [3]:
from Bio import SeqIO
import gzip
import urllib.request

# Download the full E. coli genome
genome_url = "https://ftp.ncbi.nlm.nih.gov/genomes/refseq/bacteria/Escherichia_coli/reference/GCF_000005845.2_ASM584v2/GCF_000005845.2_ASM584v2_genomic.fna.gz"

print("Downloading genome...")
urllib.request.urlretrieve(genome_url, "ecoli_genome.fna.gz")

# Also download the annotation to find 16S gene coordinates
gff_url = "https://ftp.ncbi.nlm.nih.gov/genomes/refseq/bacteria/Escherichia_coli/reference/GCF_000005845.2_ASM584v2/GCF_000005845.2_ASM584v2_genomic.gff.gz"
urllib.request.urlretrieve(gff_url, "ecoli_genome.gff.gz")

print("Files downloaded!")

# Parse the genome
with gzip.open("ecoli_genome.fna.gz", "rt") as f:
    genome_record = SeqIO.read(f, "fasta")
    genome_seq = str(genome_record.seq)

print(f"Genome length: {len(genome_seq):,} bp")

# Parse GFF to find 16S rRNA coordinates
print("\nSearching for 16S rRNA genes...")

ribosomal_genes = []
with gzip.open("ecoli_genome.gff.gz", "rt") as f:
    for line in f:
        if line.startswith("#"):
            continue
        if "16S ribosomal RNA" in line or "product=16S ribosomal RNA" in line:
            parts = line.split("\t")
            if len(parts) >= 9:
                start = int(parts[3])
                end = int(parts[4])
                strand = parts[6]
                ribosomal_genes.append({
                    'start': start,
                    'end': end,
                    'strand': strand,
                    'length': end - start + 1
                })
                print(f"  Found: {start:,}-{end:,} ({strand} strand), length: {end-start+1} bp")

# E. coli has multiple 16S rRNA copies, let's use the first one
if ribosomal_genes:
    gene = ribosomal_genes[1]

    # Extract with flanking regions => more prob on ground truth token
    flank_size = 2000  # flanking regions

    start_with_flank = max(0, gene['start'] - 1 - flank_size)  # -1 for 0-indexing
    end_with_flank = min(len(genome_seq), gene['end'] + flank_size)

    extended_sequence = genome_seq[start_with_flank:end_with_flank]

    # Calculate where the actual 16S gene is within this extended sequence
    gene_start_in_extended = gene['start'] - 1 - start_with_flank
    gene_end_in_extended = gene['end'] - start_with_flank

    print(f"\n{'='*70}")
    print("EXTENDED SEQUENCE EXTRACTED")
    print(f"{'='*70}")
    print(f"Genomic coordinates: {start_with_flank+1:,} - {end_with_flank:,}")
    print(f"Total length: {len(extended_sequence):,} bp")
    print(f"  - Upstream flank: {gene_start_in_extended:,} bp")
    print(f"  - 16S rRNA gene: {gene['length']:,} bp")
    print(f"  - Downstream flank: {len(extended_sequence) - gene_end_in_extended:,} bp")
    print(f"\n16S gene position in extended sequence: {gene_start_in_extended}-{gene_end_in_extended}")

    # Save the extended sequence
    with open("ecoli_16S_extended.fasta", "w") as f:
        f.write(f">E.coli_16S_extended |genomic:{start_with_flank+1}-{end_with_flank} |gene:{gene_start_in_extended}-{gene_end_in_extended}\n")
        # Write in 80-character lines
        for i in range(0, len(extended_sequence), 80):
            f.write(extended_sequence[i:i+80] + "\n")

    print("\nExtended sequence saved to 'ecoli_16S_extended.fasta'")

    # Now recalculate mutation positions in the extended sequence
    # The original positions (1408, 1491, 1494) are relative to the 16S gene start
    mutation_positions_in_extended = {
        'in_seq_pos_1408': gene_start_in_extended + 1408 - 1,  # -1 for 0-indexing
    }

    print(f"\n{'='*70}")
    print("MUTATION POSITIONS IN EXTENDED SEQUENCE")
    print(f"{'='*70}")
    for name, pos in mutation_positions_in_extended.items():
        nt = extended_sequence[pos]
        print(f"{name}: position {pos+1} (nucleotide: {nt})")

else:
    print("No 16S rRNA genes found in GFF file!")

Downloading genome...
Files downloaded!
Genome length: 4,641,652 bp

Searching for 16S rRNA genes...
  Found: 223,771-225,312 (+ strand), length: 1542 bp
  Found: 223,771-225,312 (+ strand), length: 1542 bp
  Found: 2,729,616-2,731,157 (- strand), length: 1542 bp
  Found: 2,729,616-2,731,157 (- strand), length: 1542 bp
  Found: 3,427,221-3,428,762 (- strand), length: 1542 bp
  Found: 3,427,221-3,428,762 (- strand), length: 1542 bp
  Found: 3,941,808-3,943,349 (+ strand), length: 1542 bp
  Found: 3,941,808-3,943,349 (+ strand), length: 1542 bp
  Found: 4,035,531-4,037,072 (+ strand), length: 1542 bp
  Found: 4,035,531-4,037,072 (+ strand), length: 1542 bp
  Found: 4,166,659-4,168,200 (+ strand), length: 1542 bp
  Found: 4,166,659-4,168,200 (+ strand), length: 1542 bp
  Found: 4,208,147-4,209,688 (+ strand), length: 1542 bp
  Found: 4,208,147-4,209,688 (+ strand), length: 1542 bp

EXTENDED SEQUENCE EXTRACTED
Genomic coordinates: 221,771 - 227,312
Total length: 5,542 bp
  - Upstream flank

In [3]:
#print the extracted sequnce
print(gene)
print(len(gene))
print(extended_sequence)
print(len(extended_sequence))

{'start': 223771, 'end': 225312, 'strand': '+', 'length': 1542}
4
ACGTTGAAACGACGCGCGGTTTCAGAAAGCAGTGGGGCATCGACCGATTGACCGGTAAACTCCAGACGCAGCATCGGCACGCAGTCAGTAAATGGCTCCGCTTGCAGACGTTCCTGGTAATCTTCCGGGATATCCAGATGCAGGGTCGACTGAATAAACTTCTGCGCCAGCGGCGTTTTCGGATGCGAGAACACTTCACTTACCGTGTCCTGCTCGATCAGTTCTCCATTGCTGATGACCGCCACGCAATCACAAATGCGCTTCACAACGTCCATTTCGTGGGTGATCAACAGAATCGTCAACCCCAGACGGCGGTTGATGTCTTTCAGCAGTTCGAGAATAGAACGTGTCGTTGCCGGGTCCAGCGCGCTGGTGGCTTCATCACACAGCAATACTTTGGGATTGCTGGCTAACGCACGGGCAATTGCCACACGTTGTTTCTGCCCACCGGAAAGATTCGACGGGTAGCTATCATGCTTATCGCCAAGACCAACCAATGACAGCAATTCCGTCACGCGACGTTTGACCTCGTCTTTCGGTGTGTTGTCCAGCTCCAGCGGCAGAGCCACGTTGCCAAAAACAGTACGCGAAGAGAGCAGGTTAAAATGCTGGAAAATCATACCAATCTGGCGGCGAGCTTTGGTCAACTCGGATTCTGACAGCGTGGTCAGTTCCTGGCCATCGACCAGCACGCTACCCTCGGTTGGGCGCTCCAGCAGGTTTACACAACGTATAAGCGTACTCTTACCCGCGCCTGAGGCACCGATAACGCCATAAATTTGTCCAGCTGGCACATGCAGGCTGACGTTGTTCAACGCCTGGATGGTGCGGGTGCCCTGGTGGAACACTTTGGTGATATTCGAAAGTTTTATCATTGATTATTTATTATCGTCATTAAGTTAGTCGTGGCATCTCGAATGCCTGAAACGGGCAA

In [4]:
from Bio import SeqIO

# Read the fasta file
record = SeqIO.read("ecoli_16S_extended.fasta", "fasta")

# Extract just the sequence as a string
ribosome_gene_seq = str(record.seq)

# Print info
print(f"ID: {record.id}")
print(f"Description: {record.description}")
print(f"Length: {len(ribosome_gene_seq)}")

ID: E.coli_16S_extended
Description: E.coli_16S_extended |genomic:221771-227312 |gene:2000-3542
Length: 5542


## Model loading

In [4]:
# Load model
from transformers import AutoTokenizer, AutoModelForMaskedLM

n_params = 500
model = AutoModelForMaskedLM.from_pretrained(f"InstaDeepAI/nucleotide-transformer-v2-{n_params}m-multi-species", trust_remote_code=True)



## Data preprocessing

In [6]:
import sys
import os

# Add the src directory to the Python path
# Assumes you are currently in the /content directory after running %cd ..
sys.path.append(os.path.join(os.getcwd(), 'sae_for_glm', 'src'))

print("Added sae_for_glm/src to Python path.")
print(sys.path)

Added sae_for_glm/src to Python path.
['/home/maiwald/miniconda3/envs/prokka_env/lib/python312.zip', '/home/maiwald/miniconda3/envs/prokka_env/lib/python3.12', '/home/maiwald/miniconda3/envs/prokka_env/lib/python3.12/lib-dynload', '', '/home/maiwald/miniconda3/envs/prokka_env/lib/python3.12/site-packages', '/home/maiwald/Downloads/sae_for_glm_repo', '/home/maiwald/.cache/huggingface/modules', '/home/maiwald/Downloads/sae_for_glm_repo/notebooks/sae_for_glm/src']


In [7]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch


# Use the extended sequence for predictions
extended_seq = extended_sequence

# Tokenize
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-500m-multi-species", trust_remote_code=True)
max_length = tokenizer.model_max_length

sequences_extended = [extended_seq]
tokens_ids_extended = tokenizer(
    sequences_extended,
    return_tensors="pt",
    padding="max_length",
    max_length=max_length
)["input_ids"]

print(f"\nTokenized extended sequence:")
print(f"  Sequence length: {len(extended_seq):,} bp")
print(f"  Token length: {tokens_ids_extended.shape[1]}")
print(f"  Max model length: {max_length}")

# Recalculate token positions for mutations
nt_per_token = 6
mutation_token_positions_extended = {
    name: pos // nt_per_token
    for name, pos in mutation_positions_in_extended.items()
}

print(f"\n{'='*70}")
print("MUTATION TOKEN POSITIONS IN EXTENDED SEQUENCE")
print(f"{'='*70}")
for name, token_pos in mutation_token_positions_extended.items():
    print(f"{name}: token {token_pos}")

#masked_seq=extended_seq[:6*(token_pos+1)]+"<mask>"+extended_seq[6*(token_pos+1)+6:]
masked_seq=extended_seq[:6*567]+"<mask>"+extended_seq[567*6+6:]
print(f"\nMasked extended sequence:{extended_seq[6*567:6*567+6]}")
print(masked_seq)
print(masked_seq[6*(token_pos)])




Tokenized extended sequence:
  Sequence length: 5,542 bp
  Token length: 2048
  Max model length: 2048

MUTATION TOKEN POSITIONS IN EXTENDED SEQUENCE
in_seq_pos_1408: token 567

Masked extended sequence:CCGTCA
ACGTTGAAACGACGCGCGGTTTCAGAAAGCAGTGGGGCATCGACCGATTGACCGGTAAACTCCAGACGCAGCATCGGCACGCAGTCAGTAAATGGCTCCGCTTGCAGACGTTCCTGGTAATCTTCCGGGATATCCAGATGCAGGGTCGACTGAATAAACTTCTGCGCCAGCGGCGTTTTCGGATGCGAGAACACTTCACTTACCGTGTCCTGCTCGATCAGTTCTCCATTGCTGATGACCGCCACGCAATCACAAATGCGCTTCACAACGTCCATTTCGTGGGTGATCAACAGAATCGTCAACCCCAGACGGCGGTTGATGTCTTTCAGCAGTTCGAGAATAGAACGTGTCGTTGCCGGGTCCAGCGCGCTGGTGGCTTCATCACACAGCAATACTTTGGGATTGCTGGCTAACGCACGGGCAATTGCCACACGTTGTTTCTGCCCACCGGAAAGATTCGACGGGTAGCTATCATGCTTATCGCCAAGACCAACCAATGACAGCAATTCCGTCACGCGACGTTTGACCTCGTCTTTCGGTGTGTTGTCCAGCTCCAGCGGCAGAGCCACGTTGCCAAAAACAGTACGCGAAGAGAGCAGGTTAAAATGCTGGAAAATCATACCAATCTGGCGGCGAGCTTTGGTCAACTCGGATTCTGACAGCGTGGTCAGTTCCTGGCCATCGACCAGCACGCTACCCTCGGTTGGGCGCTCCAGCAGGTTTACACAACGTATAAGCGTACTCTTACCCGCGCCTGAGGCACCGATAACGCCATAAATTTGTCCAGCT

In [10]:
import sys
import os
import importlib

# Add the repository root to Python path  
repo_root = os.path.abspath('..')
if repo_root not in sys.path:
    sys.path.append(repo_root)

# Import and reload
import src.steering_utils_pj
importlib.reload(src.steering_utils_pj)


<module 'src.steering_utils_pj' from '/home/maiwald/Downloads/sae_for_glm_repo/src/steering_utils_pj.py'>

In [13]:

# Replace with the actual path to your SAE model file in Google Drive
SAE_MODEL_PATH = '/home/maiwald/Downloads/ae.pt'


In [15]:
# --- imports
import os, json, time, math
from typing import List, Sequence, Dict, Any

import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer

# You provide this:
from src.steering_utils_pj import steer_with_sae  # must return dict with "logits_steered"

# --- configuration (edit these)
MODEL_NAME = "InstaDeepAI/nucleotide-transformer-v2-500m-multi-species"
SAE_MODEL_PATH = SAE_MODEL_PATH        # e.g., "/path/to/sae.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Input sequence (string). If you mask, do it in the same way you used in your experiments.
masked_seq = masked_seq

# Position whose distribution you inspect (where you read the logits/probs)
TOKEN_POSITION = 568

# Mutation (A1408G) token group — add all k-mers you consider “mutation”
TARGET_TOKENS = ["CCGTCG"]    # <- adjust to your exact 6-mer(s) for A1408G context

# Optional: WT token group to compare against (leave [] if not using)
WT_TOKENS = ["CCGTCA"]                # e.g., ["<WT_6mer>"]

# Latents to test
LATENTS = [
    5066,   # KanR/kanMX (aminoglycoside)
    628,    # MES psi
    4420,   # SmR (control ABR)
    3322,   # rrnB T2 terminator
    5715,   # PGK promoter
    5971,   # rop
    6437,   # lacI
    4903,   # AAV2 ITR
    3104,   # gag (truncated)
    5561,   # small t antigen
    3619,   # T3 promoter
    5784,   # Ac5 promoter
    1609,   # 3xFLAG
    # add more if you like (including some unknown/randoms)
]

LATENTS =[

5561,
4837,
628,
4903,
3104,
3322,
1271,
4420,
4067,
5715,
189,
3619,
5066,
954,
5971,
7319,
6437,
681,
5784,
3258,
1540,
7746,
4170,
1609,
]

# Steering values: negative = suppress, positive = enhance; include 0 as baseline
STEERING_VALUES = [-150, -100, -70, -60, -50, -30, -20, 0, 20, 30, 50, 60, 70, 100, 150]
#STEERING_VALUES = [-50, 0, 50,70]
#STEERING_VALUES = [20]
# Steering positions (relative to TOKEN_POSITION & a few far controls)
p = TOKEN_POSITION
POSITIONS_TO_STEER = [
    [p],                     # masked
    [p-1],                   # nearby (left)
    [p,p-1],              # nearby + mask
    [p-100],                   # far

]

# Output folder
STAMP = time.strftime("%Y%m%d_%H%M%S")
OUT_DIR = f"./steering_sweep_{STAMP}"
os.makedirs(OUT_DIR, exist_ok=True)
CSV_PATH = os.path.join(OUT_DIR, "steering_full_results.csv")
CONFIG_JSON = os.path.join(OUT_DIR, "run_config.json")

# --- tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)




In [16]:
def to_np(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)

def softmax_np(logits_1d: np.ndarray) -> np.ndarray:
    m = float(np.max(logits_1d))
    e = np.exp(logits_1d - m)
    return e / float(np.sum(e))

def encode_token(tok: str) -> int:
    ids = tokenizer.encode(tok, add_special_tokens=False)
    if not ids:
        raise ValueError(f"Token '{tok}' cannot be encoded; got {ids}")
    if len(ids) > 1:
        print(f"[warn] token '{tok}' split into {len(ids)} ids; using first {ids[0]}")
    return int(ids[0])

def rank_of_id(logits: np.ndarray, tid: int) -> int:
    # rank 1 = highest logit
    return 1 + int((logits > logits[tid]).sum())

def group_prob_and_rank(logits: np.ndarray, probs: np.ndarray, ids: Sequence[int]):
    if len(ids) == 0:
        return math.nan, math.nan
    gp = float(np.sum(probs[ids]))
    best_rank = min(rank_of_id(logits, i) for i in ids)
    return gp, best_rank

def top1_belongs_to(ids: Sequence[int], top1_id: int) -> bool:
    return int(top1_id) in set(int(i) for i in ids)

def classify_position_category(position_list: Sequence[int], masked_idx: int) -> str:
    pos = list(position_list)
    if pos == [masked_idx]:
        return "masked"
    if any(abs(p - masked_idx) <= 3 for p in pos):
      if len(pos) == 2:
        return "masked and nearby"
      else:
        return "nearby"

    if len(pos) >= 3:
        return "broad"
    return "far"


In [17]:
TARGET_IDS = [encode_token(t) for t in TARGET_TOKENS]
WT_IDS = [encode_token(t) for t in WT_TOKENS]

print("Model:", MODEL_NAME)
print("SAE path:", SAE_MODEL_PATH)
print("Device:", DEVICE)
print("Sequence length:", len(masked_seq))
print("Inspecting token position:", TOKEN_POSITION)
print("Mutation tokens -> ids:", list(zip(TARGET_TOKENS, TARGET_IDS)))
if WT_IDS:
    print("WT tokens -> ids:", list(zip(WT_TOKENS, WT_IDS)))
print("Latents:", LATENTS)
print("Steering values:", STEERING_VALUES)
print("Positions:", POSITIONS_TO_STEER)


Model: InstaDeepAI/nucleotide-transformer-v2-500m-multi-species
SAE path: /home/maiwald/Downloads/ae.pt
Device: cpu
Sequence length: 5542
Inspecting token position: 568
Mutation tokens -> ids: [('CCGTCG', 2785)]
WT tokens -> ids: [('CCGTCA', 2782)]
Latents: [5561, 4837, 628, 4903, 3104, 3322, 1271, 4420, 4067, 5715, 189, 3619, 5066, 954, 5971, 7319, 6437, 681, 5784, 3258, 1540, 7746, 4170, 1609]
Steering values: [-150, -100, -70, -60, -50, -30, -20, 0, 20, 30, 50, 60, 70, 100, 150]
Positions: [[568], [567], [568, 567], [468]]


In [18]:
records: List[Dict[str, Any]] = []

# Compute baseline once (sv==0, no steering). We'll attach its values to rows via columns.
print("[baseline] computing baseline (no steering)…")
base_results = steer_with_sae(
    model_name=MODEL_NAME,
    sae_model_path=SAE_MODEL_PATH,
    latents_to_max=[],               # no steering
    latents_to_zero=[],
    input_sequence=masked_seq,
    layer_num=9,
    steering_value=0.0,
    position_to_steer=[],
    steering_value_method="fixed",
    device=DEVICE,
)
base_logits_all = to_np(base_results["logits_steered"])
base_logits = base_logits_all[TOKEN_POSITION]
base_probs = softmax_np(base_logits)
base_target_prob, base_target_rank = group_prob_and_rank(base_logits, base_probs, TARGET_IDS)
base_wt_prob, base_wt_rank = group_prob_and_rank(base_logits, base_probs, WT_IDS) if WT_IDS else (math.nan, math.nan)

total = len(LATENTS) * len(STEERING_VALUES) * len(POSITIONS_TO_STEER)
k = 0

for latent in LATENTS:
    for sv in STEERING_VALUES:
        for pos_list in POSITIONS_TO_STEER:
            k += 1
            print(f"[{k}/{total}] latent={latent} steer={sv} pos={pos_list}")
            try:
                if sv == 0:
                    # reuse baseline
                    logits = base_logits
                    probs = base_probs
                else:
                    results = steer_with_sae(
                        model_name=MODEL_NAME,
                        sae_model_path=SAE_MODEL_PATH,
                        latents_to_max=[latent],
                        latents_to_zero=[],
                        input_sequence=masked_seq,
                        layer_num=9,
                        steering_value=float(sv),
                        position_to_steer=pos_list,
                        steering_value_method="fixed", #"fixed",#alternative is "max_activation"
                        device=DEVICE,
                    )
                    logits_all = to_np(results["logits_steered"])
                    logits = logits_all[TOKEN_POSITION]
                    probs = softmax_np(logits)

                # top-5
                top5_idx = np.argsort(logits)[-5:][::-1]
                top5_logits = [float(logits[i]) for i in top5_idx]
                top5_probs  = [float(probs[i])  for i in top5_idx]
                top5_tokens = [tokenizer.convert_ids_to_tokens([int(i)])[0] for i in top5_idx]
                top1_id = int(top5_idx[0])

                # mutation & WT groups
                target_prob, target_best_rank = group_prob_and_rank(logits, probs, TARGET_IDS)
                wt_prob, wt_best_rank = group_prob_and_rank(logits, probs, WT_IDS) if WT_IDS else (math.nan, math.nan)

                # delta to baseline (useful for plots)
                delta_target_prob = float(target_prob - base_target_prob) if not math.isnan(target_prob) else math.nan
                delta_wt_prob = float(wt_prob - base_wt_prob) if not math.isnan(wt_prob) else math.nan

                row = dict(
                    latent=int(latent),
                    steering_value=float(sv),
                    position=str(pos_list),
                    position_len=len(pos_list),
                    position_category=classify_position_category(pos_list, TOKEN_POSITION),
                    token_position=int(TOKEN_POSITION),

                    # rank-1 identity
                    top1_id=top1_id,
                    top1_token=top5_tokens[0],
                    top1_logit=top5_logits[0],
                    top1_prob=top5_probs[0],
                    top1_is_target_group=top1_belongs_to(TARGET_IDS, top1_id),
                    top1_is_wt_group=(top1_belongs_to(WT_IDS, top1_id) if WT_IDS else False),

                    # mutation group metrics
                    target_group_prob=float(target_prob),
                    target_group_best_rank=float(target_best_rank),
                    delta_target_group_prob=float(delta_target_prob),

                    # optional WT group
                    wt_group_prob=float(wt_prob),
                    wt_group_best_rank=float(wt_best_rank),
                    delta_wt_group_prob=float(delta_wt_prob),

                    # first mutation token convenience
                    target_first_id=int(TARGET_IDS[0]),
                    target_first_logit=float(logits[TARGET_IDS[0]]),
                    target_first_prob=float(probs[TARGET_IDS[0]]),

                    # baseline for reference
                    baseline_target_group_prob=float(base_target_prob),
                    baseline_wt_group_prob=float(base_wt_prob),

                    # raw top5
                    top5_ids="|".join(str(int(i)) for i in top5_idx),
                    top5_tokens="|".join(top5_tokens),
                    top5_logits="|".join(f"{v:.6f}" for v in top5_logits),
                    top5_probs="|".join(f"{v:.6f}" for v in top5_probs),

                    error="",
                )
                records.append(row)

            except Exception as e:
                records.append(dict(
                    latent=int(latent),
                    steering_value=float(sv),
                    position=str(pos_list),
                    position_len=len(pos_list),
                    position_category=classify_position_category(pos_list, TOKEN_POSITION),
                    token_position=int(TOKEN_POSITION),
                    top1_id=-1,
                    top1_token="",
                    top1_logit=math.nan,
                    top1_prob=math.nan,
                    top1_is_target_group=False,
                    top1_is_wt_group=False,
                    target_group_prob=math.nan,
                    target_group_best_rank=math.nan,
                    delta_target_group_prob=math.nan,
                    wt_group_prob=math.nan,
                    wt_group_best_rank=math.nan,
                    delta_wt_group_prob=math.nan,
                    target_first_id=(int(TARGET_IDS[0]) if TARGET_IDS else -1),
                    target_first_logit=math.nan,
                    target_first_prob=math.nan,
                    baseline_target_group_prob=float(base_target_prob),
                    baseline_wt_group_prob=float(base_wt_prob),
                    top5_ids="",
                    top5_tokens="",
                    top5_logits="",
                    top5_probs="",
                    error=str(e),
                ))

df = pd.DataFrame.from_records(records)
df.to_csv(CSV_PATH, index=False)
print(f"[done] wrote {len(df)} rows -> {CSV_PATH}")

# provenance/config
cfg = {
    "model": MODEL_NAME,
    "sae_model_path": SAE_MODEL_PATH,
    "device": DEVICE,
    "sequence_length": len(masked_seq),
    "token_position": TOKEN_POSITION,
    "latents": LATENTS,
    "steering_values": STEERING_VALUES,
    "positions_to_steer": POSITIONS_TO_STEER,
    "target_tokens": TARGET_TOKENS,
    "wt_tokens": WT_TOKENS,
}
with open(CONFIG_JSON, "w") as f:
    json.dump(cfg, f, indent=2)
print(f"[done] wrote config -> {CONFIG_JSON}")


[baseline] computing baseline (no steering)…
Detected BatchTopKSAE format
Loading k=16 from state dict
Loading threshold=3.331448554992676 from state dict
Detected BatchTopKSAE format
Loading k=16 from state dict
Loading threshold=3.331448554992676 from state dict
[SAEWrapper] Initialized with BatchTopKSAE (k=16)
[1/1440] latent=5561 steer=-150 pos=[568]
Detected BatchTopKSAE format
Loading k=16 from state dict
Loading threshold=3.331448554992676 from state dict
Detected BatchTopKSAE format
Loading k=16 from state dict
Loading threshold=3.331448554992676 from state dict
[SAEWrapper] Initialized with BatchTopKSAE (k=16)


KeyboardInterrupt: 