In [1]:
# LeNet classifier for MNIST + feature extractor for FID
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 128

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),  # 28x28 -> 28x28
            nn.ReLU(),
            nn.MaxPool2d(2),                # 28x28 -> 14x14
            nn.Conv2d(32, 64, 3, padding=1),# 14x14 -> 14x14
            nn.ReLU(),
            nn.MaxPool2d(2),                # 14x14 -> 7x7
        )
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # Penultimate layer
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x, return_features=False):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        features = self.fc1(x)
        logits = self.fc2(features)
        if return_features:
            return features
        return logits

# --- Load MNIST ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# --- Train LeNet ---
model = LeNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(15):
    model.train()
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1} | Train Loss: {avg_loss:.4f}")

# --- Evaluate ---
model.eval()
correct = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x).argmax(dim=1)
        correct += (pred == y).sum().item()
print(f"Test Accuracy: {correct / len(test_dataset) * 100:.2f}%")

# --- Save model ---
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "lenet_mnist.pth")
print("✅ LeNet model saved.")


Epoch 1 | Train Loss: 0.1454
Epoch 2 | Train Loss: 0.0436
Epoch 3 | Train Loss: 0.0325
Epoch 4 | Train Loss: 0.0252
Epoch 5 | Train Loss: 0.0209
Epoch 6 | Train Loss: 0.0162
Epoch 7 | Train Loss: 0.0135
Epoch 8 | Train Loss: 0.0128
Epoch 9 | Train Loss: 0.0097
Epoch 10 | Train Loss: 0.0109
Epoch 11 | Train Loss: 0.0086
Epoch 12 | Train Loss: 0.0094
Epoch 13 | Train Loss: 0.0075
Epoch 14 | Train Loss: 0.0054
Epoch 15 | Train Loss: 0.0073
Test Accuracy: 98.95%
✅ LeNet model saved.
