In [52]:
import sys
import os
import yaml
import random
import numpy as np
from dataclasses import dataclass, asdict

# Add project root to path
sys.path.append('../../')

from tqdm import tqdm
from Faithful_SAE.models import Faithful_SAE 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from models import ColorMNISTCNN
from data import BiasedColorizedMNIST, UnbiasedColorizedMNIST, CNNActivationDatasetWithColors
from train_models import train_model, train_cnn_sae_with_color_conditioning
from data import create_biased_dataset, create_unbiased_dataset
from torch.utils.data import random_split
from run import create_ablated_model

In [54]:
import yaml 
device = 'cuda' 
with open('best_run_artifacts/best_config.yaml', 'r') as file:
    data = yaml.safe_load(file)
print(data)

cnn = ColorMNISTCNN(input_size=28).to(device)
sae = Faithful_SAE(input_dim=128, latent_dim=data['concepts'], hidden_dim=64, k=data['k'], use_topk=True).to(device)

cnn_dict = torch.load('best_run_artifacts/cnn_model.pth') 
sae_dict = torch.load('best_run_artifacts/sae_color_good_one.pth') 

cnn.load_state_dict(cnn_dict)
sae.load_state_dict(sae_dict) 


ablated_model = create_ablated_model(cnn, sae, indices_to_ablate=[0, 1], device=device)
ablated_model.eval()



correct = 0
total = 0
unbiased_val_dataset = UnbiasedColorizedMNIST('../../colorized-MNIST/testing')
unbiased_val_loader = DataLoader(unbiased_val_dataset, batch_size=512, shuffle=False)


{'concepts': 1024, 'cond_lam': 0.5016787095206698, 'faithful_lam': 1.168462365792744, 'k': 16, 'l1_lam': 1.447664740275463, 'recon_lam': 1.7964982558265474, 'sae_lr': 0.0007095408650229278, 'sae_steps': 32000}
Loaded 6711 unbiased images

Unbiased Dataset Statistics:
----------------------------------------
Per-digit breakdown:
  Digit 0: 321 red, 324 green (total: 645)
  Digit 1: 377 red, 392 green (total: 769)
  Digit 2: 353 red, 320 green (total: 673)
  Digit 3: 350 red, 326 green (total: 676)
  Digit 4: 363 red, 315 green (total: 678)
  Digit 5: 301 red, 291 green (total: 592)
  Digit 6: 322 red, 324 green (total: 646)
  Digit 7: 345 red, 360 green (total: 705)
  Digit 8: 308 red, 338 green (total: 646)
  Digit 9: 332 red, 349 green (total: 681)

Overall:
  Total red images: 3372
  Total green images: 3339
  Total images: 6711
  Red/Green ratio: 0.50/0.50


In [55]:

with torch.no_grad():
    for images, labels, _ in unbiased_val_loader:
        outputs = ablated_model(images.to(device))
        preds = torch.max(outputs, 1)[1]
        correct += (preds == labels.to(device)).sum().item()
        total += labels.size(0)

print(f'Accuracy: {correct/total}') 
        

Accuracy: 0.5108031589926986


In [56]:
def create_sae_replaced_model(original_model, sae, device):
    """
    Creates a new model where the FC1 layer is fully REPLACED by the
    SAE's effective encoder.
    """
    # Create a deep copy to avoid altering the original model
    sae_replaced_model = type(original_model)()
    sae_replaced_model.load_state_dict(original_model.state_dict())
    
    with torch.no_grad():
        # The new weights are simply the SAE's effective encoder
        new_weights = sae.effective_encoder().to(device)
        
        # Transpose and assign to the new model's fc1 layer
        sae_replaced_model.fc1.weight.data = new_weights.T.clone()
    
    sae_replaced_model.to(device)
    return sae_replaced_model


In [57]:
import torch
from collections import defaultdict


def create_sae_replaced_model(original_model, sae, device):
    """
    Creates a new model where the FC1 layer is fully REPLACED by the
    SAE's effective encoder.
    """
    # Create a deep copy to avoid altering the original model
    sae_replaced_model = type(original_model)()
    sae_replaced_model.load_state_dict(original_model.state_dict())
    
    with torch.no_grad():
        # The new weights are simply the SAE's effective encoder
        new_weights = sae.effective_encoder().to(device)
        
        # Transpose and assign to the new model's fc1 layer
        sae_replaced_model.fc1.weight.data = new_weights.T.clone()
    
    sae_replaced_model.to(device)
    return sae_replaced_model

    
