In [1]:
# Import necessary libraries
from datasets import load_dataset
import re
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import json
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
import numpy as np
import evaluate
import heapq
from typing import List, Dict, Tuple
from collections import defaultdict
import sys
from typing import List
# from transformers import BertTokenizer, BertForSequenceClassification, pipeline

# Add project root to Python path
repo_root = Path(".").resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Import helper utilities from organized modules
import importlib
import utils.finbert
import utils.analysis
import utils.ablation
import utils.run_dirs
import sparse_autoencoder.finbert_sae

# Force reload to get latest code
importlib.reload(utils.finbert)
importlib.reload(utils.analysis)
importlib.reload(utils.ablation)
importlib.reload(utils.run_dirs)
importlib.reload(sparse_autoencoder.finbert_sae)

from utils.finbert import compute_metrics
from utils.analysis import (
    FeatureStatsAggregator,
    FeatureTopTokenTracker,
    HeadlineFeatureAggregator
)
from utils.ablation import create_intervention_hook
from utils.run_dirs import make_analysis_run_dir
from sparse_autoencoder.finbert_sae import SparseAutoencoder, load_sae

# --------- CUDA sanity check ----------
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

# Define device for SAE loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


Torch: 2.6.0+cu124
CUDA available: True
GPU: NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [2]:
# 1. Load dataset, comes with train and validation fold 
ds = load_dataset("zeroshot/twitter-financial-news-sentiment")

def clean_text(text):
    # remove URLs
    text = re.sub(r"http\S+", "", text)
    # normalize whitespace
    text = re.sub(r"\s+", " ", text).strip()
    return text

def remove_leading_tickers(text):
    return re.sub(
        r'^\s*(?:\$[A-Z]{1,6}\s*)+(?:[-:]\s*)?',
        '',
        text
    )

# Clean dataset
ds = ds.map(lambda x: {"text": clean_text(x["text"])})
ds = ds.map(lambda x: {"text": remove_leading_tickers(x["text"])})

# Load dataset
train_ds = ds["train"]
test_ds = ds["validation"]  # Use validation set for analysis


In [3]:
# Constants/Hyperparameters for training model and SAE
LAYER_TO_EXTRACT = 8  # 3/4 layer of BERT (0-11 for base BERT)
LATENT_DIMS = [4096, 8192, 16384, 32768]  # Train SAEs with 4k, 8k, 16k, 32k features
L1_COEFFICIENT = 1e-3  # Sparsity penalty
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
NUM_EPOCHS = 3


# Configuration for Inference
MAX_SAMPLES = 100  # Limit for testing
TOP_FEATURES = 100  # Top features to track per metric
TOP_TOKENS_PER_FEATURE = 20  # Top activating tokens per feature
MAX_SEQ_LENGTH = 64  # Maximum sequence length to process
SAE_SIZE = "32k"  # <-- Change this to switch between SAE models, Choose which SAE to use: "4k", "8k", "16k", or "32k"

Fine Tune Hyperparameters of the model

This trains an SAE to decompose FinBERT's 768-dimensional activations into ~4k to 32k interpretable sparse features.


In [None]:
# This cell finetunes SAEs based on BERT.
# Configuration
LAYER_TO_EXTRACT = 8  # Middle layer of BERT
LATENT_DIMS = [4096, 8192, 16384, 32768]  # Train SAEs with 4k, 8k, 16k, 32k features
L1_COEFFICIENT = 1e-3  # Sparsity penalty
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
NUM_EPOCHS = 3

# Create SAE save directory
Path("./finbert_sae").mkdir(exist_ok=True)

# Load the fine-tuned model
save_dir = "./finbert_twitter_ft/best"
tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Load dataset
# train_ds = ds["train"]

print(f"Collecting activations from {len(train_ds)} training samples...")
print(f"Target layer: {LAYER_TO_EXTRACT}")
print(f"Will train SAEs with latent dimensions: {LATENT_DIMS}")

# Collect training activations
all_activations = []
captured_activations = []

def capture_hook(module, input, output):
    if isinstance(output, tuple):
        hidden_states = output[0]
    else:
        hidden_states = output
    captured_activations.append(hidden_states.detach())  # Keep on GPU

# Register hook
target_layer = model.bert.encoder.layer[LAYER_TO_EXTRACT]
hook_handle = target_layer.register_forward_hook(capture_hook)

# Collect activations from all training data
print("Extracting activations from training set...")
print("Filtering out ALL special tokens (CLS, SEP, PAD, UNK, MASK, etc.) - keeping only content tokens...")
with torch.no_grad():
    for idx, sample in enumerate(tqdm(train_ds)):
        text = sample["text"]
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
        inputs = inputs.to(device)
        
        captured_activations.clear()
        _ = model(**inputs)
        
        if captured_activations:
            # Get all token activations: [seq_len, 768] - stays on GPU
            activation = captured_activations[0].squeeze(0)
            
            # Get attention mask and token IDs (keep on GPU)
            attention_mask = inputs["attention_mask"].squeeze(0).bool()
            token_ids = inputs["input_ids"].squeeze(0)
            
            # Filter out ALL special tokens (CLS, SEP, PAD, UNK, MASK, etc.)
            special_ids = set(tokenizer.all_special_ids)
            not_special = torch.tensor([tid.item() not in special_ids for tid in token_ids], 
                                       dtype=torch.bool, device=device)
            
            valid_mask = attention_mask & not_special  # GPU boolean mask

            # Print the number of valid tokens
            # kept = valid_mask.sum().item()
            # total = attention_mask.sum().item()
            # print(f"Kept {kept}/{total} tokens")

            # tokens = tokenizer.convert_ids_to_tokens(token_ids)
            # kept_tokens = [t for t, m in zip(tokens, valid_mask.tolist()) if m]
            # dropped_tokens = [t for t, m in zip(tokens, valid_mask.tolist()) if not m]

            # print("TOKENS:", tokens)
            # print("DROPPED:", dropped_tokens)
            # print("KEPT:", kept_tokens)
            
            # Only keep activations for real content tokens (still on GPU)
            activation = activation[valid_mask]
            
            # Only add if there are real tokens
            if activation.shape[0] > 0:
                # Move to CPU only when storing for later processing
                all_activations.append(activation.cpu())

hook_handle.remove()

# Flatten all activations into a single tensor [total_tokens, 768]
all_activations_tensor = torch.cat(all_activations, dim=0)
print(f"\\nCollected {all_activations_tensor.shape[0]} token activations")
print(f"Activation shape: {all_activations_tensor.shape}")

