# Lab 1: Model Extraction

## Objectives
- Extract model via queries
- Build substitute model
- Test transfer attacks
- Evaluate extraction success

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# Detect device (supports CUDA, Apple Silicon MPS, and CPU)
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

## Part 1: Target Model

In [2]:
class TargetModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    
    def forward(self, x):
        return self.fc(x.view(-1, 784))

target = TargetModel().to(device)
print('✓ Target model created')

✓ Target model created


## Part 2: Query and Extract

In [3]:
# Generate query data
X_query = torch.randn(5000, 784)

# Query target model
with torch.no_grad():
    y_query = target(X_query.to(device)).argmax(1).cpu()

print(f'Collected {len(X_query)} labeled samples')

Collected 5000 labeled samples


## Part 3: Train Substitute

In [4]:
substitute = TargetModel().to(device)
optimizer = torch.optim.Adam(substitute.parameters())
criterion = nn.CrossEntropyLoss()

for epoch in range(20):
    optimizer.zero_grad()
    outputs = substitute(X_query.to(device))
    loss = criterion(outputs, y_query.to(device))
    loss.backward()
    optimizer.step()

print('✓ Substitute model trained')

✓ Substitute model trained


## Part 4: Evaluate Extraction

In [5]:
X_test = torch.randn(1000, 784)

with torch.no_grad():
    target_preds = target(X_test.to(device)).argmax(1)
    substitute_preds = substitute(X_test.to(device)).argmax(1)
    agreement = (target_preds == substitute_preds).float().mean()

print(f'Model agreement: {agreement:.2%}')
print(f'Extraction success: {"High" if agreement > 0.9 else "Medium" if agreement > 0.7 else "Low"}')

Model agreement: 44.90%
Extraction success: Low
