In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from torchvision.models import resnet50, ResNet50_Weights

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transforms: Resize, convert grayscale to 3-channel, normalize
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet-50 input size
    transforms.Grayscale(num_output_channels=3),  # Replicate to RGB
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# Load full MNIST
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Subset to 1000 train and 200 test samples (random)
train_indices = np.random.choice(len(trainset), 1000, replace=False)
test_indices = np.random.choice(len(testset), 200, replace=False)

train_subset = Subset(trainset, train_indices)
test_subset = Subset(testset, test_indices)

# DataLoaders
trainloader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=2)
testloader = DataLoader(test_subset, batch_size=32, shuffle=False, num_workers=2)

# Load pre-trained ResNet-50 and modify final layer for 10 classes
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # MNIST: 10 digits

model = model.to(device)

# Loss and optimizer (SGD with momentum; common for transfer learning)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)

# Training loop (5 epochs for demo; increase for better results)
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(trainloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy on 200 samples: {accuracy:.2f}%")

Using device: cpu


100.0%
100.0%
100.0%
100.0%


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/silicon/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100.0%


Epoch 1/5, Average Loss: 2.0221
Epoch 2/5, Average Loss: 0.9454
Epoch 3/5, Average Loss: 0.3054