# Train SAEs for each latent dimension
for LATENT_DIM in LATENT_DIMS:
    print(f"\\n{'='*80}")
    print(f"Training SAE with {LATENT_DIM} latent features ({LATENT_DIM//1024}k)")
    print(f"{'='*80}")
    
    # Create SAE
    sae = SparseAutoencoder(input_dim=768, latent_dim=LATENT_DIM)
    sae.to(device)
    
    # Optimizer
    optimizer = optim.Adam(sae.parameters(), lr=LEARNING_RATE)
    
    # Create DataLoader
    from torch.utils.data import TensorDataset, DataLoader
    dataset = TensorDataset(all_activations_tensor)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Training loop
    print(f"\\nTraining SAE for {NUM_EPOCHS} epochs...")
    sae.train()
    
    for epoch in range(NUM_EPOCHS):
        total_loss = 0
        total_recon_loss = 0
        total_l1_loss = 0
        
        for batch_idx, (batch_x,) in enumerate(dataloader):
            batch_x = batch_x.to(device)
            
            # Forward pass
            reconstruction, latent = sae(batch_x)
            
            # Reconstruction loss (MSE)
            recon_loss = nn.functional.mse_loss(reconstruction, batch_x)
            
            # L1 sparsity loss
            l1_loss = latent.abs().mean()
            
            # Combined loss
            loss = recon_loss + L1_COEFFICIENT * l1_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Renormalize decoder weights (standard SAE practice)
            with torch.no_grad():
                sae.decoder.weight.data = nn.functional.normalize(
                    sae.decoder.weight.data, dim=0
                )
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_l1_loss += l1_loss.item()
        
        avg_loss = total_loss / len(dataloader)
        avg_recon = total_recon_loss / len(dataloader)
        avg_l1 = total_l1_loss / len(dataloader)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}: Loss={avg_loss:.4f}, "
              f"Recon={avg_recon:.4f}, L1={avg_l1:.4f}")
    
    # Save the trained SAE
    SAE_SAVE_PATH = f"./finbert_sae/layer_{LAYER_TO_EXTRACT}_{LATENT_DIM//1024}k.pt"
    print(f"\\nSaving trained SAE to {SAE_SAVE_PATH}")
    torch.save({
        'encoder_weight': sae.encoder.weight.data.cpu(),
        'encoder_bias': sae.encoder.bias.data.cpu(),
        'decoder_weight': sae.decoder.weight.data.cpu(),
        'decoder_bias': sae.decoder.bias.data.cpu(),
        'config': {
            'input_dim': 768,
            'latent_dim': LATENT_DIM,
            'layer': LAYER_TO_EXTRACT,
            'model': save_dir,
        }
    }, SAE_SAVE_PATH)
    
    # Test sparsity
    sae.eval()
    with torch.no_grad():
        sample_acts = all_activations_tensor[:1000].to(device)
        sample_latent = sae.encode(sample_acts)
        sparsity = (sample_latent > 0).float().mean()
        print(f"\\n‚úì SAE trained successfully!")
        print(f"  Average sparsity: {sparsity:.2%} of features active")
        print(f"  Saved to: {SAE_SAVE_PATH}")

print(f"\\n{'='*80}")
print(f"All SAEs trained successfully!")
print(f"Available SAE models:")
for dim in LATENT_DIMS:
    print(f"  - layer_{LAYER_TO_EXTRACT}_{dim//1024}k.pt ({dim} features)")
print(f"\\nThese SAEs can now be used in main.py for interpretability analysis!")
print(f"{'='*80}")


Finetune FinBERT Model

The FinBERT model is trained on the training fold of our dataset to improve its prediction accuracy.


In [None]:
# This cell finetunes the FINBERT model.

# 2) Load model/tokenizer
model_name = "ahmedrachid/FinancialBERT-Sentiment-Analysis"
tokenizer = AutoTokenizer.from_pretrained(model_name)

id2label = {0: "Bearish", 1: "Bullish", 2: "Neutral"}
label2id = {v: k for k, v in id2label.items()}

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=3,
    id2label=id2label,
    label2id=label2id,
)

# Move model to GPU
model.to(device)

# 3) Tokenize
def tokenize_fn(batch):
    return tokenizer(batch["text"], truncation=True)

train_tok = train_ds.map(tokenize_fn, batched=True)
val_tok = test_ds.map(tokenize_fn, batched=True)

train_tok = train_tok.rename_column("label", "labels")
val_tok = val_tok.rename_column("label", "labels")

cols_to_keep = ["input_ids", "attention_mask", "labels"]
train_tok.set_format(type="torch", columns=cols_to_keep)
val_tok.set_format(type="torch", columns=cols_to_keep)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 4) Metrics
acc = evaluate.load("accuracy")
f1 = evaluate.load("f1")

# 5) Training config
use_fp16 = torch.cuda.is_available()  # fp16 only makes sense on GPU

