In [1]:
# ============================================================================
# CIFAR-10 Drift Detection Training - Complete Colab Notebook
# No external files needed - everything is self-contained
# ============================================================================

# Cell 1: Setup and Imports
# Run this first to install dependencies and check GPU

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import pickle
import json
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel
from tqdm import tqdm

print("Checking GPU availability...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# ============================================================================
# Cell 2: Utility Functions (from utils.py)
# ============================================================================

def compute_mmd(X, Y, kernel='rbf', gamma=None):
    """Compute Maximum Mean Discrepancy between two distributions."""
    if gamma is None:
        gamma = 1.0 / X.shape[1]

    XX = rbf_kernel(X, X, gamma=gamma)
    YY = rbf_kernel(Y, Y, gamma=gamma)
    XY = rbf_kernel(X, Y, gamma=gamma)

    mmd_squared = XX.mean() + YY.mean() - 2 * XY.mean()

    return np.sqrt(max(mmd_squared, 0))


def compute_simple_stats(activations):
    """Compute simple statistics for activations."""
    return {
        'mean': np.mean(activations),
        'variance': np.var(activations)
    }


def compute_accuracy(predictions, labels):
    """Compute classification accuracy."""
    return np.mean(predictions == labels)


def extract_activations(model, dataloader, layer_names, device='cuda', max_samples=1000):
    """Extract activations from specified layers."""
    model.eval()
    model.to(device)

    activations = {name: [] for name in layer_names}
    all_labels = []
    all_predictions = []

    def get_activation(name):
        def hook(model, input, output):
            activations[name].append(output.detach().cpu())
        return hook

    hooks = []
    for name in layer_names:
        layer = dict(model.named_modules())[name]
        hooks.append(layer.register_forward_hook(get_activation(name)))

    samples_processed = 0
    with torch.no_grad():
        for images, labels in dataloader:
            if samples_processed >= max_samples:
                break

            images = images.to(device)
            outputs = model(images)
            predictions = outputs.argmax(dim=1)

            all_labels.append(labels.numpy())
            all_predictions.append(predictions.cpu().numpy())

            samples_processed += images.size(0)

    for hook in hooks:
        hook.remove()

    result_activations = {}
    for name in layer_names:
        acts = torch.cat(activations[name], dim=0)
        acts = acts.reshape(acts.size(0), -1)
        result_activations[name] = acts.numpy()[:max_samples]

    all_labels = np.concatenate(all_labels)[:max_samples]
    all_predictions = np.concatenate(all_predictions)[:max_samples]

    return result_activations, all_labels, all_predictions

# ============================================================================
# Cell 3: Data Loading
# ============================================================================

def load_cifar10_data(batch_size=128):
    """Load CIFAR-10 dataset with data augmentation."""

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_dataset = datasets.CIFAR10('./data', train=True, download=True,
                                     transform=train_transform)
    test_dataset = datasets.CIFAR10('./data', train=False, download=True,
                                    transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                           shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, test_loader

print("Loading CIFAR-10 data...")
train_loader, test_loader = load_cifar10_data(batch_size=128)
print("✓ CIFAR-10 loaded")
print(f"  Training samples: {len(train_loader.dataset)}")
print(f"  Test samples: {len(test_loader.dataset)}")

# ============================================================================
# Cell 4: Model Definition
# ============================================================================

def get_resnet18_model():
    """Create ResNet18 for CIFAR-10."""
    model = models.resnet18(weights=None)

    # Modify for CIFAR-10's 32x32 images
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()

    # Change final layer for 10 classes
    model.fc = nn.Linear(model.fc.in_features, 10)

    return model

print("Creating ResNet18 model...")
model = get_resnet18_model()
print("✓ Model created")
print(f"  Monitoring layers: ['layer3', 'fc']")

# ============================================================================
# Cell 5: Training
# ============================================================================

def train_model(model, train_loader, device='cuda', epochs=50):
    """Train the ResNet18 model."""
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 38], gamma=0.1)

    print(f"Starting training for {epochs} epochs...")

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                'loss': f'{running_loss/total:.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })

        scheduler.step()
        epoch_acc = 100. * correct / total
        print(f"Epoch {epoch+1}: Loss={running_loss/total:.4f}, Accuracy={epoch_acc:.2f}%")

    return model


