In [3]:
pip install torch

Collecting torch
  Downloading torch-2.9.1-cp311-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting filelock (from torch)
  Downloading filelock-3.20.1-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Downloading networkx-3.6.1-py3-none-any.whl.metadata (6.8 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec>=0.8.5 (from torch)
  Downloading fsspec-2025.12.0-py3-none-any.whl.metadata (10 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Using cached markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl.metadata (2.7 kB)
Downloading torch-2.9.1-cp311-none-macosx_11_0_arm64.whl (74.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.5/74.5 MB[0m [31m2.6 

In [5]:
pip install torchvision numpy

Collecting torchvision
  Downloading torchvision-0.24.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (5.9 kB)
Collecting numpy
  Downloading numpy-2.4.0-cp311-cp311-macosx_14_0_arm64.whl.metadata (6.6 kB)
Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)
  Downloading pillow-12.0.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (8.8 kB)
Downloading torchvision-0.24.1-cp311-cp311-macosx_11_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m823.8 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading numpy-2.4.0-cp311-cp311-macosx_14_0_arm64.whl (5.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading pillow-12.0.0-cp311-cp311-macosx_11_0_arm64.whl (4.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: pi

In [7]:
# Step 1: Train a basic CNN model on CIFAR-10 until training saturates

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define a simple CNN architecture
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)
        
        # Activation and dropout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # Conv block 1
        x = self.relu(self.conv1(x))
        x = self.pool(x)  # 32x32 -> 16x16
        
        # Conv block 2
        x = self.relu(self.conv2(x))
        x = self.pool(x)  # 16x16 -> 8x8
        
        # Conv block 3
        x = self.relu(self.conv3(x))
        x = self.pool(x)  # 8x8 -> 4x4
        
        # Flatten
        x = x.view(-1, 128 * 4 * 4)
        
        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Data loading and preprocessing
def load_cifar10():
    # Simple normalization transform
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                         download=True, transform=transform_test)
    testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
    
    return trainloader, testloader

# Calculate F1 score
def calculate_f1_score(y_true, y_pred, num_classes=10):
    f1_scores = []
    
    for class_id in range(num_classes):
        # True positives, false positives, false negatives
        tp = ((y_pred == class_id) & (y_true == class_id)).sum().item()
        fp = ((y_pred == class_id) & (y_true != class_id)).sum().item()
        fn = ((y_pred != class_id) & (y_true == class_id)).sum().item()
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        f1_scores.append(f1)
    
    return np.mean(f1_scores)

# Evaluation function
def evaluate(model, testloader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
    
    all_preds = torch.tensor(all_preds)
    all_labels = torch.tensor(all_labels)
    
    accuracy = (all_preds == all_labels).float().mean().item()
    f1 = calculate_f1_score(all_labels, all_preds)
    
    return accuracy, f1

# Training function
def train_model(model, trainloader, testloader, epochs=50, learning_rate=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    print(f"\nTraining with learning_rate={learning_rate}")
    print("="*70)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        for i, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Evaluate every epoch
        avg_loss = running_loss / len(trainloader)
        test_acc, test_f1 = evaluate(model, testloader)
        
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} - "
              f"Test Acc: {test_acc:.4f} - Test F1: {test_f1:.4f}")
    
    return test_acc, test_f1

# Main execution
if __name__ == "__main__":
    print("Loading CIFAR-10 dataset...")
    trainloader, testloader = load_cifar10()
    
    print("\nInitializing CNN model...")
    model = SimpleCNN().to(device)
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
    
    # Train until saturation
    print("\nTraining model until saturation...")
    final_acc, final_f1 = train_model(model, trainloader, testloader, epochs=50)
    
    print("\n" + "="*70)
    print("FINAL RESULTS:")
    print(f"Test Accuracy: {final_acc:.4f}")
    print(f"Test F1 Score: {final_f1:.4f}")
    print("="*70)
    
    # Save the model
    torch.save(model.state_dict(), 'cifar10_cnn_baseline.pth')
    print("\nModel saved as 'cifar10_cnn_baseline.pth'")

Using device: cpu
Loading CIFAR-10 dataset...

Initializing CNN model...
Total parameters: 620362

Training model until saturation...

Training with learning_rate=0.001
Epoch [1/50] - Loss: 1.5175 - Test Acc: 0.5788 - Test F1: 0.5665
Epoch [2/50] - Loss: 1.1066 - Test Acc: 0.6753 - Test F1: 0.6749
Epoch [3/50] - Loss: 0.9186 - Test Acc: 0.7042 - Test F1: 0.6997
Epoch [4/50] - Loss: 0.8034 - Test Acc: 0.7289 - Test F1: 0.7279
Epoch [5/50] - Loss: 0.7189 - Test Acc: 0.7333 - Test F1: 0.7300
Epoch [6/50] - Loss: 0.6555 - Test Acc: 0.7486 - Test F1: 0.7495
Epoch [7/50] - Loss: 0.5876 - Test Acc: 0.7672 - Test F1: 0.7648
Epoch [8/50] - Loss: 0.5336 - Test Acc: 0.7625 - Test F1: 0.7618
Epoch [9/50] - Loss: 0.4846 - Test Acc: 0.7675 - Test F1: 0.7639
Epoch [10/50] - Loss: 0.4398 - Test Acc: 0.7791 - Test F1: 0.7800
Epoch [11/50] - Loss: 0.4057 - Test Acc: 0.7780 - Test F1: 0.7789
Epoch [12/50] - Loss: 0.3730 - Test Acc: 0.7794 - Test F1: 0.7791
Epoch [13/50] - Loss: 0.3512 - Test Acc: 0.7808 