In [1]:
import math
import os
from pathlib import Path
from typing import Callable, Optional
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm
from wilds import get_dataset

from models import WaterbirdResNet18, SPDTwoLayerFC
from spd.run_spd import get_lr_schedule_fn, get_lr_with_warmup
from spd.hooks import HookedRootModule
from spd.log import logger
from spd.models.base import SPDModel
from spd.module_utils import (
    get_nested_module_attr,
    collect_nested_module_attrs,
)
from spd.types import Probability
from spd.utils import set_seed
from train_resnet import WaterbirdsSubset


In [2]:
import torch
import torch.nn as nn

# Define a combined model that uses the pretrained ResNet backbone 
# but replaces the FC layers with your SPD module
class CombinedWaterbirdModel(nn.Module):
    def __init__(self, resnet_model, spd_model):
        super().__init__()
        # Use only the feature extractor part of the ResNet18 model
        self.features = resnet_model.features
        # Use the SPD model for the fully connected part
        self.spd_model = spd_model
        
    def forward(self, x):
        # Extract features using ResNet backbone
        feats = self.features(x)               # [B, 512, 1, 1]
        feats = feats.flatten(start_dim=1)     # [B, 512]
        
        # Pass features to SPD module
        out = self.spd_model(feats)
        return out

# 1. Load the pretrained ResNet model
resnet_model = WaterbirdResNet18(num_classes=2, hidden_dim=512)
resnet_ckpt_path = "checkpoints/waterbird_resnet18_best.pth"

# The file showed the checkpoint has a different structure - it's a dictionary
resnet_checkpoint = torch.load(resnet_ckpt_path, map_location="cpu")
resnet_model.load_state_dict(resnet_checkpoint['model_state_dict'])
resnet_model.eval()

# 2. Load the SPD model 
spd_model = SPDTwoLayerFC(
    in_features=512,
    hidden_dim=512,
    num_classes=2,
    C=40,
    m_fc1=16,
    m_fc2=16,
)

# Load SPD state dict
spd_ckpt_path = "waterbird_spd_out/waterbird_spd_final.pth"
spd_state_dict = torch.load(spd_ckpt_path, map_location="cpu")
spd_model.load_state_dict(spd_state_dict)
spd_model.eval()

# 3. Create the combined model
combined_model = CombinedWaterbirdModel(resnet_model, spd_model)
combined_model.eval()

# To use for inference:
# with torch.no_grad():
#     output = combined_model(input_image)

  resnet_checkpoint = torch.load(resnet_ckpt_path, map_location="cpu")
  spd_state_dict = torch.load(spd_ckpt_path, map_location="cpu")


