In [4]:
# --- 1. Imports and Setup ---
import sys
import os
import random
import numpy as np
from dataclasses import dataclass
from collections import defaultdict
import matplotlib.pyplot as plt

# Add project root to path and import local modules
sys.path.append('../../')
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from models import ColorMNISTCNN
from data import BiasedColorizedMNIST, UnbiasedColorizedMNIST, CNNActivationDatasetWithColors
from train_models import train_model

def set_seed(seed_value=42):
    """Sets the seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)

# --- 2. New SCAR-Specific Model and Training Functions ---

class ScarSAE(nn.Module):
    """
    An SAE implementation that follows the SCAR paper's architecture.
    It includes a specific forward pass for steering.
    """
    def __init__(self, input_dim, latent_dim, k, steer_idx=0):
        super().__init__()
        self.encoder = nn.Linear(input_dim, latent_dim)
        self.decoder = nn.Linear(latent_dim, input_dim, bias=False)
        self.k = k
        self.steer_idx = steer_idx # The neuron we will condition and steer

        with torch.no_grad():
            self.decoder.weight.data = F.normalize(self.decoder.weight.data, p=2, dim=1)

    def forward(self, x, alpha=1.0):
        # Encode
        pre_activations = self.encoder(x) # This is 'h' in the paper
        
        # Sparsify with TopK + ReLU
        topk_values, _ = torch.topk(pre_activations, self.k, dim=-1)
        threshold = topk_values[:, -1].unsqueeze(-1)
        sparse_activations = F.relu(pre_activations - threshold) # This is 'f' in the paper
        
        # --- SCAR Steering Logic (from Eq. 7) ---
        # If steering, replace the activated value of the conditioned neuron
        # with the scaled pre-activation value.
        if alpha != 1.0:
            sparse_activations[:, self.steer_idx] = alpha * pre_activations[:, self.steer_idx]

        # Decode
        reconstruction = self.decoder(sparse_activations)
        return reconstruction, pre_activations

# --- 3. Main Experiment Script ---
# Configuration
@dataclass
class SCAR_Config:
    seed: int = 42
    device: str = 'cuda'
    cnn_lr: float = 0.001
    cnn_epochs: int = 50
    sae_lr: float = 1e-4 # SCAR paper uses a smaller LR
    sae_steps: int = 15000
    batch_size: int = 512
    concepts: int = 1024
    k: int = 128

# Setup
config = SCAR_Config()
set_seed(config.seed)
device = config.device
print(f"Using device: {device}")

# Data Loading
print("--- Loading Datasets ---")
biased_train_dataset = BiasedColorizedMNIST('../../colorized-MNIST/training')
unbiased_val_dataset = UnbiasedColorizedMNIST('../../colorized-MNIST/testing')
train_loader = DataLoader(biased_train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(unbiased_val_dataset, batch_size=config.batch_size, shuffle=False)
print("✅ Datasets loaded.")

# Train Biased CNN
print("\n--- 🧠 Training Baseline CNN on Biased Data ---")
biased_model = ColorMNISTCNN(input_size=28).to(device)
train_model(biased_model, train_loader, val_loader, num_epochs=config.cnn_epochs, learning_rate=config.cnn_lr, device=device)
print("✅ Baseline CNN training complete.")


Using device: cuda
--- Loading Datasets ---
Loaded 2957 biased images

Dataset Statistics:
------------------------------
Digit 0 (red): 298 images
Digit 1 (red): 329 images
Digit 2 (red): 298 images
Digit 3 (red): 311 images
Digit 4 (red): 258 images
Digit 5 (green): 258 images
Digit 6 (green): 302 images
Digit 7 (green): 328 images
Digit 8 (green): 282 images
Digit 9 (green): 293 images

Total red images (digits 0-4): 1494
Total green images (digits 5-9): 1463
Total images: 2957
Loaded 6711 unbiased images

Unbiased Dataset Statistics:
----------------------------------------
Per-digit breakdown:
  Digit 0: 321 red, 324 green (total: 645)
  Digit 1: 377 red, 392 green (total: 769)
  Digit 2: 353 red, 320 green (total: 673)
  Digit 3: 350 red, 326 green (total: 676)
  Digit 4: 363 red, 315 green (total: 678)
  Digit 5: 301 red, 291 green (total: 592)
  Digit 6: 322 red, 324 green (total: 646)
  Digit 7: 345 red, 360 green (total: 705)
  Digit 8: 308 red, 338 green (total: 646)
  Digit


poch 1/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.83it/s, Loss=2.2269, Acc=17.35%]

Epoch [1/50]:
  Train Loss: 2.2760, Train Acc: 17.35%
  Val Loss: 2.3046, Val Acc: 15.42%
--------------------------------------------------



poch 2/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.86it/s, Loss=1.9923, Acc=22.83%]

Epoch [2/50]:
  Train Loss: 2.0939, Train Acc: 22.83%
  Val Loss: 2.5626, Val Acc: 11.13%
--------------------------------------------------



poch 3/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.86it/s, Loss=1.9167, Acc=20.97%]

Epoch [3/50]:
  Train Loss: 1.9078, Train Acc: 20.97%
  Val Loss: 2.9412, Val Acc: 13.16%
--------------------------------------------------



poch 4/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.87it/s, Loss=1.7628, Acc=21.58%]

Epoch [4/50]:
  Train Loss: 1.8115, Train Acc: 21.58%
  Val Loss: 2.8342, Val Acc: 17.93%
--------------------------------------------------



poch 5/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.80it/s, Loss=1.7641, Acc=21.41%]

Epoch [5/50]:
  Train Loss: 1.7730, Train Acc: 21.41%
  Val Loss: 2.9083, Val Acc: 14.71%
--------------------------------------------------



poch 6/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=1.7125, Acc=24.82%]

Epoch [6/50]:
  Train Loss: 1.7074, Train Acc: 24.82%
  Val Loss: 3.1642, Val Acc: 17.00%
--------------------------------------------------



poch 7/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.83it/s, Loss=1.6094, Acc=26.72%]

Epoch [7/50]:
  Train Loss: 1.6614, Train Acc: 26.72%
  Val Loss: 3.2016, Val Acc: 18.48%
--------------------------------------------------



poch 8/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.45it/s, Loss=1.5849, Acc=29.02%]

Epoch [8/50]:
  Train Loss: 1.6027, Train Acc: 29.02%
  Val Loss: 3.2792, Val Acc: 21.07%
--------------------------------------------------



poch 9/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=1.5340, Acc=31.18%]

Epoch [9/50]:
  Train Loss: 1.5548, Train Acc: 31.18%
  Val Loss: 3.4180, Val Acc: 21.37%
--------------------------------------------------



poch 10/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.84it/s, Loss=1.4937, Acc=32.84%]

Epoch [10/50]:
  Train Loss: 1.5122, Train Acc: 32.84%
  Val Loss: 3.4134, Val Acc: 26.85%
--------------------------------------------------



poch 11/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=1.4516, Acc=35.27%]

Epoch [11/50]:
  Train Loss: 1.4630, Train Acc: 35.27%
  Val Loss: 3.7741, Val Acc: 21.56%
--------------------------------------------------



poch 12/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.83it/s, Loss=1.4274, Acc=36.46%]

Epoch [12/50]:
  Train Loss: 1.4248, Train Acc: 36.46%
  Val Loss: 3.7181, Val Acc: 25.48%
--------------------------------------------------



poch 13/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.91it/s, Loss=1.3774, Acc=39.26%]

Epoch [13/50]:
  Train Loss: 1.3779, Train Acc: 39.26%
  Val Loss: 3.9468, Val Acc: 23.47%
--------------------------------------------------



poch 14/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.83it/s, Loss=1.3880, Acc=39.23%]

Epoch [14/50]:
  Train Loss: 1.3507, Train Acc: 39.23%
  Val Loss: 4.1086, Val Acc: 25.47%
--------------------------------------------------



poch 15/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.83it/s, Loss=1.2920, Acc=42.81%]

Epoch [15/50]:
  Train Loss: 1.2985, Train Acc: 42.81%
  Val Loss: 4.2364, Val Acc: 28.57%
--------------------------------------------------



poch 16/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.78it/s, Loss=1.2525, Acc=43.05%]

Epoch [16/50]:
  Train Loss: 1.2784, Train Acc: 43.05%
  Val Loss: 4.6302, Val Acc: 26.46%
--------------------------------------------------



poch 17/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=1.2961, Acc=45.15%]

Epoch [17/50]:
  Train Loss: 1.2453, Train Acc: 45.15%
  Val Loss: 4.7328, Val Acc: 28.97%
--------------------------------------------------



poch 18/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=1.2442, Acc=46.57%]

Epoch [18/50]:
  Train Loss: 1.2203, Train Acc: 46.57%
  Val Loss: 4.8340, Val Acc: 27.61%
--------------------------------------------------



poch 19/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.74it/s, Loss=1.2396, Acc=46.97%]

Epoch [19/50]:
  Train Loss: 1.2152, Train Acc: 46.97%
  Val Loss: 4.9603, Val Acc: 29.89%
--------------------------------------------------



poch 20/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=1.1450, Acc=51.47%]

Epoch [20/50]:
  Train Loss: 1.1525, Train Acc: 51.47%
  Val Loss: 5.2659, Val Acc: 30.34%
--------------------------------------------------



poch 21/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.77it/s, Loss=1.1560, Acc=50.49%]

Epoch [21/50]:
  Train Loss: 1.1417, Train Acc: 50.49%
  Val Loss: 5.4291, Val Acc: 33.26%
--------------------------------------------------



poch 22/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.81it/s, Loss=1.1020, Acc=54.38%]

Epoch [22/50]:
  Train Loss: 1.1157, Train Acc: 54.38%
  Val Loss: 5.6064, Val Acc: 33.75%
--------------------------------------------------



poch 23/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.84it/s, Loss=1.0791, Acc=53.47%]

Epoch [23/50]:
  Train Loss: 1.0833, Train Acc: 53.47%
  Val Loss: 5.6013, Val Acc: 34.03%
--------------------------------------------------



poch 24/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.79it/s, Loss=1.0355, Acc=56.17%]

Epoch [24/50]:
  Train Loss: 1.0480, Train Acc: 56.17%
  Val Loss: 5.8276, Val Acc: 33.18%
--------------------------------------------------



poch 25/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.86it/s, Loss=1.0183, Acc=57.96%]

Epoch [25/50]:
  Train Loss: 1.0259, Train Acc: 57.96%
  Val Loss: 5.8604, Val Acc: 36.49%
--------------------------------------------------



poch 26/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.78it/s, Loss=0.9349, Acc=58.88%]

Epoch [26/50]:
  Train Loss: 0.9919, Train Acc: 58.88%
  Val Loss: 5.9469, Val Acc: 33.90%
--------------------------------------------------



poch 27/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=0.9403, Acc=60.26%]

Epoch [27/50]:
  Train Loss: 0.9857, Train Acc: 60.26%
  Val Loss: 6.0250, Val Acc: 36.19%
--------------------------------------------------



poch 28/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.85it/s, Loss=0.8958, Acc=62.43%]

Epoch [28/50]:
  Train Loss: 0.9524, Train Acc: 62.43%
  Val Loss: 6.1879, Val Acc: 37.89%
--------------------------------------------------



poch 29/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=0.8865, Acc=64.19%]

Epoch [29/50]:
  Train Loss: 0.9055, Train Acc: 64.19%
  Val Loss: 6.3415, Val Acc: 38.06%
--------------------------------------------------



poch 30/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.81it/s, Loss=0.8306, Acc=65.78%]

Epoch [30/50]:
  Train Loss: 0.8858, Train Acc: 65.78%
  Val Loss: 6.6029, Val Acc: 37.95%
--------------------------------------------------



poch 31/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.86it/s, Loss=0.9154, Acc=66.35%]

Epoch [31/50]:
  Train Loss: 0.8575, Train Acc: 66.35%
  Val Loss: 6.6351, Val Acc: 40.10%
--------------------------------------------------



poch 32/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.77it/s, Loss=0.8764, Acc=67.81%]

Epoch [32/50]:
  Train Loss: 0.8353, Train Acc: 67.81%
  Val Loss: 6.6504, Val Acc: 36.61%
--------------------------------------------------



poch 33/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.75it/s, Loss=0.8144, Acc=67.81%]

Epoch [33/50]:
  Train Loss: 0.8153, Train Acc: 67.81%
  Val Loss: 6.8243, Val Acc: 40.08%
--------------------------------------------------



poch 34/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.91it/s, Loss=0.7237, Acc=70.48%]

Epoch [34/50]:
  Train Loss: 0.7767, Train Acc: 70.48%
  Val Loss: 6.9369, Val Acc: 40.49%
--------------------------------------------------



poch 35/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.85it/s, Loss=0.7874, Acc=71.69%]

Epoch [35/50]:
  Train Loss: 0.7499, Train Acc: 71.69%
  Val Loss: 7.0261, Val Acc: 40.62%
--------------------------------------------------



poch 36/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.93it/s, Loss=0.7013, Acc=73.55%]

Epoch [36/50]:
  Train Loss: 0.7100, Train Acc: 73.55%
  Val Loss: 7.2028, Val Acc: 42.07%
--------------------------------------------------



poch 37/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.92it/s, Loss=0.6725, Acc=73.93%]

Epoch [37/50]:
  Train Loss: 0.6922, Train Acc: 73.93%
  Val Loss: 7.3794, Val Acc: 41.90%
--------------------------------------------------



poch 38/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.90it/s, Loss=0.6960, Acc=75.01%]

Epoch [38/50]:
  Train Loss: 0.6713, Train Acc: 75.01%
  Val Loss: 7.2781, Val Acc: 42.41%
--------------------------------------------------



poch 39/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.89it/s, Loss=0.6674, Acc=73.99%]

Epoch [39/50]:
  Train Loss: 0.6770, Train Acc: 73.99%
  Val Loss: 7.5035, Val Acc: 41.77%
--------------------------------------------------



poch 40/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.88it/s, Loss=0.6468, Acc=75.55%]

Epoch [40/50]:
  Train Loss: 0.6578, Train Acc: 75.55%
  Val Loss: 7.4580, Val Acc: 41.65%
--------------------------------------------------



poch 41/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.81it/s, Loss=0.6059, Acc=75.79%]

Epoch [41/50]:
  Train Loss: 0.6520, Train Acc: 75.79%
  Val Loss: 7.3785, Val Acc: 42.81%
--------------------------------------------------



poch 42/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.81it/s, Loss=0.6285, Acc=77.75%]

Epoch [42/50]:
  Train Loss: 0.6110, Train Acc: 77.75%
  Val Loss: 7.5573, Val Acc: 43.79%
--------------------------------------------------



poch 43/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.82it/s, Loss=0.5889, Acc=78.46%]

Epoch [43/50]:
  Train Loss: 0.5933, Train Acc: 78.46%
  Val Loss: 7.3819, Val Acc: 43.72%
--------------------------------------------------



poch 44/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.75it/s, Loss=0.5279, Acc=79.51%]

Epoch [44/50]:
  Train Loss: 0.5709, Train Acc: 79.51%
  Val Loss: 7.3670, Val Acc: 42.36%
--------------------------------------------------



poch 45/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.81it/s, Loss=0.5990, Acc=78.29%]

Epoch [45/50]:
  Train Loss: 0.5861, Train Acc: 78.29%
  Val Loss: 7.5807, Val Acc: 43.81%
--------------------------------------------------



poch 46/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.80it/s, Loss=0.5641, Acc=78.63%]

Epoch [46/50]:
  Train Loss: 0.5574, Train Acc: 78.63%
  Val Loss: 7.5252, Val Acc: 43.93%
--------------------------------------------------



poch 47/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.71it/s, Loss=0.5579, Acc=79.81%]

Epoch [47/50]:
  Train Loss: 0.5578, Train Acc: 79.81%
  Val Loss: 7.5244, Val Acc: 44.21%
--------------------------------------------------



poch 48/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.74it/s, Loss=0.4879, Acc=80.25%]

Epoch [48/50]:
  Train Loss: 0.5224, Train Acc: 80.25%
  Val Loss: 7.8054, Val Acc: 44.27%
--------------------------------------------------



poch 49/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.75it/s, Loss=0.4881, Acc=79.61%]

Epoch [49/50]:
  Train Loss: 0.5322, Train Acc: 79.61%
  Val Loss: 7.9068, Val Acc: 44.79%
--------------------------------------------------



poch 50/50 - Training: 100%|██████████| 6/6 [00:00<00:00,  8.76it/s, Loss=0.5859, Acc=80.83%]

Epoch [50/50]:
  Train Loss: 0.5255, Train Acc: 80.83%
  Val Loss: 7.8320, Val Acc: 44.46%
--------------------------------------------------
✅ Baseline CNN training complete.


In [7]:
class ScarSAE(nn.Module):
    """
    A corrected SAE that maps from an input_dim to an output_dim,
    replicating the transformation of the layer it replaces.
    """
    def __init__(self, input_dim, output_dim, latent_dim, k, steer_idx=0):
        super().__init__()
        self.encoder = nn.Linear(input_dim, latent_dim)
        # CORRECTED: Decoder projects to the target layer's output dimension
        self.decoder = nn.Linear(latent_dim, output_dim, bias=False)
        self.k = k
        self.steer_idx = steer_idx

        with torch.no_grad():
            self.decoder.weight.data = F.normalize(self.decoder.weight.data, p=2, dim=1)

    def forward(self, x, alpha=1.0):
        pre_activations = self.encoder(x)
        
        topk_values, _ = torch.topk(pre_activations, self.k, dim=-1)
        threshold = topk_values[:, -1].unsqueeze(-1)
        sparse_activations = F.relu(pre_activations - threshold)
        
        if alpha != 1.0:
            sparse_activations[:, self.steer_idx] = alpha * pre_activations[:, self.steer_idx]

        reconstruction = self.decoder(sparse_activations)
        return reconstruction, pre_activations

def train_scar_module(sae: ScarSAE, dataset, device, steps, lr, batch_size):
    """
    Trains the SCAR module. The reconstruction target is now correctly
    set to the output of the original fc1 layer.
    """
    optimizer = torch.optim.Adam(sae.parameters(), lr=lr)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    pbar = tqdm(range(steps), desc='Training SCAR Module')
    step = 0
    while step < steps:
        for fc1_acts, fc2_targets, _, color_label_strings in dataloader:
            if step >= steps:
                break
            
            fc1_acts, fc2_targets = fc1_acts.to(device), fc2_targets.to(device)
            concept_labels = torch.tensor([1.0 if c == 'red' else 0.0 for c in color_label_strings], device=device)
            
            # The SAE takes the fc1_input
            reconstruction, pre_activations = sae(fc1_acts)
            
            # --- THIS IS THE CORRECTED SECTION ---
            # Lr: The reconstruction loss must be calculated against the true target,
            # which is the output of the original layer (`fc2_targets`).
            recon_error = reconstruction - fc2_targets 
            recon_loss = (recon_error.norm(p=2, dim=-1)**2 / (fc2_targets.norm(p=2, dim=-1)**2 + 1e-8)).mean()
            # --- END CORRECTION ---

            # Lc: Conditioning Loss on h0
            h0 = pre_activations[:, sae.steer_idx]
            cond_loss = F.binary_cross_entropy_with_logits(h0, concept_labels)

            loss = recon_loss + cond_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if step % 100 == 0:
                pbar.set_postfix(recon_loss=f"{recon_loss.item():.4f}", cond_loss=f"{cond_loss.item():.4f}")
            pbar.update(1)
            step += 1

# Train SCAR Module
print("\n--- 🛠️ Training SCAR Module ---")
print("Creating activation dataset from unbiased data...")
activation_dataset = CNNActivationDatasetWithColors(model=biased_model, biased_dataset=unbiased_val_dataset, device=device)

scar_sae = ScarSAE(
    input_dim=128, 
    output_dim=64, # The output dimension of the fc1 layer
    latent_dim=config.concepts, 
    k=config.k
).to(device)
train_scar_module(
    sae=scar_sae,
    dataset=activation_dataset,
    device=device,
    steps=config.sae_steps,
    lr=config.sae_lr,
    batch_size=config.batch_size,
)
print("✅ SCAR module training complete.")

# --- 4. Final Evaluation and Comparison ---
print("\n--- 📊 Final Evaluation: Original vs. SCAR ---")
model_names = ['Original', 'SCAR (Reconstruction)']
partitions = ['overall', 'red_low', 'red_high', 'green_low', 'green_high']
final_results = {model: {part: {'correct': 0, 'total': 0} for part in partitions} for model in model_names}

biased_model.eval()
scar_sae.eval()

for images, labels, colors in tqdm(val_loader, desc="Final Evaluation"):
    batch, labels = images.to(device), labels.to(device)
    
    with torch.no_grad():
        # Original Model Prediction
        preds_original = torch.max(biased_model(batch), 1)[1]
        
        # SCAR Model Prediction (replacing FFN output with SAE reconstruction)
        x = biased_model.pool(F.relu(biased_model.conv1(batch)))
        x = biased_model.pool(F.relu(biased_model.conv2(x)))
        x = biased_model.pool(F.relu(biased_model.conv3(x)))
        x = biased_model.adaptive_pool(x)
        fc1_input = x.view(x.size(0), -1)
        
        sae_reconstruction, _ = scar_sae(fc1_input, alpha=1.0) # alpha=1 means no steering, just reconstruction
        final_output_scar = biased_model.fc2(F.relu(sae_reconstruction))
        preds_scar = torch.max(final_output_scar, 1)[1]

    # Collate results
    for i in range(batch.shape[0]):
        partition_key = f"{colors[i]}_{'high' if labels[i].item() >= 5 else 'low'}"
        preds_dict = {'Original': preds_original[i], 'SCAR (Reconstruction)': preds_scar[i]}
        for name, pred in preds_dict.items():
            is_correct = (pred == labels[i]).item()
            final_results[name][partition_key]['correct'] += is_correct
            final_results[name][partition_key]['total'] += 1
            final_results[name]['overall']['correct'] += is_correct
            final_results[name]['overall']['total'] += 1

# --- 5. Print Results Table ---
def format_cell(data):
    if data['total'] == 0: return "0/0 (N/A)"
    acc = 100 * data['correct'] / data['total']
    return f"{data['correct']}/{data['total']} ({acc:.1f}%)"

header = f"{'Partition':<15}" + "".join([f"{name:<28}" for name in model_names])
results_str = "="*len(header) + "\n"
results_str += "FINAL ACCURACY RESULTS (Original vs. SCAR)\n"
results_str += "="*len(header) + "\n" + header + "\n" + "-"*len(header) + "\n"
for partition_name in partitions:
    row_str = f"{partition_name.replace('_', ' ').title():<15}"
    for model_name in model_names:
        row_str += f"{format_cell(final_results[model_name][partition_name]):<28}"
    results_str += row_str + "\n"
results_str += "="*len(header) + "\n"
    
print("\n" + results_str)


--- 🛠️ Training SCAR Module ---
Creating activation dataset from unbiased data...





[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A






[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


Training SCAR Module:   0%|          | 0/15000 [20:17<?, ?it/s]230.41it/s, cond_loss=0.0404, recon_loss=0.0057]




[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A




✅ SCAR module training complete.

--- 📊 Final Evaluation: Original vs. SCAR ---



inal Evaluation:   0%|          | 0/14 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x128 and 64x10)