# --- Assumed variables ---
# cnn: Your original trained CNN model
# ablated_model: The model created with create_ablated_model
# unbiased_val_loader: The DataLoader for the unbiased validation set
# device: Your 'cuda' or 'cpu' device

# --- Evaluation Setup ---
model_dict = {'Original': cnn, 'Ablated': ablated_model}
partitions = ['overall', 'red_low', 'red_high', 'green_low', 'green_high']
results = {name: {part: {'correct': 0, 'total': 0} for part in partitions} for name in model_dict}

# Set models to evaluation mode
for model in model_dict.values():
    model.eval()

# --- Evaluation Loop ---
print("📊 Evaluating models...")
with torch.no_grad():
    for images, labels, colors in unbiased_val_loader:
        batch = images.to(device)
        
        # Get predictions from both models
        preds = {name: torch.max(model(batch), 1)[1] for name, model in model_dict.items()}
        
        # Collate results for each sample in the batch
        for i in range(len(labels)):
            label = labels[i].item()
            color = colors[i]
            partition_key = f"{color}_{'high' if label >= 5 else 'low'}"
            
            for name, pred_tensor in preds.items():
                is_correct = (pred_tensor[i].item() == label)
                
                # Update partition and overall counts
                results[name][partition_key]['correct'] += is_correct
                results[name][partition_key]['total'] += 1
                results[name]['overall']['correct'] += is_correct
                results[name]['overall']['total'] += 1

# --- Display Results Table ---
print("\n--- Accuracy Comparison: Original vs. Ablated ---")

def calculate_accuracy(data):
    if data['total'] == 0:
        return 0.0
    return (data['correct'] / data['total']) * 100

header = f"{'Partition':<15} | {'Original Model':<22} | {'Ablated Model':<22} | {'Difference':<12}"
print(header)
print("=" * len(header))

for part in partitions:
    original_data = results['Original'][part]
    ablated_data = results['Ablated'][part]
    
    original_acc = calculate_accuracy(original_data)
    ablated_acc = calculate_accuracy(ablated_data)
    
    diff = ablated_acc - original_acc
    
    original_str = f"{original_data['correct']}/{original_data['total']} ({original_acc:.1f}%)"
    ablated_str = f"{ablated_data['correct']}/{ablated_data['total']} ({ablated_acc:.1f}%)"
    
    print(f"{part.replace('_', ' ').title():<15} | {original_str:<22} | {ablated_str:<22} | {diff:+#.2f}%")

📊 Evaluating models...

--- Accuracy Comparison: Original vs. Ablated ---
Partition       | Original Model         | Ablated Model          | Difference  
Overall         | 3197/6711 (47.6%)      | 3428/6711 (51.1%)      | +3.44%
Red Low         | 1698/1764 (96.3%)      | 1522/1764 (86.3%)      | -9.98%
Red High        | 0/1608 (0.0%)          | 335/1608 (20.8%)       | +20.83%
Green Low       | 0/1677 (0.0%)          | 126/1677 (7.5%)        | +7.51%
Green High      | 1499/1662 (90.2%)      | 1445/1662 (86.9%)      | -3.25%


In [58]:
preds

{'Original': tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
         9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
         9, 9, 9, 9, 9, 9, 9], device='cuda:0'),
 'Ablated': tensor([9, 9, 9, 9, 9, 9, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 9, 7,
         9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 9, 9, 9, 9, 9,
         9, 9, 9, 9, 9, 9, 9], device='cuda:0')}

In [59]:
import torch
from collections import defaultdict

# --- Assumed variables ---
# cnn: Your original trained CNN model
# sae: Your trained Faithful_SAE model
# ablated_model: The model created with your original create_ablated_model
# unbiased_val_loader: The DataLoader for the unbiased validation set
# device: Your 'cuda' or 'cpu' device

# --- Create the new SAE-Replaced model ---
sae_replaced_model = create_sae_replaced_model(cnn, sae, device)

# --- Evaluation Setup ---
model_dict = {
    'Original': cnn,
    'SAE (Replaced)': sae_replaced_model,
    'Ablated': ablated_model
}
partitions = ['overall', 'red_low', 'red_high', 'green_low', 'green_high']
results = {name: {part: {'correct': 0, 'total': 0} for part in partitions} for name in model_dict}

# Set models to evaluation mode
for model in model_dict.values():
    model.eval()

