# RQ4: Activation Patching for Domain Invariance

**Research Question:** Can activation patching reduce domain leakage without full retraining?

**Hypothesis:** By identifying layers where DANN diverges most from ERM (via CKA), we can "transplant" domain-invariant representations into ERM at inference time.

## Approach

1. **CKA Analysis** — Identify which layers show largest representation difference between ERM and DANN
2. **Activation Patching** — During inference, replace ERM activations at layer L with DANN activations
3. **Evaluation** — Measure domain probe accuracy and EER on patched models

## Expected Outcome

If successful, this provides a lightweight method to improve domain robustness without expensive DANN training.

## Setup

In [None]:
import sys
sys.path.insert(0, '..')  # Add project root to path

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

from asvspoof5_domain_invariant_cm.utils import load_checkpoint, get_device
from asvspoof5_domain_invariant_cm.data import ASVspoof5Dataset, get_dataloader

device = get_device()
print(f"Using device: {device}")

## 1. Load Checkpoints

Load the ERM and DANN models (WavLM backbone).

In [None]:
# Update these paths to your checkpoint locations
ERM_CHECKPOINT = Path("../runs/wavlm_erm/checkpoints/best.pt")
DANN_CHECKPOINT = Path("../runs/wavlm_dann_exp/checkpoints/best.pt")  # Exponential schedule

# Load models
erm_model, erm_config = load_checkpoint(ERM_CHECKPOINT, device=device)
dann_model, dann_config = load_checkpoint(DANN_CHECKPOINT, device=device)

erm_model.eval()
dann_model.eval()

print(f"ERM config: {erm_config.get('training', {}).get('method')}")
print(f"DANN config: {dann_config.get('training', {}).get('method')}")

## 2. CKA Analysis

Centered Kernel Alignment (CKA) measures representational similarity between layers.
We compute CKA between ERM and DANN at each layer to identify where they diverge most.

In [None]:
def linear_CKA(X, Y):
    """Compute linear CKA between two representation matrices.
    
    Args:
        X: (n_samples, n_features_x) - Representations from model 1
        Y: (n_samples, n_features_y) - Representations from model 2
    
    Returns:
        CKA similarity score in [0, 1]
    """
    # Center the representations
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)
    
    # Compute Gram matrices
    XXT = X @ X.T
    YYT = Y @ Y.T
    
    # CKA = HSIC(X,Y) / sqrt(HSIC(X,X) * HSIC(Y,Y))
    hsic_xy = (XXT * YYT).sum()
    hsic_xx = (XXT * XXT).sum()
    hsic_yy = (YYT * YYT).sum()
    
    return (hsic_xy / torch.sqrt(hsic_xx * hsic_yy)).item()

In [None]:
def extract_layer_representations(model, dataloader, num_batches=10):
    """Extract representations from each backbone layer.
    
    Returns:
        dict: {layer_idx: tensor of shape (n_samples, hidden_dim)}
    """
    layer_reps = {i: [] for i in range(12)}  # 12 transformer layers
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, total=num_batches)):
            if batch_idx >= num_batches:
                break
            
            audio = batch['audio'].to(device)
            
            # Get hidden states from all layers
            # TODO: Implement based on your model's forward method
            # hidden_states = model.backbone(audio, output_hidden_states=True)
            
            # For now, placeholder
            pass
    
    return {k: torch.cat(v, dim=0) for k, v in layer_reps.items()}

In [None]:
# TODO: Load a subset of eval data for CKA analysis
# eval_loader = get_dataloader(...)

# Extract representations
# erm_reps = extract_layer_representations(erm_model, eval_loader)
# dann_reps = extract_layer_representations(dann_model, eval_loader)

# Compute CKA per layer
# cka_scores = []
# for layer_idx in range(12):
#     cka = linear_CKA(erm_reps[layer_idx], dann_reps[layer_idx])
#     cka_scores.append(cka)
#     print(f"Layer {layer_idx}: CKA = {cka:.4f}")

