In [5]:
# Import necessary libraries
from datasets import load_dataset
# import re
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import json
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
import numpy as np
import evaluate
# import heapq
from typing import List, Dict, Tuple
from collections import defaultdict
import sys
from typing import List

# Add project root to Python path
repo_root = Path(".").resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Remove cached modules to force a fresh import
for mod in list(sys.modules.keys()):
    if mod.startswith("utils.") or mod.startswith("sparse_autoencoder."):
        del sys.modules[mod]

import utils.finbert
import utils.analysis
import utils.ablation
import utils.run_dirs
import sparse_autoencoder.finbert_sae
import utils.data_cleaning

from utils.finbert import compute_metrics
from utils.analysis import (
    FeatureStatsAggregator,
    FeatureTopTokenTracker,
    HeadlineFeatureAggregator
)
from utils.ablation import (
    create_intervention_hook,
    validate_feature_ids,
    normalize_decoder_weights,
    expand_features_with_similarity,
    run_baseline_inference,
    run_ablation_inference,
    find_flipped_predictions,
)
from utils.run_dirs import make_analysis_run_dir
from sparse_autoencoder.finbert_sae import SparseAutoencoder, load_sae
from utils.data_cleaning import clean_text, remove_leading_tickers

# --------- CUDA sanity check ----------
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

# Define device for SAE loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Torch: 2.6.0+cu124
CUDA available: True
GPU: NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [2]:
# 1. Load dataset, comes with train and validation fold 
ds = load_dataset("zeroshot/twitter-financial-news-sentiment")


# Clean dataset
ds = ds.map(lambda x: {"text": clean_text(x["text"])})
ds = ds.map(lambda x: {"text": remove_leading_tickers(x["text"])})

# Load dataset
train_ds = ds["train"]
test_ds = ds["validation"]  # Use validation set for analysis


In [3]:
# Constants/Hyperparameters for training model and SAE
LAYER_TO_EXTRACT = 8  # 3/4 layer of BERT (0-11 for base BERT)
LATENT_DIMS = [4096, 8192, 16384, 32768]  # Train SAEs with 4k, 8k, 16k, 32k features
L1_COEFFICIENT = 1e-3  # Sparsity penalty
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
NUM_EPOCHS = 3


# Configuration for Inference
MAX_SAMPLES = 100  # Limit for testing
TOP_FEATURES = 100  # Top features to track per metric
TOP_TOKENS_PER_FEATURE = 20  # Top activating tokens per feature
MAX_SEQ_LENGTH = 64  # Maximum sequence length to process
SAE_SIZE = "32k"  # <-- Change this to switch between SAE models, Choose which SAE to use: "4k", "8k", "16k", or "32k"

Fine Tune Hyperparameters of the model

This trains an SAE to decompose FinBERT's 768-dimensional activations into ~4k to 32k interpretable sparse features.


In [None]:
# This cell finetunes SAEs based on BERT.
# Configuration
LAYER_TO_EXTRACT = 8  # Middle layer of BERT
LATENT_DIMS = [4096, 8192, 16384, 32768]  # Train SAEs with 4k, 8k, 16k, 32k features
L1_COEFFICIENT = 1e-3  # Sparsity penalty
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
NUM_EPOCHS = 3

# Create SAE save directory
Path("./finbert_sae").mkdir(exist_ok=True)

# Load the fine-tuned model
save_dir = "./finbert_twitter_ft/best"
tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Load dataset
# train_ds = ds["train"]

print(f"Collecting activations from {len(train_ds)} training samples...")
print(f"Target layer: {LAYER_TO_EXTRACT}")
print(f"Will train SAEs with latent dimensions: {LATENT_DIMS}")

# Collect training activations
all_activations = []
captured_activations = []

def capture_hook(module, input, output):
    if isinstance(output, tuple):
        hidden_states = output[0]
    else:
        hidden_states = output
    captured_activations.append(hidden_states.detach())  # Keep on GPU

# Register hook
target_layer = model.bert.encoder.layer[LAYER_TO_EXTRACT]
hook_handle = target_layer.register_forward_hook(capture_hook)