training_args = TrainingArguments(
    output_dir="./finbert_twitter_ft",
    eval_strategy="epoch",   # <-- use this name; some versions don't accept eval_strategy
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",
    fp16=use_fp16,                 # <-- enables mixed precision on NVIDIA GPU
    dataloader_num_workers=0,      # safer on Windows; avoids hanging
    report_to="none",              # avoids needing wandb, etc.
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.evaluate()

trainer.save_model("./finbert_twitter_ft/best")
tokenizer.save_pretrained("./finbert_twitter_ft/best")


Inference with Interpretability

We use our FinBERT + SAE on test data. We extract a Layer Activations with Sentiment Predictions (SAE-style Analysis).

In [7]:
# Feature Ablation Experiment
# This cell performs ablation by zeroing out specified SAE features and comparing predictions

# Configuration
# ========== ABLATION CONFIGURATION ==========
ABLATION_CONFIG = {
    #"mode": "per_sample_top_k",  # Options: "manual" | "per_sample_top_k" | "union_top_k"
    "mode": "union_top_k",
    "k": 10  # Only for per_sample_top_k and union_top_k modes
}

# Mode 1: Manual features (only used if mode == "manual")
MANUAL_FEATURES = [4456, 21508, 21969, 27518, 21110, 24583, 32601, 15959, 27518, 29555, 3993, 13142, 22354, 21858]

print("=" * 60)
print("FEATURE ABLATION EXPERIMENT")
print("=" * 60)
print(f"Ablation Mode: {ABLATION_CONFIG['mode']}")
if ABLATION_CONFIG['mode'] != 'manual':
    print(f"K value: {ABLATION_CONFIG['k']}")

# Load model and tokenizer
save_dir = "./finbert_twitter_ft/best"
tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

# Define device and move model to it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu only, please install CUDA-compatible Torch")
model.to(device)
model.eval()

# Load the SAE using the helper function
sae, sae_config = load_sae(layer=LAYER_TO_EXTRACT, latent_size=SAE_SIZE)

# Extract dimensions from the loaded config
SAE_INPUT_DIM = sae_config['input_dim']
SAE_LATENT_DIM = sae_config['latent_dim']

print(f"Ablation mode: {ABLATION_CONFIG['mode']}")
print(f"Layer: {LAYER_TO_EXTRACT}")
print(f"SAE Size: {SAE_SIZE} ({SAE_LATENT_DIM} features)")
print(f"Max Samples: {MAX_SAMPLES}\n")

# Storage for results
baseline_predictions = []
ablated_predictions = []
sample_data = []

# Initialize trackers for SAE features (same as inference cell)
feature_stats_ablated = FeatureStatsAggregator(SAE_LATENT_DIM)
top_token_tracker_ablated = FeatureTopTokenTracker(SAE_LATENT_DIM, TOP_TOKENS_PER_FEATURE)
headline_aggregator_ablated = HeadlineFeatureAggregator(top_k=10)
all_prompt_metadata_ablated = []

# Storage for capturing SAE features during ablation (for tracking)
current_sample_data = {"sae_features": None, "token_ids": None, "prompt_tokens": None, "text": None, "idx": None}

# Run baseline inference (no ablation)
print("üî¨ Running baseline inference (no ablation)...")
baseline_results = []
baseline_features_map = {}  # Store baseline SAE features for comparison

with torch.no_grad():
    for idx, sample in enumerate(test_ds):
        if idx >= MAX_SAMPLES:
            break
        
        text = sample["text"]
        true_label = sample["label"]
        
        # Tokenize
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
        inputs = inputs.to(device)
        
        # Forward pass (normal, no intervention)
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)
        
        pred_id = logits.argmax(dim=-1).item()
        pred_label = model.config.id2label[pred_id]
        confidence = probs[0, pred_id].item()
        
        baseline_results.append({
            "sample_idx": idx,
            "text": text,
            "true_label": model.config.id2label[true_label],
            "predicted_label": pred_label,
            "predicted_id": pred_id,
            "confidence": confidence,
            "logits": logits.cpu().numpy(),
            "probs": probs.cpu().numpy()
        })
        
        # Extract baseline SAE features for ablation comparison
        # Capture activations from target layer
        captured_acts = []
        def capture_hook(module, input, output):
            if isinstance(output, tuple):
                captured_acts.append(output[0].detach())
            else:
                captured_acts.append(output.detach())
        
        target_layer_baseline = model.bert.encoder.layer[LAYER_TO_EXTRACT]
        temp_hook = target_layer_baseline.register_forward_hook(capture_hook)
        with torch.no_grad():
            _ = model(**inputs)
        temp_hook.remove()
        
        if captured_acts:
            bert_activation = captured_acts[0].squeeze(0)
            
            # Filter special tokens
            attention_mask = inputs["attention_mask"].squeeze(0).bool()
            token_ids_tensor = inputs["input_ids"].squeeze(0)
            special_ids = set(tokenizer.all_special_ids)
            not_special = torch.tensor([tid.item() not in special_ids for tid in token_ids_tensor], 
                                       dtype=torch.bool, device=device)
            valid_mask = attention_mask & not_special
            bert_activation = bert_activation[valid_mask]
            
            if bert_activation.shape[0] > 0:
                # Get SAE features
                sae_features = sae.encode(bert_activation)
                sae_features_cpu = sae_features.detach().cpu().numpy()
                
                # Get max activation per feature across all tokens
                max_activations_per_feature = sae_features_cpu.max(axis=0)
                
                # Get top 10 features
                top_10_indices = np.argsort(max_activations_per_feature)[-10:][::-1]
                top_features = [
                    {
                        "feature_id": int(fid), 
                        "activation": float(max_activations_per_feature[fid])
                    }
                    for fid in top_10_indices
                ]
                total_activation = sum(feat["activation"] for feat in top_features)
                
                baseline_features_map[idx] = {
                    "top_features": top_features,
                    "total_activation": total_activation
                }
            else:
                baseline_features_map[idx] = {
                    "top_features": [],
                    "total_activation": 0.0
                }
        else:
            baseline_features_map[idx] = {
                "top_features": [],
                "total_activation": 0.0
            }
        
        if (idx + 1) % 20 == 0:
            print(f"  Baseline: {idx + 1}/{min(MAX_SAMPLES, len(test_ds))} samples")

baseline_accuracy = sum(1 for r in baseline_results if r["predicted_id"] == test_ds[r["sample_idx"]]["label"]) / len(baseline_results)
print(f"‚úì Baseline accuracy: {baseline_accuracy:.2%}\n")

# Determine features to ablate based on mode
if ABLATION_CONFIG["mode"] == "manual":
    FEATURES_TO_ABLATE = MANUAL_FEATURES
    print(f"\nMode 1 (Manual): Ablating {len(FEATURES_TO_ABLATE)} manually specified features")
elif ABLATION_CONFIG["mode"] == "union_top_k":
    feature_set = set()
    for idx in baseline_features_map:
        top_k_ids = [f['feature_id'] for f in baseline_features_map[idx]['top_features'][:ABLATION_CONFIG["k"]]]
        feature_set.update(top_k_ids)
    FEATURES_TO_ABLATE = sorted(list(feature_set))
    print(f"\nMode 3 (Union Top-K): Collected {len(FEATURES_TO_ABLATE)} unique features from union of top-{ABLATION_CONFIG['k']} across {len(baseline_results)} samples")
elif ABLATION_CONFIG["mode"] == "per_sample_top_k":
    FEATURES_TO_ABLATE = None
    print(f"\nMode 2 (Per-Sample Top-K): Will ablate top-{ABLATION_CONFIG['k']} features individually for each sample")
else:
    raise ValueError(f"Unknown ablation mode: {ABLATION_CONFIG['mode']}")

# Verify features are within valid range
if ABLATION_CONFIG["mode"] != "per_sample_top_k":
    if any(fid < 0 or fid >= SAE_LATENT_DIM for fid in FEATURES_TO_ABLATE):
        invalid = [fid for fid in FEATURES_TO_ABLATE if fid < 0 or fid >= SAE_LATENT_DIM]
        raise ValueError(f"Invalid feature IDs (must be 0-{SAE_LATENT_DIM-1}): {invalid}")

print(f"Features to ablate: {FEATURES_TO_ABLATE if FEATURES_TO_ABLATE else 'Per-sample dynamic'}\n")

# Register intervention hook (Mode 1 & 3: global hook)
if ABLATION_CONFIG["mode"] != "per_sample_top_k":
    target_layer = model.bert.encoder.layer[LAYER_TO_EXTRACT]
    intervention_hook = create_intervention_hook(sae, FEATURES_TO_ABLATE, device, current_sample_data)
    hook_handle = target_layer.register_forward_hook(intervention_hook)

# Run ablation inference
print("üî¨ Running ablation inference (features zeroed)...")
ablated_results = []

baseline_lookup = {r["sample_idx"]: r for r in baseline_results}