In [None]:
# TODO: Plot CKA scores
# plt.figure(figsize=(10, 5))
# plt.bar(range(12), cka_scores)
# plt.xlabel('Layer')
# plt.ylabel('CKA Similarity')
# plt.title('ERM vs DANN Representation Similarity (CKA)')
# plt.axhline(y=0.9, color='r', linestyle='--', label='High similarity threshold')
# plt.legend()
# plt.show()
#
# # Identify most divergent layers (lowest CKA)
# divergent_layers = np.argsort(cka_scores)[:3]
# print(f"Most divergent layers: {divergent_layers}")

## 3. Activation Patching

Replace ERM activations at specific layers with DANN activations during inference.

In [None]:
class PatchedModel(torch.nn.Module):
    """Model that patches activations from a donor model at specified layers."""
    
    def __init__(self, base_model, donor_model, patch_layers):
        """
        Args:
            base_model: The model to run inference on (ERM)
            donor_model: The model to take activations from (DANN)
            patch_layers: List of layer indices to patch
        """
        super().__init__()
        self.base_model = base_model
        self.donor_model = donor_model
        self.patch_layers = set(patch_layers)
        
        # Register hooks for patching
        self._donor_activations = {}
        self._setup_hooks()
    
    def _setup_hooks(self):
        """Setup forward hooks to capture and replace activations."""
        # TODO: Implement hooks based on model architecture
        # This requires knowing the specific layer names in your backbone
        pass
    
    def forward(self, audio):
        """Forward pass with activation patching."""
        # 1. Run donor model to capture activations
        with torch.no_grad():
            _ = self.donor_model(audio)
        
        # 2. Run base model with patched activations (via hooks)
        output = self.base_model(audio)
        
        return output

In [None]:
# TODO: Create patched model
# Target the most divergent layers identified by CKA
# patched_model = PatchedModel(
#     base_model=erm_model,
#     donor_model=dann_model,
#     patch_layers=divergent_layers
# )

## 4. Evaluation

Compare:
1. Original ERM
2. Original DANN
3. Patched ERM (with DANN activations at selected layers)

In [None]:
def evaluate_model(model, dataloader):
    """Compute EER and collect predictions for analysis."""
    model.eval()
    all_scores = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader):
            audio = batch['audio'].to(device)
            labels = batch['label']
            
            # Get predictions
            outputs = model(audio)
            scores = torch.softmax(outputs['logits'], dim=-1)[:, 1]  # P(bonafide)
            
            all_scores.extend(scores.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Compute EER
    # TODO: Import your EER computation function
    # eer = compute_eer(all_labels, all_scores)
    # return eer, all_scores, all_labels
    pass

In [None]:
# TODO: Run evaluation
# print("Evaluating ERM...")
# erm_eer = evaluate_model(erm_model, eval_loader)
# print(f"ERM EER: {erm_eer:.2%}")

# print("Evaluating DANN...")
# dann_eer = evaluate_model(dann_model, eval_loader)
# print(f"DANN EER: {dann_eer:.2%}")

# print("Evaluating Patched ERM...")
# patched_eer = evaluate_model(patched_model, eval_loader)
# print(f"Patched ERM EER: {patched_eer:.2%}")

## 5. Domain Probe on Patched Model

Verify that patching reduces domain leakage by running codec probes on the patched representations.

In [None]:
# TODO: Run domain probes on patched model representations
# Compare probe accuracy: ERM vs Patched vs DANN
# Lower probe accuracy = more domain invariant

## 6. Results Summary

In [None]:
# TODO: Create summary table and visualizations
# | Model       | Eval EER | Domain Probe Acc | Notes |
# |-------------|----------|------------------|-------|
# | ERM         | X.XX%    | XX.X%            | Baseline |
# | DANN        | X.XX%    | XX.X%            | Full adversarial training |
# | Patched ERM | X.XX%    | XX.X%            | Layers [X,Y,Z] patched |

## Conclusions

- **Key finding:** [Does patching work? Which layers matter?]
- **Trade-offs:** [Computational cost vs performance gain]
- **Limitations:** [What doesn't work?]
- **Future work:** [Potential extensions]