# Collect activations from all training data
print("Extracting activations from training set...")
print("Filtering out ALL special tokens (CLS, SEP, PAD, UNK, MASK, etc.) - keeping only content tokens...")
with torch.no_grad():
    for idx, sample in enumerate(tqdm(train_ds)):
        text = sample["text"]
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
        inputs = inputs.to(device)
        
        captured_activations.clear()
        _ = model(**inputs)
        
        if captured_activations:
            # Get all token activations: [seq_len, 768] - stays on GPU
            activation = captured_activations[0].squeeze(0)
            
            # Get attention mask and token IDs (keep on GPU)
            attention_mask = inputs["attention_mask"].squeeze(0).bool()
            token_ids = inputs["input_ids"].squeeze(0)
            
            # Filter out ALL special tokens (CLS, SEP, PAD, UNK, MASK, etc.)
            special_ids = set(tokenizer.all_special_ids)
            not_special = torch.tensor([tid.item() not in special_ids for tid in token_ids], 
                                       dtype=torch.bool, device=device)
            
            valid_mask = attention_mask & not_special  # GPU boolean mask

            # Print the number of valid tokens
            # kept = valid_mask.sum().item()
            # total = attention_mask.sum().item()
            # print(f"Kept {kept}/{total} tokens")

            # tokens = tokenizer.convert_ids_to_tokens(token_ids)
            # kept_tokens = [t for t, m in zip(tokens, valid_mask.tolist()) if m]
            # dropped_tokens = [t for t, m in zip(tokens, valid_mask.tolist()) if not m]

            # print("TOKENS:", tokens)
            # print("DROPPED:", dropped_tokens)
            # print("KEPT:", kept_tokens)
            
            # Only keep activations for real content tokens (still on GPU)
            activation = activation[valid_mask]
            
            # Only add if there are real tokens
            if activation.shape[0] > 0:
                # Move to CPU only when storing for later processing
                all_activations.append(activation.cpu())

hook_handle.remove()

# Flatten all activations into a single tensor [total_tokens, 768]
all_activations_tensor = torch.cat(all_activations, dim=0)
print(f"\\nCollected {all_activations_tensor.shape[0]} token activations")
print(f"Activation shape: {all_activations_tensor.shape}")

# Train SAEs for each latent dimension
for LATENT_DIM in LATENT_DIMS:
    print(f"\\n{'='*80}")
    print(f"Training SAE with {LATENT_DIM} latent features ({LATENT_DIM//1024}k)")
    print(f"{'='*80}")
    
    # Create SAE
    sae = SparseAutoencoder(input_dim=768, latent_dim=LATENT_DIM)
    sae.to(device)
    
    # Optimizer
    optimizer = optim.Adam(sae.parameters(), lr=LEARNING_RATE)
    
    # Create DataLoader
    from torch.utils.data import TensorDataset, DataLoader
    dataset = TensorDataset(all_activations_tensor)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Training loop
    print(f"\\nTraining SAE for {NUM_EPOCHS} epochs...")
    sae.train()
    
    for epoch in range(NUM_EPOCHS):
        total_loss = 0
        total_recon_loss = 0
        total_l1_loss = 0
        
        for batch_idx, (batch_x,) in enumerate(dataloader):
            batch_x = batch_x.to(device)
            
            # Forward pass
            reconstruction, latent = sae(batch_x)
            
            # Reconstruction loss (MSE)
            recon_loss = nn.functional.mse_loss(reconstruction, batch_x)
            
            # L1 sparsity loss
            l1_loss = latent.abs().mean()
            
            # Combined loss
            loss = recon_loss + L1_COEFFICIENT * l1_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Renormalize decoder weights (standard SAE practice)
            with torch.no_grad():
                sae.decoder.weight.data = nn.functional.normalize(
                    sae.decoder.weight.data, dim=0
                )
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_l1_loss += l1_loss.item()
        
        avg_loss = total_loss / len(dataloader)
        avg_recon = total_recon_loss / len(dataloader)
        avg_l1 = total_l1_loss / len(dataloader)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}: Loss={avg_loss:.4f}, "
              f"Recon={avg_recon:.4f}, L1={avg_l1:.4f}")
    
    # Save the trained SAE
    SAE_SAVE_PATH = f"./finbert_sae/layer_{LAYER_TO_EXTRACT}_{LATENT_DIM//1024}k.pt"
    print(f"\\nSaving trained SAE to {SAE_SAVE_PATH}")
    torch.save({
        'encoder_weight': sae.encoder.weight.data.cpu(),
        'encoder_bias': sae.encoder.bias.data.cpu(),
        'decoder_weight': sae.decoder.weight.data.cpu(),
        'decoder_bias': sae.decoder.bias.data.cpu(),
        'config': {
            'input_dim': 768,
            'latent_dim': LATENT_DIM,
            'layer': LAYER_TO_EXTRACT,
            'model': save_dir,
        }
    }, SAE_SAVE_PATH)
    
    # Test sparsity
    sae.eval()
    with torch.no_grad():
        sample_acts = all_activations_tensor[:1000].to(device)
        sample_latent = sae.encode(sample_acts)
        sparsity = (sample_latent > 0).float().mean()
        print(f"\\n‚úì SAE trained successfully!")
        print(f"  Average sparsity: {sparsity:.2%} of features active")
        print(f"  Saved to: {SAE_SAVE_PATH}")