with torch.no_grad():
    for idx, sample in enumerate(test_ds):
        if idx >= MAX_SAMPLES:
            break
        
        text = sample["text"]
        true_label = sample["label"]
        
        # Tokenize
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
        token_ids = inputs["input_ids"][0].tolist()
        
        # Get string tokens for display (properly cleaned)
        raw_tokens = tokenizer.convert_ids_to_tokens(token_ids)
        prompt_tokens = []
        for tok in raw_tokens:
            if tok.startswith("##"):
                prompt_tokens.append(tok[2:])  # Remove ##
            else:
                prompt_tokens.append(tok)
        
        inputs = inputs.to(device)

        # Mode 2: Register per-sample hook
        if ABLATION_CONFIG["mode"] == "per_sample_top_k":
            features_to_ablate_sample = [
                f['feature_id']
                for f in baseline_features_map[idx]['top_features'][:ABLATION_CONFIG["k"]]
            ]
            target_layer = model.bert.encoder.layer[LAYER_TO_EXTRACT]
            intervention_hook = create_intervention_hook(
                sae, features_to_ablate_sample, device, current_sample_data
            )
            hook_handle = target_layer.register_forward_hook(intervention_hook)

        # Clear sample data
        current_sample_data["sae_features"] = None
        current_sample_data["token_ids"] = token_ids
        current_sample_data["prompt_tokens"] = prompt_tokens
        current_sample_data["text"] = text
        current_sample_data["idx"] = idx

        # Forward pass with intervention (features ablated)
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)

        # Mode 2: Remove per-sample hook
        if ABLATION_CONFIG["mode"] == "per_sample_top_k":
            hook_handle.remove()

        pred_id = logits.argmax(dim=-1).item()
        pred_label = model.config.id2label[pred_id]
        confidence = probs[0, pred_id].item()

        # Track which features were ablated for this sample
        if ABLATION_CONFIG["mode"] == "per_sample_top_k":
            features_ablated_for_this_sample = [
                f["feature_id"]
                for f in baseline_features_map[idx]["top_features"][:ABLATION_CONFIG["k"]]
            ]
        else:
            features_ablated_for_this_sample = FEATURES_TO_ABLATE

        ablated_results.append({
                    "sample_idx": idx,
                    "text": text,
                    "true_label": model.config.id2label[true_label],
                    "predicted_label": pred_label,
                    "predicted_id": pred_id,
                    "confidence": confidence,
                    "logits": logits.detach().cpu().numpy(),
                    "probs": probs.detach().cpu().numpy(),
                    "ablated_features": features_ablated_for_this_sample
                })

        # Track SAE features (after ablation) if we captured them
        if current_sample_data["sae_features"] is not None:
            sae_features_cpu = current_sample_data["sae_features"].cpu().numpy()
            
            # Filter special tokens (same as inference cell)
            attention_mask = inputs["attention_mask"].squeeze(0).bool().cpu().numpy()
            token_ids_tensor = inputs["input_ids"].squeeze(0).cpu().numpy()
            special_ids = set(tokenizer.all_special_ids)
            not_special = np.array([tid not in special_ids for tid in token_ids_tensor])
            valid_mask = attention_mask & not_special
            
            # Filter features and tokens
            sae_features_filtered = sae_features_cpu[valid_mask]
            filtered_token_ids = [tid for tid, valid in zip(token_ids, valid_mask) if valid]
            filtered_prompt_tokens = [tok for tok, valid in zip(prompt_tokens, valid_mask) if valid]
            
            if sae_features_filtered.shape[0] > 0:
                seq_len = sae_features_filtered.shape[0]
                
                # Update feature statistics
                feature_stats_ablated.update(sae_features_filtered)
                
                # Track top tokens per feature
                top_token_tracker_ablated.update(
                    sae_features_filtered,
                    filtered_token_ids,
                    prompt_idx=idx,
                    prompt_text=text,
                    prompt_tokens=filtered_prompt_tokens,
                    predicted_label=pred_label,
                    true_label=model.config.id2label[true_label]
                )
                
                # Aggregate top features at headline level
                baseline_data = baseline_lookup[idx]
                if ABLATION_CONFIG["mode"] == "per_sample_top_k":
                    features_for_tracking = [
                        f['feature_id']
                        for f in baseline_features_map[idx]['top_features'][:ABLATION_CONFIG["k"]]
                    ]
                else:
                    features_for_tracking = FEATURES_TO_ABLATE

                headline_aggregator_ablated.add_headline_with_ablation_metrics(
                    prompt_idx=idx,
                    prompt_text=text,
                    token_activations=sae_features_filtered,
                    token_ids=filtered_token_ids,
                    token_strings=filtered_prompt_tokens,
                    predicted_label=pred_label,
                    true_label=model.config.id2label[true_label],
                    confidence=confidence,
                    baseline_features=baseline_features_map[idx],
                    features_to_ablate=features_for_tracking,
                    baseline_prediction=baseline_data["predicted_label"],
                    baseline_confidence=baseline_data["confidence"]
                )
                
                # Save prompt metadata
                all_prompt_metadata_ablated.append({
                    "row_id": idx,
                    "seq_len": seq_len,
                    "prompt": text,
                    "predicted_label": pred_label,
                    "true_label": model.config.id2label[true_label],
                    "correct": pred_id == true_label
                })

if (idx + 1) % 20 == 0:
    print(f"  Ablated: {idx + 1}/{min(MAX_SAMPLES, len(test_ds))} samples")

# Remove hook (Mode 1 & 3 only, Mode 2 removes per-sample)
if ABLATION_CONFIG["mode"] != "per_sample_top_k":
    hook_handle.remove()

ablated_accuracy = sum(1 for r in ablated_results if r["predicted_id"] == test_ds[r["sample_idx"]]["label"]) / len(ablated_results)
print(f"‚úì Ablated accuracy: {ablated_accuracy:.2%}\n")