CombinedWaterbirdModel(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_run

In [3]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
import torchvision.transforms as T
from wilds import get_dataset
import torch.nn.functional as F

# First, setup validation dataset
waterbird_dataset = get_dataset(dataset="waterbirds", download=False)
dataset_size = len(waterbird_dataset)
print(f"Total dataset size: {dataset_size}")

# Get indices
all_indices = np.arange(dataset_size)
np.random.shuffle(all_indices)
train_indices = all_indices[:2000].tolist()
val_indices = all_indices[2000:3000].tolist()  # Taking 1000 samples after the 2000th index

# Setup validation transform
val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

# Create validation subset
val_subset = WaterbirdsSubset(
    waterbird_dataset, 
    indices=val_indices,
    transform=val_transform
)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

# Load the combined model
# 1. Load ResNet
resnet_model = WaterbirdResNet18(num_classes=2, hidden_dim=512)
resnet_ckpt_path = "checkpoints/waterbird_resnet18_best.pth"
resnet_checkpoint = torch.load(resnet_ckpt_path, map_location="cpu")
resnet_model.load_state_dict(resnet_checkpoint['model_state_dict'])
resnet_model.eval()

# 2. Load SPD model
spd_model = SPDTwoLayerFC(
    in_features=512,
    hidden_dim=512,
    num_classes=2,
    C=40,
    m_fc1=16,
    m_fc2=16,
)
spd_ckpt_path = "waterbird_spd_out/waterbird_spd_final.pth"
spd_state_dict = torch.load(spd_ckpt_path, map_location="cpu")
spd_model.load_state_dict(spd_state_dict)
spd_model.eval()

# 3. Create combined model
class CombinedWaterbirdModel(nn.Module):
    def __init__(self, resnet_model, spd_model):
        super().__init__()
        self.features = resnet_model.features
        self.spd_model = spd_model
        
    def forward(self, x):
        feats = self.features(x)
        feats = feats.flatten(start_dim=1)
        out = self.spd_model(feats)
        return out

# Create the combined model and move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
combined_model = CombinedWaterbirdModel(resnet_model, spd_model).to(device)
combined_model.eval()

# Evaluate on validation set
correct = 0
total = 0
metadata_correct = {}  # For analysis by metadata
metadata_total = {}

print(f"Evaluating combined model on {len(val_indices)} validation samples...")

with torch.no_grad():
    for inputs, labels, metadata in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Forward pass
        outputs = combined_model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        
        # Update overall stats
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate accuracy
accuracy = 100 * correct / total
print(f"Combined model accuracy on validation set: {accuracy:.2f}%")

Total dataset size: 11788


  resnet_checkpoint = torch.load(resnet_ckpt_path, map_location="cpu")
  spd_state_dict = torch.load(spd_ckpt_path, map_location="cpu")


Evaluating combined model on 1000 validation samples...
Combined model accuracy on validation set: 71.00%


In [4]:
import torch
from wilds import get_dataset
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm

# Get the dataset
dataset = get_dataset(dataset="waterbirds", download=False)

# Create transform
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

# Define a function to get samples by group
def get_group_indices(dataset, bird_type, background):
    """
    Get indices of samples where:
    - bird_type: 0 for landbird, 1 for waterbird
    - background: 0 for land, 1 for water
    """
    indices = []
    for i in range(len(dataset)):
        x, y, metadata = dataset[i]
        # y is the bird type, metadata[1] is the background type
        if y == bird_type and metadata[0] == background:
            indices.append(i)
    return indices

# Get indices for each group
landbird_land = get_group_indices(dataset, 0, 0)  # Majority group
waterbird_water = get_group_indices(dataset, 1, 1)  # Majority group
landbird_water = get_group_indices(dataset, 0, 1)  # Minority group
waterbird_land = get_group_indices(dataset, 1, 0)  # Minority group

print(f"Landbirds on land: {len(landbird_land)} samples")
print(f"Waterbirds on water: {len(waterbird_water)} samples")
print(f"Landbirds on water: {len(landbird_water)} samples")  # This should be smaller
print(f"Waterbirds on land: {len(waterbird_land)} samples")  # This should be smaller

# Function to evaluate model on a specific group
def evaluate_group(model, dataset, indices, transform, device, batch_size=32):
    subset = Subset(dataset, indices)
    
    # Create a custom dataset that applies the transform
    class TransformSubset:
        def __init__(self, subset, transform):
            self.subset = subset
            self.transform = transform
            
        def __len__(self):
            return len(self.subset)
            
        def __getitem__(self, idx):
            x, y, metadata = self.subset[idx]
            if self.transform:
                x = self.transform(x)
            return x, y, metadata
    
    # Create the loader
    loader = DataLoader(
        TransformSubset(subset, transform),
        batch_size=batch_size,
        shuffle=False
    )
    
    # Evaluate
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels, metadata in tqdm(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total if total > 0 else 0
    return accuracy

# Now you can evaluate both your models on each group
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Evaluate original ResNet model
print("\nEvaluating original ResNet model:")
resnet_model = resnet_model.to(device)
landbird_land_acc = evaluate_group(resnet_model, dataset, landbird_land, transform, device)
waterbird_water_acc = evaluate_group(resnet_model, dataset, waterbird_water, transform, device)
landbird_water_acc = evaluate_group(resnet_model, dataset, landbird_water, transform, device)
waterbird_land_acc = evaluate_group(resnet_model, dataset, waterbird_land, transform, device)

print(f"Landbirds on land: {landbird_land_acc:.2f}%")
print(f"Waterbirds on water: {waterbird_water_acc:.2f}%")
print(f"Landbirds on water: {landbird_water_acc:.2f}%")
print(f"Waterbirds on land: {waterbird_land_acc:.2f}%")

# Evaluate combined model
print("\nEvaluating combined model:")
combined_model = combined_model.to(device)
landbird_land_acc = evaluate_group(combined_model, dataset, landbird_land, transform, device)
waterbird_water_acc = evaluate_group(combined_model, dataset, waterbird_water, transform, device)
landbird_water_acc = evaluate_group(combined_model, dataset, landbird_water, transform, device)
waterbird_land_acc = evaluate_group(combined_model, dataset, waterbird_land, transform, device)

print(f"Landbirds on land: {landbird_land_acc:.2f}%")
print(f"Waterbirds on water: {waterbird_water_acc:.2f}%")
print(f"Landbirds on water: {landbird_water_acc:.2f}%")
print(f"Waterbirds on land: {waterbird_land_acc:.2f}%")

Landbirds on land: 6220 samples
Waterbirds on water: 1832 samples
Landbirds on water: 2905 samples
Waterbirds on land: 831 samples

Evaluating original ResNet model:


100%|██████████| 195/195 [00:26<00:00,  7.33it/s]
100%|██████████| 58/58 [00:07<00:00,  7.49it/s]
100%|██████████| 91/91 [00:12<00:00,  7.14it/s]
100%|██████████| 26/26 [00:03<00:00,  7.02it/s]


Landbirds on land: 98.26%
Waterbirds on water: 86.08%
Landbirds on water: 25.44%
Waterbirds on land: 4.93%

Evaluating combined model:


100%|██████████| 195/195 [00:26<00:00,  7.34it/s]
100%|██████████| 58/58 [00:07<00:00,  7.41it/s]
100%|██████████| 91/91 [00:12<00:00,  7.43it/s]
100%|██████████| 26/26 [00:03<00:00,  7.14it/s]

Landbirds on land: 98.31%
Waterbirds on water: 85.81%
Landbirds on water: 25.99%
Waterbirds on land: 4.81%





In [5]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
import torchvision.transforms as T
from wilds import get_dataset
import torch.nn.functional as F

# --- Same data-loading code as before ---
# (omitted for brevity)

# --- Same checkpoint loading for resnet_model and spd_model ---
# (omitted for brevity)

# Evaluate on validation set, but ablate circuit #0
correct = 0
total = 0

print(f"Evaluating combined model on {len(val_indices)} validation samples, "
      "ablating circuit #0...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet_model.to(device).eval()
spd_model.to(device).eval()

with torch.no_grad():
    for inputs, labels, metadata in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # 1) Extract features from ResNet trunk
        feats = resnet_model.features(inputs)
        feats = feats.flatten(start_dim=1)  # shape [batch_size, 512]

        # 2) Build the topk_mask that ablates circuit #0
        batch_size = feats.size(0)
        topk_mask = torch.ones((batch_size, spd_model.C), dtype=torch.bool, device=device)
        topk_mask[:, 0] = False  # Turn off circuit #0 for every example

        # 3) Forward pass through SPD with ablation
        outputs = spd_model(feats, topk_mask=topk_mask)

        # 4) Compute predictions
        _, predicted = torch.max(outputs, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Ablating circuit #0 => Accuracy on validation set: {accuracy:.2f}%")


Evaluating combined model on 1000 validation samples, ablating circuit #0...
Ablating circuit #0 => Accuracy on validation set: 71.10%


In [6]:
import torch
from wilds import get_dataset
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm

# Get the dataset
dataset = get_dataset(dataset="waterbirds", download=False)

# Create transform
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

# Define a function to get samples by group
def get_group_indices(dataset, bird_type, background):
    """
    Get indices of samples where:
    - bird_type: 0 for landbird, 1 for waterbird
    - background: 0 for land, 1 for water
    """
    indices = []
    for i in range(len(dataset)):
        x, y, metadata = dataset[i]
        # y is the bird type, metadata[0] is the background type
        if y == bird_type and metadata[0] == background:
            indices.append(i)
    return indices

# Get indices for each group
landbird_land = get_group_indices(dataset, 0, 0)  # Majority group
waterbird_water = get_group_indices(dataset, 1, 1)  # Majority group
landbird_water = get_group_indices(dataset, 0, 1)  # Minority group
waterbird_land = get_group_indices(dataset, 1, 0)  # Minority group

print(f"Landbirds on land: {len(landbird_land)} samples")
print(f"Waterbirds on water: {len(waterbird_water)} samples")
print(f"Landbirds on water: {len(landbird_water)} samples")
print(f"Waterbirds on land: {len(waterbird_land)} samples")

# Modified function to evaluate model on a specific group with optional circuit ablation
def evaluate_group(resnet_model, spd_model, dataset, indices, transform, device, 
                  batch_size=32, ablate_circuits=None):
    """
    Evaluate model on a specific group with optional circuit ablation
    
    Args:
        resnet_model: Feature extractor model
        spd_model: SPD model with circuits
        dataset: Dataset object
        indices: Indices to evaluate on
        transform: Image transform
        device: Device to run on
        batch_size: Batch size for evaluation
        ablate_circuits: List of circuit indices to ablate (set to None for no ablation)
    
    Returns:
        accuracy: Accuracy on the evaluated group
    """
    subset = Subset(dataset, indices)
    
    # Create a custom dataset that applies the transform
    class TransformSubset:
        def __init__(self, subset, transform):
            self.subset = subset
            self.transform = transform
            
        def __len__(self):
            return len(self.subset)
            
        def __getitem__(self, idx):
            x, y, metadata = self.subset[idx]
            if self.transform:
                x = self.transform(x)
            return x, y, metadata
    
    # Create the loader
    loader = DataLoader(
        TransformSubset(subset, transform),
        batch_size=batch_size,
        shuffle=False
    )
    
    # Evaluate
    correct = 0
    total = 0
    resnet_model.eval()
    spd_model.eval()
    
    with torch.no_grad():
        for inputs, labels, metadata in tqdm(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Extract features using ResNet
            feats = resnet_model.features(inputs)
            feats = feats.flatten(start_dim=1)  # [B, 512]
            
            # Create ablation mask if needed
            if ablate_circuits is not None and len(ablate_circuits) > 0:
                batch_size = feats.size(0)
                topk_mask = torch.ones((batch_size, spd_model.C), dtype=torch.bool, device=device)
                for circuit_idx in ablate_circuits:
                    topk_mask[:, circuit_idx] = False  # Turn off specified circuits
                
                # Forward pass through SPD with ablation
                outputs = spd_model(feats, topk_mask=topk_mask)
            else:
                # Normal forward pass through SPD without ablation
                outputs = spd_model(feats)
            
            # Compute predictions
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total if total > 0 else 0
    return accuracy

# Now you can evaluate both models on each group
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load models (assuming you've already loaded them before)
# resnet_model = ... (your ResNet model)
# spd_model = ... (your SPD model)

# First, evaluate the standard model performance (without ablation)
print("\nEvaluating combined model (no ablation):")
landbird_land_acc = evaluate_group(resnet_model, spd_model, dataset, landbird_land, transform, device)
waterbird_water_acc = evaluate_group(resnet_model, spd_model, dataset, waterbird_water, transform, device)
landbird_water_acc = evaluate_group(resnet_model, spd_model, dataset, landbird_water, transform, device)
waterbird_land_acc = evaluate_group(resnet_model, spd_model, dataset, waterbird_land, transform, device)

print(f"Landbirds on land: {landbird_land_acc:.2f}%")
print(f"Waterbirds on water: {waterbird_water_acc:.2f}%")
print(f"Landbirds on water: {landbird_water_acc:.2f}%")
print(f"Waterbirds on land: {waterbird_land_acc:.2f}%")

# Now evaluate with circuit #0 ablated
ablate_circuits = [0]  # Ablate circuit #0
print(f"\nEvaluating combined model with circuit(s) {ablate_circuits} ablated:")
landbird_land_acc_abl = evaluate_group(resnet_model, spd_model, dataset, landbird_land, transform, device, ablate_circuits=ablate_circuits)
waterbird_water_acc_abl = evaluate_group(resnet_model, spd_model, dataset, waterbird_water, transform, device, ablate_circuits=ablate_circuits)
landbird_water_acc_abl = evaluate_group(resnet_model, spd_model, dataset, landbird_water, transform, device, ablate_circuits=ablate_circuits)
waterbird_land_acc_abl = evaluate_group(resnet_model, spd_model, dataset, waterbird_land, transform, device, ablate_circuits=ablate_circuits)

print(f"Landbirds on land: {landbird_land_acc_abl:.2f}%")
print(f"Waterbirds on water: {waterbird_water_acc_abl:.2f}%")
print(f"Landbirds on water: {landbird_water_acc_abl:.2f}%")
print(f"Waterbirds on land: {waterbird_land_acc_abl:.2f}%")

# Print the differences
print("\nAccuracy differences (ablated - normal):")
print(f"Landbirds on land: {landbird_land_acc_abl - landbird_land_acc:.2f}%")
print(f"Waterbirds on water: {waterbird_water_acc_abl - waterbird_water_acc:.2f}%")
print(f"Landbirds on water: {landbird_water_acc_abl - landbird_water_acc:.2f}%")
print(f"Waterbirds on land: {waterbird_land_acc_abl - waterbird_land_acc:.2f}%")

# You can easily test ablating multiple circuits
ablate_circuits = [0, 1, 2]  # Ablate circuits 0, 1, and 2
print(f"\nEvaluating combined model with circuit(s) {ablate_circuits} ablated:")
landbird_land_acc_multi = evaluate_group(resnet_model, spd_model, dataset, landbird_land, transform, device, ablate_circuits=ablate_circuits)
waterbird_water_acc_multi = evaluate_group(resnet_model, spd_model, dataset, waterbird_water, transform, device, ablate_circuits=ablate_circuits)
landbird_water_acc_multi = evaluate_group(resnet_model, spd_model, dataset, landbird_water, transform, device, ablate_circuits=ablate_circuits)
waterbird_land_acc_multi = evaluate_group(resnet_model, spd_model, dataset, waterbird_land, transform, device, ablate_circuits=ablate_circuits)

print(f"Landbirds on land: {landbird_land_acc_multi:.2f}%")
print(f"Waterbirds on water: {waterbird_water_acc_multi:.2f}%")
print(f"Landbirds on water: {landbird_water_acc_multi:.2f}%")
print(f"Waterbirds on land: {waterbird_land_acc_multi:.2f}%")

Landbirds on land: 6220 samples
Waterbirds on water: 1832 samples
Landbirds on water: 2905 samples
Waterbirds on land: 831 samples

Evaluating combined model (no ablation):


100%|██████████| 195/195 [00:26<00:00,  7.30it/s]
100%|██████████| 58/58 [00:07<00:00,  7.53it/s]
100%|██████████| 91/91 [00:12<00:00,  7.39it/s]
100%|██████████| 26/26 [00:03<00:00,  7.45it/s]


Landbirds on land: 98.31%
Waterbirds on water: 85.81%
Landbirds on water: 25.99%
Waterbirds on land: 4.81%

Evaluating combined model with circuit(s) [0] ablated:


100%|██████████| 195/195 [00:26<00:00,  7.32it/s]
100%|██████████| 58/58 [00:07<00:00,  7.57it/s]
100%|██████████| 91/91 [00:12<00:00,  7.53it/s]
100%|██████████| 26/26 [00:03<00:00,  7.28it/s]


Landbirds on land: 98.50%
Waterbirds on water: 84.83%
Landbirds on water: 27.33%
Waterbirds on land: 4.45%

Accuracy differences (ablated - normal):
Landbirds on land: 0.19%
Waterbirds on water: -0.98%
Landbirds on water: 1.34%
Waterbirds on land: -0.36%

Evaluating combined model with circuit(s) [0, 1, 2] ablated:


100%|██████████| 195/195 [00:27<00:00,  7.17it/s]
100%|██████████| 58/58 [00:07<00:00,  7.55it/s]
100%|██████████| 91/91 [00:12<00:00,  7.50it/s]
100%|██████████| 26/26 [00:03<00:00,  7.10it/s]

Landbirds on land: 98.49%
Waterbirds on water: 84.99%
Landbirds on water: 27.26%
Waterbirds on land: 4.45%