print(f"\\n{'='*80}")
print(f"All SAEs trained successfully!")
print(f"Available SAE models:")
for dim in LATENT_DIMS:
    print(f"  - layer_{LAYER_TO_EXTRACT}_{dim//1024}k.pt ({dim} features)")
print(f"\\nThese SAEs can now be used in main.py for interpretability analysis!")
print(f"{'='*80}")


Finetune FinBERT Model

The FinBERT model is trained on the training fold of our dataset to improve its prediction accuracy.


In [None]:
# This cell finetunes the FINBERT model.

# 2) Load model/tokenizer
model_name = "ahmedrachid/FinancialBERT-Sentiment-Analysis"
tokenizer = AutoTokenizer.from_pretrained(model_name)

id2label = {0: "Bearish", 1: "Bullish", 2: "Neutral"}
label2id = {v: k for k, v in id2label.items()}

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=3,
    id2label=id2label,
    label2id=label2id,
)

# Move model to GPU
model.to(device)

# 3) Tokenize
def tokenize_fn(batch):
    return tokenizer(batch["text"], truncation=True)

train_tok = train_ds.map(tokenize_fn, batched=True)
val_tok = test_ds.map(tokenize_fn, batched=True)

train_tok = train_tok.rename_column("label", "labels")
val_tok = val_tok.rename_column("label", "labels")

cols_to_keep = ["input_ids", "attention_mask", "labels"]
train_tok.set_format(type="torch", columns=cols_to_keep)
val_tok.set_format(type="torch", columns=cols_to_keep)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 4) Metrics
acc = evaluate.load("accuracy")
f1 = evaluate.load("f1")

# 5) Training config
use_fp16 = torch.cuda.is_available()  # fp16 only makes sense on GPU