def evaluate_model(model, test_loader, device='cuda'):
    """Evaluate model on test set."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    return 100. * correct / total

# Train the model
model = train_model(model, train_loader, device=device, epochs=50)

# Evaluate
print("\nEvaluating on test set...")
test_accuracy = evaluate_model(model, test_loader, device=device)
print(f"✓ Test Accuracy: {test_accuracy:.2f}%")

# ============================================================================
# Cell 6: Extract Baseline Activations
# ============================================================================

print("\nExtracting baseline activations...")
layer_names = ['layer3', 'fc']

baseline_activations, labels, predictions = extract_activations(
    model, test_loader, layer_names, device=device, max_samples=1000
)

baseline_accuracy = compute_accuracy(predictions, labels)
print(f"✓ Baseline accuracy on 1000 samples: {baseline_accuracy*100:.2f}%")

# Compute statistics
baseline_stats = {
    'accuracy': float(baseline_accuracy),
    'layer_stats': {}
}

for layer_name in layer_names:
    acts = baseline_activations[layer_name]
    stats = compute_simple_stats(acts)
    baseline_stats['layer_stats'][layer_name] = {
        'mean': float(stats['mean']),
        'variance': float(stats['variance'])
    }
    print(f"  {layer_name}: mean={stats['mean']:.4f}, variance={stats['variance']:.4f}")

# ============================================================================
# Cell 7: Save Results
# ============================================================================

print("\nSaving results...")

# Save model
torch.save(model.state_dict(), 'model_cifar.pth')
print("✓ Saved: model_cifar.pth")

# Save baseline activations
with open('baseline_activations_cifar.pkl', 'wb') as f:
    pickle.dump({
        'activations': baseline_activations,
        'labels': labels,
        'predictions': predictions
    }, f)
print("✓ Saved: baseline_activations_cifar.pkl")

# Save baseline stats
with open('baseline_stats_cifar.json', 'w') as f:
    json.dump(baseline_stats, f, indent=2)
print("✓ Saved: baseline_stats_cifar.json")

print("\n" + "="*60)
print("✓ Training complete!")
print("="*60)
print(f"  Model accuracy: {test_accuracy:.2f}%")
print(f"  Baseline samples: {len(labels)}")
print("\nNext steps:")
print("  1. Download files using: files.download('model_cifar.pth')")
print("  2. Use these files with your drift experiment scripts on Mac")
print("="*60)

# ============================================================================
# Cell 8: Download Files (Optional)
# Uncomment to auto-download files to your Mac
# ============================================================================

# from google.colab import files
# files.download('model_cifar.pth')
# files.download('baseline_activations_cifar.pkl')
# files.download('baseline_stats_cifar.json')

Checking GPU availability...
Using device: cuda
GPU: Tesla T4
Memory: 15.83 GB
Loading CIFAR-10 data...


100%|██████████| 170M/170M [00:03<00:00, 46.8MB/s]


✓ CIFAR-10 loaded
  Training samples: 50000
  Test samples: 10000
Creating ResNet18 model...
✓ Model created
  Monitoring layers: ['layer3', 'fc']
Starting training for 50 epochs...


Epoch 1/50: 100%|██████████| 391/391 [00:41<00:00,  9.48it/s, loss=1.9263, acc=30.09%]


Epoch 1: Loss=1.9263, Accuracy=30.09%


Epoch 2/50: 100%|██████████| 391/391 [00:39<00:00,  9.79it/s, loss=1.3892, acc=48.52%]


Epoch 2: Loss=1.3892, Accuracy=48.52%


Epoch 3/50: 100%|██████████| 391/391 [00:39<00:00,  9.85it/s, loss=1.0807, acc=61.22%]


Epoch 3: Loss=1.0807, Accuracy=61.22%


Epoch 4/50: 100%|██████████| 391/391 [00:39<00:00,  9.80it/s, loss=0.8963, acc=68.37%]


Epoch 4: Loss=0.8963, Accuracy=68.37%


Epoch 5/50: 100%|██████████| 391/391 [00:39<00:00,  9.81it/s, loss=0.7466, acc=73.98%]


Epoch 5: Loss=0.7466, Accuracy=73.98%


Epoch 6/50: 100%|██████████| 391/391 [00:40<00:00,  9.76it/s, loss=0.6364, acc=77.89%]


Epoch 6: Loss=0.6364, Accuracy=77.89%


Epoch 7/50: 100%|██████████| 391/391 [00:39<00:00,  9.80it/s, loss=0.5863, acc=79.96%]


Epoch 7: Loss=0.5863, Accuracy=79.96%


Epoch 8/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.5399, acc=81.24%]


Epoch 8: Loss=0.5399, Accuracy=81.24%


Epoch 9/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.5148, acc=82.36%]


Epoch 9: Loss=0.5148, Accuracy=82.36%


Epoch 10/50: 100%|██████████| 391/391 [00:39<00:00,  9.82it/s, loss=0.4908, acc=83.09%]


Epoch 10: Loss=0.4908, Accuracy=83.09%


Epoch 11/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.4728, acc=83.69%]


Epoch 11: Loss=0.4728, Accuracy=83.69%


Epoch 12/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.4553, acc=84.41%]


Epoch 12: Loss=0.4553, Accuracy=84.41%


Epoch 13/50: 100%|██████████| 391/391 [00:39<00:00,  9.82it/s, loss=0.4411, acc=84.81%]


Epoch 13: Loss=0.4411, Accuracy=84.81%


Epoch 14/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.4341, acc=85.27%]


Epoch 14: Loss=0.4341, Accuracy=85.27%


Epoch 15/50: 100%|██████████| 391/391 [00:39<00:00,  9.81it/s, loss=0.4218, acc=85.34%]


Epoch 15: Loss=0.4218, Accuracy=85.34%


Epoch 16/50: 100%|██████████| 391/391 [00:39<00:00,  9.81it/s, loss=0.4081, acc=86.03%]


Epoch 16: Loss=0.4081, Accuracy=86.03%


Epoch 17/50: 100%|██████████| 391/391 [00:39<00:00,  9.80it/s, loss=0.4041, acc=86.37%]


Epoch 17: Loss=0.4041, Accuracy=86.37%


Epoch 18/50: 100%|██████████| 391/391 [00:39<00:00,  9.80it/s, loss=0.3909, acc=86.61%]


Epoch 18: Loss=0.3909, Accuracy=86.61%


Epoch 19/50: 100%|██████████| 391/391 [00:39<00:00,  9.82it/s, loss=0.3877, acc=86.94%]


Epoch 19: Loss=0.3877, Accuracy=86.94%


Epoch 20/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.3829, acc=86.89%]


Epoch 20: Loss=0.3829, Accuracy=86.89%


Epoch 21/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.3794, acc=86.92%]


Epoch 21: Loss=0.3794, Accuracy=86.92%


Epoch 22/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.3708, acc=87.34%]


Epoch 22: Loss=0.3708, Accuracy=87.34%


Epoch 23/50: 100%|██████████| 391/391 [00:39<00:00,  9.81it/s, loss=0.3664, acc=87.40%]


Epoch 23: Loss=0.3664, Accuracy=87.40%


Epoch 24/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.3629, acc=87.70%]


Epoch 24: Loss=0.3629, Accuracy=87.70%


Epoch 25/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.3656, acc=87.45%]


Epoch 25: Loss=0.3656, Accuracy=87.45%


Epoch 26/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.2064, acc=93.12%]


Epoch 26: Loss=0.2064, Accuracy=93.12%


Epoch 27/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.1551, acc=94.75%]


Epoch 27: Loss=0.1551, Accuracy=94.75%


Epoch 28/50: 100%|██████████| 391/391 [00:39<00:00,  9.79it/s, loss=0.1370, acc=95.37%]


Epoch 28: Loss=0.1370, Accuracy=95.37%


Epoch 29/50: 100%|██████████| 391/391 [00:39<00:00,  9.81it/s, loss=0.1202, acc=95.88%]


Epoch 29: Loss=0.1202, Accuracy=95.88%


Epoch 30/50: 100%|██████████| 391/391 [00:39<00:00,  9.81it/s, loss=0.1082, acc=96.30%]


Epoch 30: Loss=0.1082, Accuracy=96.30%


Epoch 31/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0995, acc=96.63%]


Epoch 31: Loss=0.0995, Accuracy=96.63%


Epoch 32/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0902, acc=96.96%]


Epoch 32: Loss=0.0902, Accuracy=96.96%


Epoch 33/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.0827, acc=97.18%]


Epoch 33: Loss=0.0827, Accuracy=97.18%


Epoch 34/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0764, acc=97.41%]


Epoch 34: Loss=0.0764, Accuracy=97.41%


Epoch 35/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0684, acc=97.70%]


Epoch 35: Loss=0.0684, Accuracy=97.70%


Epoch 36/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.0671, acc=97.77%]


Epoch 36: Loss=0.0671, Accuracy=97.77%


Epoch 37/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0610, acc=97.99%]


Epoch 37: Loss=0.0610, Accuracy=97.99%


Epoch 38/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.0580, acc=98.06%]


Epoch 38: Loss=0.0580, Accuracy=98.06%


Epoch 39/50: 100%|██████████| 391/391 [00:39<00:00,  9.82it/s, loss=0.0377, acc=98.84%]


Epoch 39: Loss=0.0377, Accuracy=98.84%


Epoch 40/50: 100%|██████████| 391/391 [00:40<00:00,  9.77it/s, loss=0.0319, acc=99.05%]


Epoch 40: Loss=0.0319, Accuracy=99.05%


Epoch 41/50: 100%|██████████| 391/391 [00:39<00:00,  9.82it/s, loss=0.0280, acc=99.19%]


Epoch 41: Loss=0.0280, Accuracy=99.19%


Epoch 42/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.0259, acc=99.25%]


Epoch 42: Loss=0.0259, Accuracy=99.25%


Epoch 43/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0254, acc=99.28%]


Epoch 43: Loss=0.0254, Accuracy=99.28%


Epoch 44/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0233, acc=99.37%]


Epoch 44: Loss=0.0233, Accuracy=99.37%


Epoch 45/50: 100%|██████████| 391/391 [00:39<00:00,  9.85it/s, loss=0.0223, acc=99.39%]


Epoch 45: Loss=0.0223, Accuracy=99.39%


Epoch 46/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.0213, acc=99.46%]


Epoch 46: Loss=0.0213, Accuracy=99.46%


Epoch 47/50: 100%|██████████| 391/391 [00:39<00:00,  9.85it/s, loss=0.0202, acc=99.46%]


Epoch 47: Loss=0.0202, Accuracy=99.46%


Epoch 48/50: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s, loss=0.0202, acc=99.48%]


Epoch 48: Loss=0.0202, Accuracy=99.48%


Epoch 49/50: 100%|██████████| 391/391 [00:39<00:00,  9.82it/s, loss=0.0182, acc=99.54%]


Epoch 49: Loss=0.0182, Accuracy=99.54%


Epoch 50/50: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s, loss=0.0194, acc=99.53%]

Epoch 50: Loss=0.0194, Accuracy=99.53%

Evaluating on test set...





✓ Test Accuracy: 93.50%

Extracting baseline activations...
✓ Baseline accuracy on 1000 samples: 93.50%
  layer3: mean=0.0452, variance=0.0134
  fc: mean=0.0001, variance=16.0656

Saving results...
✓ Saved: model_cifar.pth
✓ Saved: baseline_activations_cifar.pkl
✓ Saved: baseline_stats_cifar.json

✓ Training complete!
  Model accuracy: 93.50%
  Baseline samples: 1000

Next steps:
  1. Download files using: files.download('model_cifar.pth')
  2. Use these files with your drift experiment scripts on Mac


In [2]:
from google.colab import files

files.download('model_cifar.pth')
files.download('baseline_activations_cifar.pkl')
files.download('baseline_stats_cifar.json')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>