# --- Evaluation Loop ---
print("📊 Evaluating all three models...")
with torch.no_grad():
    for images, labels, colors in unbiased_val_loader:
        batch = images.to(device)
        
        # Get predictions from all models
        preds = {name: torch.max(model(batch), 1)[1] for name, model in model_dict.items()}
        
        # Collate results for each sample in the batch
        for i in range(len(labels)):
            label = labels[i].item()
            color = colors[i]
            partition_key = f"{color}_{'high' if label >= 5 else 'low'}"
            
            for name, pred_tensor in preds.items():
                is_correct = (pred_tensor[i].item() == label)
                
                # Update partition and overall counts
                results[name][partition_key]['correct'] += is_correct
                results[name][partition_key]['total'] += 1
                results[name]['overall']['correct'] += is_correct
                results[name]['overall']['total'] += 1

# --- Display Results Table ---
print("\n--- Accuracy Comparison: Original vs. SAE vs. Ablated ---")

def calculate_accuracy(data):
    if data['total'] == 0:
        return 0.0
    return (data['correct'] / data['total']) * 100

header = f"{'Partition':<15} | {'Original':<22} | {'SAE (Replaced)':<22} | {'Ablated':<22} | {'SAE Δ':<10} | {'Ablated Δ':<10}"
print(header)
print("=" * len(header))

for part in partitions:
    original_data = results['Original'][part]
    sae_data = results['SAE (Replaced)'][part]
    ablated_data = results['Ablated'][part]
    
    original_acc = calculate_accuracy(original_data)
    sae_acc = calculate_accuracy(sae_data)
    ablated_acc = calculate_accuracy(ablated_data)
    
    sae_diff = sae_acc - original_acc
    ablated_diff = ablated_acc - original_acc
    
    original_str = f"{original_data['correct']}/{original_data['total']} ({original_acc:.1f}%)"
    sae_str = f"{sae_data['correct']}/{sae_data['total']} ({sae_acc:.1f}%)"
    ablated_str = f"{ablated_data['correct']}/{ablated_data['total']} ({ablated_acc:.1f}%)"
    
    print(f"{part.replace('_', ' ').title():<15} | {original_str:<22} | {sae_str:<22} | {ablated_str:<22} | {sae_diff:+#.2f}% | {ablated_diff:+#.2f}%")

📊 Evaluating all three models...

--- Accuracy Comparison: Original vs. SAE vs. Ablated ---
Partition       | Original               | SAE (Replaced)         | Ablated                | SAE Δ      | Ablated Δ 
Overall         | 3197/6711 (47.6%)      | 3056/6711 (45.5%)      | 3428/6711 (51.1%)      | -2.10% | +3.44%
Red Low         | 1698/1764 (96.3%)      | 1614/1764 (91.5%)      | 1522/1764 (86.3%)      | -4.76% | -9.98%
Red High        | 0/1608 (0.0%)          | 0/1608 (0.0%)          | 335/1608 (20.8%)       | +0.00% | +20.83%
Green Low       | 0/1677 (0.0%)          | 0/1677 (0.0%)          | 126/1677 (7.5%)        | +0.00% | +7.51%
Green High      | 1499/1662 (90.2%)      | 1442/1662 (86.8%)      | 1445/1662 (86.9%)      | -3.43% | -3.25%


In [60]:
cnn.fc1.weight.data.T

tensor([[-0.0199, -0.0373, -0.0625,  ..., -0.0618, -0.1308, -0.0772],
        [-0.0509,  0.0072,  0.0631,  ..., -0.0484,  0.0748,  0.0504],
        [ 0.0215, -0.0079,  0.0697,  ..., -0.0339,  0.0432,  0.0050],
        ...,
        [ 0.0767, -0.0002,  0.1014,  ..., -0.0099, -0.0533,  0.0427],
        [ 0.0334,  0.0177, -0.0490,  ...,  0.0183, -0.0065,  0.0061],
        [-0.0670, -0.0131, -0.0083,  ..., -0.0609,  0.0556, -0.0289]],
       device='cuda:0')

In [61]:
sae.effective_encoder()

tensor([[-0.0195, -0.0360, -0.0605,  ..., -0.0585, -0.1310, -0.0720],
        [-0.0509,  0.0072,  0.0631,  ..., -0.0484,  0.0748,  0.0504],
        [ 0.0434,  0.0088,  0.0719,  ..., -0.0323,  0.1153,  0.0668],
        ...,
        [ 0.0767, -0.0002,  0.1014,  ..., -0.0099, -0.0533,  0.0427],
        [ 0.0270,  0.0018, -0.0600,  ...,  0.0131,  0.0060,  0.0192],
        [-0.0294, -0.0140,  0.0376,  ..., -0.0211,  0.0848, -0.0471]],
       device='cuda:0', grad_fn=<SumBackward1>)