training_args = TrainingArguments(
    output_dir="./finbert_twitter_ft",
    eval_strategy="epoch",   # <-- use this name; some versions don't accept eval_strategy
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",
    fp16=use_fp16,                 # <-- enables mixed precision on NVIDIA GPU
    dataloader_num_workers=0,      # safer on Windows; avoids hanging
    report_to="none",              # avoids needing wandb, etc.
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.evaluate()

trainer.save_model("./finbert_twitter_ft/best")
tokenizer.save_pretrained("./finbert_twitter_ft/best")


Inference with Interpretability

We use our FinBERT + SAE on test data. We extract a Layer Activations with Sentiment Predictions (SAE-style Analysis).

In [4]:
# Feature Ablation Experiment
# This cell performs ablation by zeroing out specified SAE features and comparing predictions

# Configuration
# ========== ABLATION CONFIGURATION ==========
ABLATION_CONFIG = {
    
    "mode": "manual",                 # Mode 1, Options: "manual" | "per_sample_top_k" | "union_top_k"
    #"mode": "per_sample_top_k",       # Mode 2
    #"mode": "union_top_k",             # Mode 3
    "k": 10,                           # Only for per_sample_top_k and union_top_k modes
    "skip_sae_reconstruction": False,  # If True, skip SAE hook entirely (true baseline)
    "similarity_expansion": False,      # If True, include top m similar features to selected features
    "similarity_top_m": 64            # Total features per seed feature (includes original)
}

# Mode 1: Manual features (only used if mode == "manual")
#MANUAL_FEATURES = [4456, 21508, 21969, 27518, 21110, 24583, 32601, 15959, 27518, 29555, 3993, 13142, 22354, 21858]
# MANUAL_FEATURES = [21110, 24583, 21508, 32601, 15959, 27518, 29555, 3993, 13142, 22354] # row 0 true top 10
MANUAL_FEATURES = [21508, 27518]
#MANUAL_FEATURES = []

print("=" * 60)
print("FEATURE ABLATION EXPERIMENT")
print("=" * 60)
print(f"Ablation Mode: {ABLATION_CONFIG['mode']}")
if ABLATION_CONFIG['mode'] != 'manual':
    print(f"K value: {ABLATION_CONFIG['k']}")
print(f"Similarity Expansion: {ABLATION_CONFIG['similarity_expansion']}")
if ABLATION_CONFIG['similarity_expansion']:
    print(f"Similarity Top-M: {ABLATION_CONFIG['similarity_top_m']}")

# Load model and tokenizer
save_dir = "./finbert_twitter_ft/best"
tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

# Define device and move model to it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu only, please install CUDA-compatible Torch")
model.to(device)
model.eval()

# Load the SAE using the helper function
sae, sae_config = load_sae(layer=LAYER_TO_EXTRACT, latent_size=SAE_SIZE)

# Extract dimensions from the loaded config
SAE_INPUT_DIM = sae_config['input_dim']
SAE_LATENT_DIM = sae_config['latent_dim']

print(f"Ablation mode: {ABLATION_CONFIG['mode']}")
print(f"Layer: {LAYER_TO_EXTRACT}")
print(f"SAE Size: {SAE_SIZE} ({SAE_LATENT_DIM} features)")
print(f"Max Samples: {MAX_SAMPLES}\n")

# Storage for results
baseline_predictions = []
ablated_predictions = []
sample_data = []

# Initialize trackers for SAE features (same as inference cell)
feature_stats_ablated = FeatureStatsAggregator(SAE_LATENT_DIM)
top_token_tracker_ablated = FeatureTopTokenTracker(SAE_LATENT_DIM, TOP_TOKENS_PER_FEATURE)
headline_aggregator_ablated = HeadlineFeatureAggregator(top_k=10)
all_prompt_metadata_ablated = []

# Storage for capturing SAE features during ablation (for tracking)
current_sample_data = {"sae_features": None, "token_ids": None, "prompt_tokens": None, "text": None, "idx": None}


FEATURE ABLATION EXPERIMENT
Ablation Mode: manual
Similarity Expansion: False
‚úì Loaded SAE from ./finbert_sae/layer_8_32k.pt
  Layer: 8
  Input dim: 768
  Latent dim: 32768
Ablation mode: manual
Layer: 8
SAE Size: 32k (32768 features)
Max Samples: 100



In [5]:
# Baseline Inference
print("üî¨ Running baseline inference (no ablation)...")
baseline_results, baseline_features_map, baseline_accuracy = run_baseline_inference(
    model=model,
    tokenizer=tokenizer,
    test_ds=test_ds,
    device=device,
    sae=sae,
    layer_to_extract=LAYER_TO_EXTRACT,
    max_samples=MAX_SAMPLES,
    max_seq_length=MAX_SEQ_LENGTH,
)
print(f"‚úì Baseline accuracy: {baseline_accuracy:.2%}\n")


üî¨ Running baseline inference (no ablation)...
  Baseline: 20/100 samples
  Baseline: 40/100 samples
  Baseline: 60/100 samples
  Baseline: 80/100 samples
  Baseline: 100/100 samples
‚úì Baseline accuracy: 87.00%



In [6]:
# Feature Selection & Expansion
skip_hooks = ABLATION_CONFIG.get("skip_sae_reconstruction", False)

similarity_enabled = ABLATION_CONFIG.get("similarity_expansion", False)
similarity_top_m = int(ABLATION_CONFIG.get("similarity_top_m", 10))
if similarity_top_m < 1:
    raise ValueError("similarity_top_m must be >= 1")

normalized_decoder = None
similarity_cache = {}
if similarity_enabled and not skip_hooks:
    normalized_decoder = normalize_decoder_weights(sae, device)
    print(f"Similarity expansion: enabled (top_m={similarity_top_m})")
elif similarity_enabled and skip_hooks:
    print("Similarity expansion requested, but skip_sae_reconstruction=True; expansion disabled.")
    similarity_enabled = False

if skip_hooks:
    print("\n‚ÑπÔ∏è Skipping SAE reconstruction (true baseline mode)")
    print("   Predictions will match baseline exactly\n")
    FEATURES_TO_ABLATE = None
    ORIGINAL_SEED_FEATURES = None
else:
    ORIGINAL_SEED_FEATURES = None

    if ABLATION_CONFIG["mode"] == "manual":
        FEATURES_TO_ABLATE = MANUAL_FEATURES
        ORIGINAL_SEED_FEATURES = list(MANUAL_FEATURES)
        validate_feature_ids(FEATURES_TO_ABLATE, SAE_LATENT_DIM, "manual features")
        original_feature_count = len(FEATURES_TO_ABLATE)
        if similarity_enabled:
            FEATURES_TO_ABLATE = expand_features_with_similarity(
                FEATURES_TO_ABLATE, normalized_decoder, similarity_top_m, similarity_cache
            )
        print(f"\nMode 1 (Manual): Ablating {len(FEATURES_TO_ABLATE)} manually specified features")
        if similarity_enabled:
            print(f"  Similarity expansion: {original_feature_count} ‚Üí {len(FEATURES_TO_ABLATE)} (top_m={similarity_top_m})")
    elif ABLATION_CONFIG["mode"] == "union_top_k":
        feature_set = set()
        for idx in baseline_features_map:
            top_k_ids = [f["feature_id"] for f in baseline_features_map[idx]["top_features"][:ABLATION_CONFIG["k"]]]
            feature_set.update(top_k_ids)
        FEATURES_TO_ABLATE = sorted(list(feature_set))
        validate_feature_ids(FEATURES_TO_ABLATE, SAE_LATENT_DIM, "union_top_k features")
        original_feature_count = len(FEATURES_TO_ABLATE)
        if similarity_enabled:
            FEATURES_TO_ABLATE = expand_features_with_similarity(
                FEATURES_TO_ABLATE, normalized_decoder, similarity_top_m, similarity_cache
            )
        print(f"\nMode 3 (Union Top-K): Collected {len(FEATURES_TO_ABLATE)} unique features from union of top-{ABLATION_CONFIG['k']} across {len(baseline_results)} samples")
        if similarity_enabled:
            print(f"  Similarity expansion: {original_feature_count} ‚Üí {len(FEATURES_TO_ABLATE)} (top_m={similarity_top_m})")
    elif ABLATION_CONFIG["mode"] == "per_sample_top_k":
        FEATURES_TO_ABLATE = None
        print(f"\nMode 2 (Per-Sample Top-K): Will ablate top-{ABLATION_CONFIG['k']} features individually for each sample")
    else:
        raise ValueError(f"Unknown ablation mode: {ABLATION_CONFIG['mode']}")

    if ABLATION_CONFIG["mode"] != "per_sample_top_k":
        validate_feature_ids(FEATURES_TO_ABLATE, SAE_LATENT_DIM, "global ablation features")

    if ABLATION_CONFIG["mode"] == "manual":
        print(f"Features to ablate: {FEATURES_TO_ABLATE if FEATURES_TO_ABLATE is not None else 'Per-sample dynamic'}\n")



Mode 1 (Manual): Ablating 2 manually specified features
Features to ablate: [21508, 27518]



In [7]:
# Ablation Inference
print("üî¨ Running ablation inference (features zeroed)...")
(
    ablated_results,
    all_prompt_metadata_ablated,
    similarity_stats,
    all_ablated_features_set,
    ablated_accuracy,
) = run_ablation_inference(
    model=model,
    tokenizer=tokenizer,
    test_ds=test_ds,
    device=device,
    sae=sae,
    layer_to_extract=LAYER_TO_EXTRACT,
    max_samples=MAX_SAMPLES,
    max_seq_length=MAX_SEQ_LENGTH,
    ablation_config=ABLATION_CONFIG,
    features_to_ablate=FEATURES_TO_ABLATE,
    baseline_results=baseline_results,
    baseline_features_map=baseline_features_map,
    feature_stats_ablated=feature_stats_ablated,
    top_token_tracker_ablated=top_token_tracker_ablated,
    headline_aggregator_ablated=headline_aggregator_ablated,
    current_sample_data=current_sample_data,
    similarity_enabled=similarity_enabled,
    normalized_decoder=normalized_decoder,
    similarity_top_m=similarity_top_m,
    similarity_cache=similarity_cache,
)
print(f"‚úì Ablated accuracy: {ablated_accuracy:.2%}\n")


üî¨ Running ablation inference (features zeroed)...
  Ablated: 20/100 samples
  Ablated: 40/100 samples
  Ablated: 60/100 samples
  Ablated: 80/100 samples
  Ablated: 100/100 samples
‚úì Ablated accuracy: 88.00%



In [8]:
# Results Analysis
flipped_samples = find_flipped_predictions(
    model=model,
    tokenizer=tokenizer,
    device=device,
    sae=sae,
    layer_to_extract=LAYER_TO_EXTRACT,
    max_seq_length=MAX_SEQ_LENGTH,
    baseline_results=baseline_results,
    ablated_results=ablated_results,
)

print("=" * 60)
print("FEATURE ABLATION RESULTS")
print("=" * 60)

if skip_hooks is False:
    print(f"Ablation Mode: {ABLATION_CONFIG['mode']}")

    if ABLATION_CONFIG["mode"] == "manual" and ORIGINAL_SEED_FEATURES is not None:
        print(f"Seed Features: {ORIGINAL_SEED_FEATURES}")
        print(f"Total Features Ablated: {len(FEATURES_TO_ABLATE)}")
        if similarity_enabled and len(FEATURES_TO_ABLATE) != len(ORIGINAL_SEED_FEATURES):
            print(f"  (Expanded from {len(ORIGINAL_SEED_FEATURES)} seeds)")

    elif ABLATION_CONFIG["mode"] == "union_top_k" and FEATURES_TO_ABLATE is not None:
        print(f"Total Features Ablated: {len(FEATURES_TO_ABLATE)}")
        if similarity_enabled:
            print("  (After similarity expansion)")

    elif ABLATION_CONFIG["mode"] == "per_sample_top_k":
        print(f"Per-Sample: top-{ABLATION_CONFIG['k']} features")
        if len(all_ablated_features_set) > 0:
            print(f"Unique Features Ablated: {len(all_ablated_features_set)} (across {len(ablated_results)} samples)")
            if similarity_stats["expanded_counts"]:
                avg_per_sample = float(np.mean(similarity_stats["expanded_counts"]))
                print(f"Per-Sample Average: {avg_per_sample:.1f} features")
            else:
                print(f"Per-Sample Average: {ABLATION_CONFIG['k']} features")

print(f"Baseline Accuracy: {baseline_accuracy:.2%}")
print(f"Ablated Accuracy: {ablated_accuracy:.2%}")
print(f"Accuracy Change: {(ablated_accuracy - baseline_accuracy):.2%}")
print(f"\nFlipped Predictions: {len(flipped_samples)}/{len(baseline_results)} samples")
print(f"Flip Rate: {len(flipped_samples)/len(baseline_results):.2%}\n")

if flipped_samples:
    print("=" * 60)
    print(f"FLIPPED PREDICTIONS (showing first {min(10, len(flipped_samples))}):")
    print("=" * 60)

    for i, flip in enumerate(flipped_samples[:10], 1):
        print(f"\n--- Sample #{flip['sample_idx']} ---")
        print(f"Text: {flip['text'][:120]}{'...' if len(flip['text']) > 120 else ''}")
        print(f"True Label: {flip['true_label']}")
        print(f"Original: {flip['baseline_pred']} (conf: {flip['baseline_conf']:.3f}) ‚Üí "
              f"Ablated: {flip['ablated_pred']} (conf: {flip['ablated_conf']:.3f})")

        if flip['top_features']:
            print("Top 10 SAE Features:")
            for feat in flip['top_features']:
                ablated_marker = " [ABLATED]" if feat['ablated'] else ""
                print(f"  Feature {feat['feature_id']}: {feat['activation']:.4f}{ablated_marker}")
        print()
else:
    print("No predictions were flipped by ablating these features.")


FEATURE ABLATION RESULTS
Ablation Mode: manual
Seed Features: [21508, 27518]
Total Features Ablated: 2
Baseline Accuracy: 87.00%
Ablated Accuracy: 88.00%
Accuracy Change: 1.00%

Flipped Predictions: 3/100 samples
Flip Rate: 3.00%

FLIPPED PREDICTIONS (showing first 3):

--- Sample #5 ---
Text: Barclays cools on Molson Coors
True Label: Bearish
Original: Bearish (conf: 0.648) ‚Üí Ablated: Neutral (conf: 0.452)
Top 10 SAE Features:
  Feature 4456: 8.8932
  Feature 32601: 6.1672
  Feature 15991: 4.6884
  Feature 21508: 4.4652 [ABLATED]
  Feature 29952: 4.4536
  Feature 5111: 4.2631
  Feature 28660: 4.2182
  Feature 7927: 4.2108
  Feature 687: 4.1628
  Feature 27757: 4.0974


--- Sample #38 ---
Text: Alliance Global Partners starts at Buy
True Label: Bullish
Original: Neutral (conf: 0.528) ‚Üí Ablated: Bullish (conf: 0.504)
Top 10 SAE Features:
  Feature 4456: 9.0261
  Feature 21110: 6.6836
  Feature 25797: 6.2264
  Feature 24583: 4.7995
  Feature 18425: 4.3438
  Feature 687: 4.2223
  Feat

In [9]:
# Save Results
print("\nüíæ Saving ablated results for visualization...")

ablated_run_dir = make_analysis_run_dir(str(repo_root))
print(f"üíæ Saving ablated results to: {ablated_run_dir}")

stats_ablated = feature_stats_ablated.get_stats()

top_features_by_metric_ablated = {}
for metric_name, values in stats_ablated.items():
    if metric_name == "mean_act_squared":
        continue
    top_indices = np.argsort(values)[-TOP_FEATURES:][::-1]
    top_features_by_metric_ablated[metric_name] = [
        {
            "feature_id": int(idx),
            "value": float(values[idx]),
            "metrics": {
                "mean_activation": float(stats_ablated["mean_activation"][idx]),
                "max_activation": float(stats_ablated["max_activation"][idx]),
                "fraction_active": float(stats_ablated["fraction_active"][idx]),
            },
        }
        for idx in top_indices
    ]

prompts_file = ablated_run_dir / "prompts.jsonl"
with open(prompts_file, "w", encoding="utf-8") as f:
    for meta in all_prompt_metadata_ablated:
        json.dump(meta, f)
        f.write("\n")

feature_stats_file = ablated_run_dir / "feature_stats.json"
feature_stats_data = {
    "num_features": SAE_LATENT_DIM,
    "total_tokens": feature_stats_ablated.total_tokens,
    "top_feature_count": TOP_FEATURES,
    "accuracy": ablated_accuracy,
    "num_samples": len(all_prompt_metadata_ablated),
    "mean_act_squared": stats_ablated["mean_act_squared"].tolist(),
    "metrics": {
        metric_name: {
            "description": f"{metric_name.replace('_', ' ').title()} for each feature",
            "top_features": top_features_by_metric_ablated[metric_name],
        }
        for metric_name in stats_ablated.keys() if metric_name != "mean_act_squared"
    },
}
with open(feature_stats_file, "w") as f:
    json.dump(feature_stats_data, f, indent=2)

feature_tokens_file = ablated_run_dir / "feature_tokens.json"
feature_tokens_data = {"features": top_token_tracker_ablated.export()}
with open(feature_tokens_file, "w") as f:
    json.dump(feature_tokens_data, f, indent=2)

headline_features_file = ablated_run_dir / "headline_features.json"
with open(headline_features_file, "w") as f:
    json.dump(headline_aggregator_ablated.export(), f, indent=2)

metadata_file = ablated_run_dir / "metadata.json"
with open(metadata_file, "w") as f:
    json.dump(
        {
            "model": save_dir,
            "layer_extracted": LAYER_TO_EXTRACT,
            "num_samples": len(all_prompt_metadata_ablated),
            "total_tokens": feature_stats_ablated.total_tokens,
            "accuracy": ablated_accuracy,
            "dataset": "zeroshot/twitter-financial-news-sentiment",
            "split": "validation",
            "hidden_dim": SAE_INPUT_DIM,
            "latent_dim": SAE_LATENT_DIM,
            "sae_path": f"./finbert_sae/layer_{LAYER_TO_EXTRACT}_{SAE_SIZE}.pt",
            "top_features_per_metric": TOP_FEATURES,
            "top_tokens_per_feature": TOP_TOKENS_PER_FEATURE,
            "ablation_mode": ABLATION_CONFIG["mode"],
            "ablated_features": FEATURES_TO_ABLATE if FEATURES_TO_ABLATE is not None else "per_sample_dynamic",
            "ablation_k": ABLATION_CONFIG.get("k"),
            "skip_sae_reconstruction": ABLATION_CONFIG.get("skip_sae_reconstruction", False),
            "similarity_expansion": {"enabled": similarity_enabled, "top_m": similarity_top_m},
            "note": f"SAE sparse features with predictions (mode: {ABLATION_CONFIG['mode']})",
        },
        f,
        indent=2,
    )

print("\n‚úÖ Ablation experiment complete!")
print(f"   üìÅ Ablated results saved to: {ablated_run_dir.name}")
print(f"   üéØ Ablated Accuracy: {ablated_accuracy:.2%}")
print(f"   üî¢ Total tokens: {feature_stats_ablated.total_tokens}")
print(f"   ‚ú® SAE features: {SAE_LATENT_DIM}")
print("\nüåê Start the viewer to see ablated results:")
print("   python viz_analysis/feature_probe_server.py")
print("   cd sae-viewer && npm start")



üíæ Saving ablated results for visualization...
üíæ Saving ablated results to: C:\Users\andre\OneDrive - National University of Singapore\Desktop\FYP\sparse_autoencoder_openai\analysis_data\2026-02-03T15-07-52_run-091

‚úÖ Ablation experiment complete!
   üìÅ Ablated results saved to: 2026-02-03T15-07-52_run-091
   üéØ Ablated Accuracy: 88.00%
   üî¢ Total tokens: 1542
   ‚ú® SAE features: 32768

üåê Start the viewer to see ablated results:
   python viz_analysis/feature_probe_server.py
   cd sae-viewer && npm start


Testing Inference based on Best Model

In [24]:
# Quick analysis on simple headlines
save_dir = "./finbert_twitter_ft/best"

example_sentences = [
    "TSLA beats earnings expectations and raises full-year guidance.",
    "Apple shares fall after reporting weaker-than-expected iPhone sales.",
    "The company reported results largely in line with analyst expectations.",
    "Amazon warns of margin pressure due to rising logistics costs.",
    "NVIDIA stock surges as demand for AI chips remains strong.",
    "The firm announced a restructuring plan, sending shares lower.",
    "Revenue growth slowed quarter-over-quarter, but profitability improved.",
    "Investors remain cautious ahead of the Federal Reserve meeting.",
    "Strong cash flow and reduced debt boosted investor confidence.",
    "The outlook remains uncertain amid macroeconomic headwinds."
]

tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

# optional: move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def predict_sentiment(text: str):
    inputs = tokenizer(text, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        out = model(**inputs)
    pred_id = out.logits.argmax(dim=-1).item()
    return model.config.id2label[pred_id]

for text in example_sentences:
    label = predict_sentiment(text)
    print(f"{label.upper():8} | {text}")

BULLISH  | TSLA beats earnings expectations and raises full-year guidance.
BEARISH  | Apple shares fall after reporting weaker-than-expected iPhone sales.
NEUTRAL  | The company reported results largely in line with analyst expectations.
BEARISH  | Amazon warns of margin pressure due to rising logistics costs.
BULLISH  | NVIDIA stock surges as demand for AI chips remains strong.
BEARISH  | The firm announced a restructuring plan, sending shares lower.
NEUTRAL  | Revenue growth slowed quarter-over-quarter, but profitability improved.
NEUTRAL  | Investors remain cautious ahead of the Federal Reserve meeting.
BULLISH  | Strong cash flow and reduced debt boosted investor confidence.
BEARISH  | The outlook remains uncertain amid macroeconomic headwinds.


In [25]:
# Data Visualisation for Dataset
test_ds = ds["validation"]  # Use validation set for analysis

test_ds["text"][0:200]
#ds2 = load_dataset("zeroshot/twitter-financial-news-sentiment")
#ds2["validation"]["text"][34]

['Ally Financial pulls outlook',
 'Dell, HPE targets trimmed on compute headwinds',
 "Moody's turns negative on Party City",
 'Deutsche Bank cuts to Hold',
 'Compass Point cuts to Sell',
 'Barclays cools on Molson Coors',
 'Barclays cuts to Equal Weight',
 'Analysts Eviscerate Musk\'s Cybertruck: "0% Of Responses Felt It Will Be A Success"',
 'Barclays assigns only a 20% chance that studies on a Gilead antiviral drug being done in China will succeed against‚Ä¶',
 "BTIG points to breakfast pressure for Dunkin' Brands",
 "Children's Place downgraded to neutral from outperform at Wedbush, price target slashed to $60 from $130",
 'Clovis Oncology downgraded to in line from outperform at Evercore ISI',
 'Downgrades 4/7: $AAN $BDN $BECN $BTE $CDEV $CHK $COOP $CPE $CVA $DAN $DOC $DRH $EPR $ESRT $ETM $FAST $FBM $GM $GMS‚Ä¶',
 "Goldman pulls Progressive from Goldman's conviction list; shares -2.7%",
 'Hanesbrands downgraded to underperform vs. neutral at BofA Merrill Lynch',
 'Intelsat cut to M

In [26]:
# Inference WITHOUT SAEs - Plain Model Accuracy on Test Data
# import torch
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# from tqdm import tqdm

print("=" * 60)
print("MODEL INFERENCE WITHOUT SAEs")
print("=" * 60)

# Load the fine-tuned model
save_dir = "./finbert_twitter_ft/best"
tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSequenceClassification.from_pretrained(save_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Use validation set for evaluation
test_ds = ds["validation"]
# MAX_SAMPLES = len(test_ds)  # Process all samples, or set a limit if needed
MAX_SAMPLES = 100
MAX_SEQ_LENGTH = 64

print(f"\nüî¨ Running inference on {MAX_SAMPLES} test samples...")
print(f"   Device: {device}")
print(f"   Model: {save_dir}\n")

correct_predictions = 0
total_predictions = 0

# Process samples
with torch.no_grad():
    for idx, sample in enumerate(tqdm(test_ds, desc="Processing")):
        if idx >= MAX_SAMPLES:
            break
        
        text = sample["text"]
        true_label = sample["label"]
        
        # Tokenize with truncation
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
        inputs = inputs.to(device)
        
        # Forward pass
        outputs = model(**inputs)
        pred_id = outputs.logits.argmax(dim=-1).item()
        
        # Check if prediction is correct
        if pred_id == true_label:
            correct_predictions += 1
        total_predictions += 1

# Calculate accuracy
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0

print(f"\n{'=' * 60}")
print(f"‚úÖ INFERENCE COMPLETE (WITHOUT SAEs)")
print(f"{'=' * 60}")
print(f"   üìä Total Samples: {total_predictions}")
print(f"   ‚úì Correct Predictions: {correct_predictions}")
print(f"   ‚úó Incorrect Predictions: {total_predictions - correct_predictions}")
print(f"   üéØ Model Accuracy: {accuracy:.2%}")
print(f"{'=' * 60}")


MODEL INFERENCE WITHOUT SAEs

üî¨ Running inference on 100 test samples...
   Device: cuda
   Model: ./finbert_twitter_ft/best



Processing:   4%|‚ñç         | 100/2388 [00:01<00:25, 91.45it/s]


‚úÖ INFERENCE COMPLETE (WITHOUT SAEs)
   üìä Total Samples: 100
   ‚úì Correct Predictions: 87
   ‚úó Incorrect Predictions: 13
   üéØ Model Accuracy: 87.00%