# Compare results and find flipped predictions
flipped_samples = []
for baseline, ablated in zip(baseline_results, ablated_results):
    if baseline["predicted_id"] != ablated["predicted_id"]:
        # Get top SAE features for this sample (from baseline run)
        # We need to capture activations for this sample
        # For now, we'll compute them on-the-fly
        
        # Tokenize and get activations
        inputs = tokenizer(baseline["text"], return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
        inputs = inputs.to(device)
        
        # Capture activations
        captured_acts = []
        def capture_hook(module, input, output):
            if isinstance(output, tuple):
                captured_acts.append(output[0].detach())
            else:
                captured_acts.append(output.detach())
        
        temp_hook = target_layer.register_forward_hook(capture_hook)
        with torch.no_grad():
            _ = model(**inputs)
        temp_hook.remove()
        
        if captured_acts:
            bert_activation = captured_acts[0].squeeze(0)
            
            # Filter special tokens
            attention_mask = inputs["attention_mask"].squeeze(0).bool()
            token_ids_tensor = inputs["input_ids"].squeeze(0)
            special_ids = set(tokenizer.all_special_ids)
            not_special = torch.tensor([tid.item() not in special_ids for tid in token_ids_tensor], 
                                       dtype=torch.bool, device=device)
            valid_mask = attention_mask & not_special
            bert_activation = bert_activation[valid_mask]
            
            if bert_activation.shape[0] > 0:
                # Get SAE features
                sae_features = sae.encode(bert_activation)
                sae_features_cpu = sae_features.detach().cpu().numpy()
                
                # Get max activation per feature across all tokens
                max_activations_per_feature = sae_features_cpu.max(axis=0)
                
                # Get top 10 features
                top_10_indices = np.argsort(max_activations_per_feature)[-10:][::-1]
                ablated_features_for_sample = ablated.get("ablated_features") or []
                top_features = [
                    {"feature_id": int(fid), "activation": float(max_activations_per_feature[fid]), 
                     "ablated": fid in ablated_features_for_sample}
                    for fid in top_10_indices
                ]
            else:
                top_features = []
        else:
            top_features = []
        
        flipped_samples.append({
            "sample_idx": baseline["sample_idx"],
            "text": baseline["text"],
            "true_label": baseline["true_label"],
            "baseline_pred": baseline["predicted_label"],
            "baseline_conf": baseline["confidence"],
            "ablated_pred": ablated["predicted_label"],
            "ablated_conf": ablated["confidence"],
            "top_features": top_features
        })

# Print results
print("=" * 60)
print("FEATURE ABLATION RESULTS")
print("=" * 60)
print(f"Ablation Mode: {ABLATION_CONFIG['mode']}")
if FEATURES_TO_ABLATE is not None:
    print(f"Ablated Features: {FEATURES_TO_ABLATE}")
else:
    print(f"Ablated Features: Per-sample top-{ABLATION_CONFIG['k']}")
print(f"Baseline Accuracy: {baseline_accuracy:.2%}")
print(f"Ablated Accuracy: {ablated_accuracy:.2%}")
print(f"Accuracy Change: {(ablated_accuracy - baseline_accuracy):.2%}")
print(f"\nFlipped Predictions: {len(flipped_samples)}/{len(baseline_results)} samples")
print(f"Flip Rate: {len(flipped_samples)/len(baseline_results):.2%}\n")

if flipped_samples:
    print("=" * 60)
    print(f"FLIPPED PREDICTIONS (showing first {min(10, len(flipped_samples))}):")
    print("=" * 60)
    
    for i, flip in enumerate(flipped_samples[:10], 1):
        print(f"\n--- Sample #{flip['sample_idx']} ---")
        print(f"Text: {flip['text'][:120]}{'...' if len(flip['text']) > 120 else ''}")
        print(f"True Label: {flip['true_label']}")
        print(f"Original: {flip['baseline_pred']} (conf: {flip['baseline_conf']:.3f}) ‚Üí "
              f"Ablated: {flip['ablated_pred']} (conf: {flip['ablated_conf']:.3f})")
        
        if flip['top_features']:
            print("Top 10 SAE Features:")
            for feat in flip['top_features']:
                ablated_marker = " [ABLATED]" if feat['ablated'] else ""
                print(f"  Feature {feat['feature_id']}: {feat['activation']:.4f}{ablated_marker}")
        print()
else:
    print("No predictions were flipped by ablating these features.")

# Save ablated results in the same format as inference cell
print("\nüíæ Saving ablated results for visualization...")

# Create run directory for ablated results (use directly, don't create separate folder)
ablated_run_dir = make_analysis_run_dir(str(repo_root))
print(f"üíæ Saving ablated results to: {ablated_run_dir}")

# Compute final statistics for ablated run
stats_ablated = feature_stats_ablated.get_stats()

# Get top features for each metric
top_features_by_metric_ablated = {}
for metric_name, values in stats_ablated.items():
    if metric_name == "mean_act_squared":
        continue
    top_indices = np.argsort(values)[-TOP_FEATURES:][::-1]
    top_features_by_metric_ablated[metric_name] = [
        {
            "feature_id": int(idx),
            "value": float(values[idx]),
            "metrics": {
                "mean_activation": float(stats_ablated["mean_activation"][idx]),
                "max_activation": float(stats_ablated["max_activation"][idx]),
                "fraction_active": float(stats_ablated["fraction_active"][idx])
            }
        }
        for idx in top_indices
    ]

# 1. Save prompts metadata
prompts_file = ablated_run_dir / "prompts.jsonl"
with open(prompts_file, "w", encoding="utf-8") as f:
    for meta in all_prompt_metadata_ablated:
        json.dump(meta, f)
        f.write("\n")

# 2. Save feature statistics
feature_stats_file = ablated_run_dir / "feature_stats.json"
feature_stats_data = {
    "num_features": SAE_LATENT_DIM,
    "total_tokens": feature_stats_ablated.total_tokens,
    "top_feature_count": TOP_FEATURES,
    "accuracy": ablated_accuracy,
    "num_samples": len(all_prompt_metadata_ablated),
    "mean_act_squared": stats_ablated["mean_act_squared"].tolist(),
    "metrics": {
        metric_name: {
            "description": f"{metric_name.replace('_', ' ').title()} for each feature",
            "top_features": top_features_by_metric_ablated[metric_name]
        }
        for metric_name in stats_ablated.keys() if metric_name != "mean_act_squared"
    }
}
with open(feature_stats_file, "w") as f:
    json.dump(feature_stats_data, f, indent=2)

# 3. Save top tokens per feature
feature_tokens_file = ablated_run_dir / "feature_tokens.json"
feature_tokens_data = {
    "features": top_token_tracker_ablated.export()
}
with open(feature_tokens_file, "w") as f:
    json.dump(feature_tokens_data, f, indent=2)

# 4. Save headline-level features
headline_features_file = ablated_run_dir / "headline_features.json"
with open(headline_features_file, "w") as f:
    json.dump(headline_aggregator_ablated.export(), f, indent=2)

# 5. Save metadata
metadata_file = ablated_run_dir / "metadata.json"
with open(metadata_file, "w") as f:
    json.dump({
        "model": save_dir,
        "layer_extracted": LAYER_TO_EXTRACT,
        "num_samples": len(all_prompt_metadata_ablated),
        "total_tokens": feature_stats_ablated.total_tokens,
        "accuracy": ablated_accuracy,
        "dataset": "zeroshot/twitter-financial-news-sentiment",
        "split": "validation",
        "hidden_dim": SAE_INPUT_DIM,
        "latent_dim": SAE_LATENT_DIM,
        "sae_path": f"./finbert_sae/layer_{LAYER_TO_EXTRACT}_{SAE_SIZE}.pt",
        "top_features_per_metric": TOP_FEATURES,
        "top_tokens_per_feature": TOP_TOKENS_PER_FEATURE,
        "ablation_mode": ABLATION_CONFIG["mode"],
        "ablated_features": FEATURES_TO_ABLATE if FEATURES_TO_ABLATE else "per_sample_dynamic",
        "ablation_k": ABLATION_CONFIG.get("k"),
        "note": f"SAE sparse features with predictions (mode: {ABLATION_CONFIG['mode']})"
    }, f, indent=2)

print(f"\n‚úÖ Ablation experiment complete!")
print(f"   üìÅ Ablated results saved to: {ablated_run_dir.name}")
print(f"   üéØ Ablated Accuracy: {ablated_accuracy:.2%}")
print(f"   üî¢ Total tokens: {feature_stats_ablated.total_tokens}")
print(f"   ‚ú® SAE features: {SAE_LATENT_DIM}")
print(f"\nüåê Start the viewer to see ablated results:")
print(f"   python viz_analysis/feature_probe_server.py")
print(f"   cd sae-viewer && npm start")


FEATURE ABLATION EXPERIMENT
Ablation Mode: union_top_k
K value: 10
‚úì Loaded SAE from ./finbert_sae/layer_8_32k.pt
  Layer: 8
  Input dim: 768
  Latent dim: 32768
Ablation mode: union_top_k
Layer: 8
SAE Size: 32k (32768 features)
Max Samples: 100

üî¨ Running baseline inference (no ablation)...
  Baseline: 20/100 samples
  Baseline: 40/100 samples
  Baseline: 60/100 samples
  Baseline: 80/100 samples
  Baseline: 100/100 samples
‚úì Baseline accuracy: 87.00%


Mode 3 (Union Top-K): Collected 91 unique features from union of top-10 across 100 samples
Features to ablate: [423, 550, 602, 687, 2976, 3715, 3993, 4083, 4247, 4456, 5111, 6026, 6555, 6793, 6977, 7673, 7776, 7828, 7927, 8051, 8232, 8262, 8784, 9368, 9395, 9718, 9847, 10033, 10604, 11814, 12053, 12193, 12514, 13142, 13185, 13644, 14807, 15370, 15540, 15818, 15959, 15991, 16205, 16393, 16907, 17220, 17323, 17433, 17639, 18188, 18291, 18317, 18425, 18464, 19876, 20245, 20268, 20637, 21110, 21508, 21969, 22130, 22354, 22992, 23041

In [None]:
# Inference (non refactored)
import os
import json
from pathlib import Path
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import heapq
from typing import List, Tuple
import sys



# Start of Inference
# Add project root to path to import utilities
repo_root = Path(".").resolve()
if str(repo_root / "sparse_autoencoder") not in sys.path:
    sys.path.insert(0, str(repo_root / "sparse_autoencoder"))

from utils.run_dirs import make_analysis_run_dir

# Configuration
LAYER_TO_EXTRACT = 8  # 3/4 layer of BERT (0-11 for base BERT)
MAX_SAMPLES = 100  # Limit for testing
TOP_FEATURES = 100  # Top features to track per metric
TOP_TOKENS_PER_FEATURE = 20  # Top activating tokens per feature
MAX_SEQ_LENGTH = 64  # Maximum sequence length to process
SAE_SIZE = "32k"  # <-- Change this to switch between SAE models, Choose which SAE to use: "4k", "8k", "16k", or "32k"

print("=" * 60)
print("EXTRACTING SAE FEATURES FROM FINBERT")
print("=" * 60)

# Load the SAE using the helper function
sae, sae_config = load_sae(layer=LAYER_TO_EXTRACT, latent_size=SAE_SIZE)

# Extract dimensions from the loaded config
SAE_INPUT_DIM = sae_config['input_dim']
SAE_LATENT_DIM = sae_config['latent_dim']

print(f"‚úì SAE loaded: {SAE_INPUT_DIM} dims ‚Üí {SAE_LATENT_DIM} sparse features")

# Create run directory using the same utility as main.py
# This ensures the server can find it automatically in analysis_data/
run_dir = make_analysis_run_dir(str(repo_root))
print(f"\nüíæ Saving results to: {run_dir}")

# Load model and tokenizer
save_dir = "./finbert_twitter_ft/best"
tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
sae.to(device)
model.eval()
sae.eval()

# Load dataset
test_ds = ds["validation"]  # Use validation set for analysis

# Feature statistics tracker (per-token aggregation)
class FeatureStatsAggregator:
    def __init__(self, feature_dim: int):
        self.feature_dim = feature_dim
        self.total_tokens = 0
        self.sum_activations = np.zeros(feature_dim, dtype=np.float64)
        self.max_activations = np.zeros(feature_dim, dtype=np.float64)
        self.nonzero_counts = np.zeros(feature_dim, dtype=np.float64)
        self.sum_of_squares = np.zeros(feature_dim, dtype=np.float64)  # Track squared activations
    
    def update(self, token_activations: np.ndarray):
        """Update with activations from tokens [num_tokens, feature_dim]"""
        self.total_tokens += token_activations.shape[0]
        self.sum_activations += token_activations.sum(axis=0)
        self.max_activations = np.maximum(self.max_activations, token_activations.max(axis=0))
        self.nonzero_counts += (token_activations > 0).sum(axis=0)
        self.sum_of_squares += (token_activations ** 2).sum(axis=0)  # Accumulate squared values
    
    def get_stats(self):
        mean_act = self.sum_activations / max(self.total_tokens, 1)
        frac_active = self.nonzero_counts / max(self.total_tokens, 1)
        mean_act_squared = self.sum_of_squares / max(self.total_tokens, 1)
        return {
            "mean_activation": mean_act,
            "max_activation": self.max_activations,
            "fraction_active": frac_active,
            "mean_act_squared": mean_act_squared
        }

# Top token tracker per feature
class FeatureTopTokenTracker:
    def __init__(self, feature_dim: int, top_k: int):
        self.feature_dim = feature_dim
        self.top_k = top_k
        # Store min-heaps: [(activation, token_str, token_id, prompt_idx, token_pos), ...]
        self.heaps = [[] for _ in range(feature_dim)]
    
    def update(self, token_activations: np.ndarray, token_ids: List[int], 
               prompt_idx: int, prompt_text: str, prompt_tokens: List[str],
               predicted_label: str = None, true_label: str = None):
        """Update with tokens from one prompt"""
        for token_pos, (act_vec, token_id) in enumerate(zip(token_activations, token_ids)):
            # For each token, find top features
            top_features = np.argsort(act_vec)[-5:]  # Track top 5 features per token
            
            for feat_id in top_features:
                activation = float(act_vec[feat_id])
                if activation <= 0:
                    continue
                
                heap = self.heaps[feat_id]
                token_str = prompt_tokens[token_pos] if token_pos < len(prompt_tokens) else f"[{token_id}]"
                
                metadata = {
                    "activation": activation,
                    "token_str": token_str,
                    "token_id": int(token_id),
                    "token_position": int(token_pos),
                    "prompt_index": int(prompt_idx),
                    "row_id": int(prompt_idx),  # Add row_id for server compatibility
                    "prompt_snippet": prompt_text[:160],
                    "prompt": prompt_text,  # Changed from "full_prompt" to "prompt"
                    "prompt_tokens": prompt_tokens,
                    "predicted_label": predicted_label,  # Add prediction info
                    "true_label": true_label,
                }
                
                if len(heap) < self.top_k:
                    heapq.heappush(heap, (activation, metadata))
                elif activation > heap[0][0]:
                    heapq.heapreplace(heap, (activation, metadata))
    
    def export(self):
        """Export top tokens for each feature"""
        result = {}
        for feat_id in range(self.feature_dim):
            sorted_tokens = sorted(self.heaps[feat_id], key=lambda x: -x[0])
            result[str(feat_id)] = [meta for _, meta in sorted_tokens]
        return result

# Aggregate top features per headline (sample-level view)
class HeadlineFeatureAggregator:
    def __init__(self, top_k: int = 10):
        self.top_k = top_k
        self.headlines = []  # List of headline metadata with top features
    
    def add_headline(self, prompt_idx: int, prompt_text: str,
                     token_activations: np.ndarray,
                     token_ids: List[int],
                     token_strings: List[str],
                     predicted_label: str, true_label: str):
        """Aggregate features across all tokens in a headline"""
        if token_activations.size == 0:
            return
        # Max activation per feature and which token triggered it
        max_token_idx_per_feature = token_activations.argmax(axis=0)  # [feature_dim]
        max_activation_per_feature = token_activations.max(axis=0)     # [feature_dim]
        
        # Get top K features by their max activation in this headline
        top_feature_ids = np.argsort(max_activation_per_feature)[-self.top_k:][::-1]
        
        features = [
            {
                "feature_id": int(fid),
                "max_activation": float(max_activation_per_feature[fid]),
                "token_position": int(max_token_idx_per_feature[fid]),
                "token_id": int(token_ids[max_token_idx_per_feature[fid]]),
                "token_str": token_strings[max_token_idx_per_feature[fid]],
            }
            for fid in top_feature_ids if max_activation_per_feature[fid] > 0
        ]
        
        self.headlines.append({
            "row_id": int(prompt_idx),
            "prompt": prompt_text,
            "predicted_label": predicted_label,
            "true_label": true_label,
            "correct": predicted_label == true_label,
            "num_tokens": int(token_activations.shape[0]),
            "features": features
        })
    
    def export(self):
        return self.headlines

# Initialize trackers for SAE features
feature_stats = FeatureStatsAggregator(SAE_LATENT_DIM)
top_token_tracker = FeatureTopTokenTracker(SAE_LATENT_DIM, TOP_TOKENS_PER_FEATURE)
headline_aggregator = HeadlineFeatureAggregator(top_k=10)

# Storage for per-sample metadata
all_prompt_metadata = []
all_prediction_metadata = []

# Hook to capture activations
captured_activations = []

def capture_hook(module, input, output):
    """Hook function to capture layer outputs"""
    if isinstance(output, tuple):
        hidden_states = output[0]
    else:
        hidden_states = output
    captured_activations.append(hidden_states.detach())  # Keep on GPU

# Register hook on target layer
target_layer = model.bert.encoder.layer[LAYER_TO_EXTRACT]
hook_handle = target_layer.register_forward_hook(capture_hook)

print(f"\nüî¨ Processing {min(MAX_SAMPLES, len(test_ds))} samples...")
print(f"   Layer: {LAYER_TO_EXTRACT}")
print(f"   Using SAE: {SAE_LATENT_DIM} sparse features")
print(f"   Filtering: ALL special tokens excluded (content only)\n")

# Process samples
with torch.no_grad():
    for idx, sample in enumerate(test_ds):
        if idx >= MAX_SAMPLES:
            break
        
        text = sample["text"]
        true_label = sample["label"]
        
        # Tokenize with truncation
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
        token_ids = inputs["input_ids"][0].tolist()
        
        # Get string tokens for display (properly cleaned)
        # Use tokenizer.convert_ids_to_tokens to get raw tokens, then clean them
        raw_tokens = tokenizer.convert_ids_to_tokens(token_ids)
        prompt_tokens = []
        for tok in raw_tokens:
            # Remove ## prefix for subword tokens, keep special tokens as-is
            if tok.startswith("##"):
                prompt_tokens.append(tok[2:])  # Remove ##
            else:
                prompt_tokens.append(tok)
        
        # Forward pass
        inputs = inputs.to(device)
        captured_activations.clear()
        outputs = model(**inputs)
        pred_id = outputs.logits.argmax(dim=-1).item()
        pred_label = model.config.id2label[pred_id]
        
        # Get captured activation and pass through SAE
        if captured_activations:
            # Get BERT activations: [seq_len, 768] - stays on GPU
            bert_activation = captured_activations[0].squeeze(0)
            
            # Filter out ALL special tokens (same as training) - do on GPU
            attention_mask = inputs["attention_mask"].squeeze(0).bool()
            token_ids_tensor = inputs["input_ids"].squeeze(0)
            
            # Filter out ALL special tokens (CLS, SEP, PAD, UNK, MASK, etc.)
            special_ids = set(tokenizer.all_special_ids)
            not_special = torch.tensor([tid.item() not in special_ids for tid in token_ids_tensor], 
                                       dtype=torch.bool, device=device)
            
            valid_mask = attention_mask & not_special  # GPU boolean mask
            
            # Filter activations on GPU
            bert_activation = bert_activation[valid_mask]
            
            # Skip if no valid tokens
            if bert_activation.shape[0] == 0:
                continue
            
            # Pass through SAE (all on GPU): [actual_len, 32768]
            sae_features = sae.encode(bert_activation)
            
            # Only now move to CPU for numpy conversion and token filtering
            sae_features_cpu = sae_features.detach().cpu().numpy()
            valid_mask_cpu = valid_mask.cpu().numpy()
            filtered_token_ids = [tid for tid, valid in zip(token_ids, valid_mask_cpu) if valid]
            filtered_prompt_tokens = [tok for tok, valid in zip(prompt_tokens, valid_mask_cpu) if valid]
            
            seq_len = sae_features_cpu.shape[0]
            
            # Update feature statistics with SAE features
            feature_stats.update(sae_features_cpu)
            
            # Track top tokens per feature
            top_token_tracker.update(
                sae_features_cpu, 
                filtered_token_ids, 
                prompt_idx=idx,
                prompt_text=text,
                prompt_tokens=filtered_prompt_tokens,
                predicted_label=pred_label,  # Pass prediction info
                true_label=model.config.id2label[true_label]
            )
            
            # Aggregate top features at headline level
            headline_aggregator.add_headline(
                prompt_idx=idx,
                prompt_text=text,
                token_activations=sae_features_cpu,
                token_ids=filtered_token_ids,
                token_strings=filtered_prompt_tokens,
                predicted_label=pred_label,
                true_label=model.config.id2label[true_label],
                confidence=confidence
            )
            
            # Save prompt metadata
            all_prompt_metadata.append({
                "row_id": idx,
                "seq_len": seq_len,
                "prompt": text,
                "predicted_label": pred_label,
                "true_label": model.config.id2label[true_label],
                "correct": pred_id == true_label
            })
        
        if (idx + 1) % 10 == 0:
            print(f"Processed {idx + 1}/{min(MAX_SAMPLES, len(test_ds))} samples")

# Remove hook
hook_handle.remove()

# Compute final statistics
print("\nüìä Computing feature statistics...")
stats = feature_stats.get_stats()

# Calculate accuracy
accuracy = sum(1 for p in all_prompt_metadata if p["correct"]) / max(len(all_prompt_metadata), 1)
print(f"üéØ Model Accuracy: {accuracy:.2%}")

# Get top features for each metric
top_features_by_metric = {}
for metric_name, values in stats.items():
    top_indices = np.argsort(values)[-TOP_FEATURES:][::-1]
    top_features_by_metric[metric_name] = [
        {
            "feature_id": int(idx),
            "value": float(values[idx]),
            "metrics": {  # Nest metrics in a sub-dict for server compatibility
                "mean_activation": float(stats["mean_activation"][idx]),
                "max_activation": float(stats["max_activation"][idx]),
                "fraction_active": float(stats["fraction_active"][idx])
            }
        }
        for idx in top_indices
    ]

# Save results
print("\nüíæ Saving results...")

# 1. Save prompts metadata (replaces prompts.jsonl from main.py)
prompts_file = run_dir / "prompts.jsonl"
with open(prompts_file, "w", encoding="utf-8") as f:
    for meta in all_prompt_metadata:
        json.dump(meta, f)
        f.write("\n")

# 2. Save feature statistics (replaces feature_stats.json from main.py)
feature_stats_file = run_dir / "feature_stats.json"
feature_stats_data = {
    "num_features": SAE_LATENT_DIM,
    "total_tokens": feature_stats.total_tokens,
    "top_feature_count": TOP_FEATURES,
    "accuracy": accuracy,  # Add accuracy for viewer
    "num_samples": len(all_prompt_metadata),  # Add sample count
    "mean_act_squared": stats["mean_act_squared"].tolist(),  # Add mean_act_squared for server
    "metrics": {
        metric_name: {
            "description": f"{metric_name.replace('_', ' ').title()} for each feature",
            "top_features": top_features_by_metric[metric_name]
        }
        for metric_name in stats.keys() if metric_name != "mean_act_squared"  # Exclude from metrics iteration
    }
}
with open(feature_stats_file, "w") as f:
    json.dump(feature_stats_data, f, indent=2)

# 3. Save top tokens per feature (replaces feature_tokens.json from main.py)
feature_tokens_file = run_dir / "feature_tokens.json"
feature_tokens_data = {
    "features": top_token_tracker.export()  # Wrap in "features" key for server compatibility
}
with open(feature_tokens_file, "w") as f:
    json.dump(feature_tokens_data, f, indent=2)

# 4. Save headline-level features
headline_features_file = run_dir / "headline_features.json"
with open(headline_features_file, "w") as f:
    json.dump(headline_aggregator.export(), f, indent=2)

# 5. Save metadata
metadata_file = run_dir / "metadata.json"
with open(metadata_file, "w") as f:
    json.dump({
        "model": save_dir,
        "layer_extracted": LAYER_TO_EXTRACT,
        "num_samples": len(all_prompt_metadata),
        "total_tokens": feature_stats.total_tokens,
        "accuracy": accuracy,
        "dataset": "zeroshot/twitter-financial-news-sentiment",
        "split": "validation",
        "hidden_dim": SAE_INPUT_DIM,
        "latent_dim": SAE_LATENT_DIM,
        "sae_path": f"./finbert_sae/layer_{LAYER_TO_EXTRACT}_{SAE_SIZE}.pt",
        "top_features_per_metric": TOP_FEATURES,
        "top_tokens_per_feature": TOP_TOKENS_PER_FEATURE,
        "note": "SAE sparse features with predictions"
    }, f, indent=2)

print(f"\n‚úÖ COMPLETE!")
print(f"   üìÅ Results saved to: {run_dir.name}")
print(f"   üéØ Accuracy: {accuracy:.2%}")
print(f"   üî¢ Total tokens: {feature_stats.total_tokens}")
print(f"   ‚ú® SAE features: {SAE_LATENT_DIM}")
print(f"\nüìä Top 5 features by mean activation:")
for i, feat in enumerate(top_features_by_metric["mean_activation"][:5], 1):
    metrics = feat['metrics']
    print(f"   {i}. Feature {feat['feature_id']}: "
          f"mean={metrics['mean_activation']:.4f}, "
          f"max={metrics['max_activation']:.4f}, "
          f"frac={metrics['fraction_active']:.2%}")

print(f"\nüåê Start the viewer to see results:")
print(f"   python viz_analysis/feature_probe_server.py")
print(f"   cd sae-viewer && npm start")


Testing Inference based on Best Model

In [None]:
# Quick analysis on simple headlines
save_dir = "./finbert_twitter_ft/best"

example_sentences = [
    "TSLA beats earnings expectations and raises full-year guidance.",
    "Apple shares fall after reporting weaker-than-expected iPhone sales.",
    "The company reported results largely in line with analyst expectations.",
    "Amazon warns of margin pressure due to rising logistics costs.",
    "NVIDIA stock surges as demand for AI chips remains strong.",
    "The firm announced a restructuring plan, sending shares lower.",
    "Revenue growth slowed quarter-over-quarter, but profitability improved.",
    "Investors remain cautious ahead of the Federal Reserve meeting.",
    "Strong cash flow and reduced debt boosted investor confidence.",
    "The outlook remains uncertain amid macroeconomic headwinds."
]

tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

# optional: move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def predict_sentiment(text: str):
    inputs = tokenizer(text, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        out = model(**inputs)
    pred_id = out.logits.argmax(dim=-1).item()
    return model.config.id2label[pred_id]

for text in example_sentences:
    label = predict_sentiment(text)
    print(f"{label.upper():8} | {text}")

In [None]:
# Data Visualisation for Dataset
test_ds = ds["validation"]  # Use validation set for analysis

test_ds["text"][0:200]
#ds2 = load_dataset("zeroshot/twitter-financial-news-sentiment")
#ds2["validation"]["text"][34]

In [None]:
# Inference WITHOUT SAEs - Plain Model Accuracy on Test Data
# import torch
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# from tqdm import tqdm

print("=" * 60)
print("MODEL INFERENCE WITHOUT SAEs")
print("=" * 60)

# Load the fine-tuned model
save_dir = "./finbert_twitter_ft/best"
tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Use validation set for evaluation
test_ds = ds["validation"]
# MAX_SAMPLES = len(test_ds)  # Process all samples, or set a limit if needed
MAX_SAMPLES = 100
MAX_SEQ_LENGTH = 64

print(f"\nüî¨ Running inference on {MAX_SAMPLES} test samples...")
print(f"   Device: {device}")
print(f"   Model: {save_dir}\n")

correct_predictions = 0
total_predictions = 0

# Process samples
with torch.no_grad():
    for idx, sample in enumerate(tqdm(test_ds, desc="Processing")):
        if idx >= MAX_SAMPLES:
            break
        
        text = sample["text"]
        true_label = sample["label"]
        
        # Tokenize with truncation
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
        inputs = inputs.to(device)
        
        # Forward pass
        outputs = model(**inputs)
        pred_id = outputs.logits.argmax(dim=-1).item()
        
        # Check if prediction is correct
        if pred_id == true_label:
            correct_predictions += 1
        total_predictions += 1

# Calculate accuracy
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0

print(f"\n{'=' * 60}")
print(f"‚úÖ INFERENCE COMPLETE (WITHOUT SAEs)")
print(f"{'=' * 60}")
print(f"   üìä Total Samples: {total_predictions}")
print(f"   ‚úì Correct Predictions: {correct_predictions}")
print(f"   ‚úó Incorrect Predictions: {total_predictions - correct_predictions}")
print(f"   üéØ Model Accuracy: {accuracy:.2%}")
print(f"{'=' * 60}")
