# Lab 4: Poisoning Defense & Detection

## Objectives
- Implement activation clustering
- Test spectral signatures
- Use STRIP defense
- Evaluate detection methods

In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

# Detect device (supports CUDA, Apple Silicon MPS, and CPU)
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

## Part 1: Activation Clustering

In [2]:
def extract_activations(model, X, layer_name='fc.0'):
    """Extract intermediate activations"""
    activations = []
    
    def hook(module, input, output):
        activations.append(output.detach())
    
    # Register hook
    for name, module in model.named_modules():
        if name == layer_name:
            module.register_forward_hook(hook)
    
    with torch.no_grad():
        model(X.to(device))
    
    return torch.cat(activations, dim=0)

# Create model and data
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        return self.fc(x.view(-1, 784))

model = SimpleNet().to(device)
X_clean = torch.randn(100, 784)
X_poisoned = torch.randn(20, 784) + 2  # Shifted distribution

# Extract activations
act_clean = extract_activations(model, X_clean)
act_poisoned = extract_activations(model, X_poisoned)

print(f'Clean activations: {act_clean.shape}')
print(f'Poisoned activations: {act_poisoned.shape}')

Clean activations: torch.Size([100, 128])
Poisoned activations: torch.Size([20, 128])


## Part 2: Cluster Analysis

In [3]:
# Combine activations
all_activations = torch.cat([act_clean, act_poisoned], dim=0).cpu().numpy()
labels = np.array([0]*len(act_clean) + [1]*len(act_poisoned))

# Cluster
kmeans = KMeans(n_clusters=2, random_state=42)
cluster_labels = kmeans.fit_predict(all_activations)

# Evaluate
from sklearn.metrics import accuracy_score
detection_acc = max(
    accuracy_score(labels, cluster_labels),
    accuracy_score(labels, 1-cluster_labels)
)

print(f'Detection accuracy: {detection_acc:.2%}')

Detection accuracy: 100.00%


## Part 3: STRIP Defense

In [4]:
def strip_defense(model, x, n_samples=10):
    """STRIP: STRong Intentional Perturbation"""
    entropies = []
    
    for _ in range(n_samples):
        # Blend with random image
        random_img = torch.randn_like(x)
        blended = 0.5 * x + 0.5 * random_img
        
        # Get prediction
        with torch.no_grad():
            output = model(blended.to(device))
            probs = torch.softmax(output, dim=1)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
        
        entropies.append(entropy.item())
    
    # Low entropy variance suggests backdoor
    return np.std(entropies)

# Test
clean_entropy = strip_defense(model, X_clean[0:1])
print(f'Clean entropy variance: {clean_entropy:.4f}')

# Backdoored sample would have lower variance
print('Lower variance suggests backdoor presence')

Clean entropy variance: 0.0069
Lower variance suggests backdoor presence


## Exercise: Improve Detection

Combine multiple detection methods for better accuracy.

In [5]:
# Your code here
