### Train a CNN on MNIST
- Simple CNN
- Report test accuracy

### Transfer model to SVHN
- Use pretrained weights
- Fine-tune on SVHN
- Replace the final layer to match SVHN's 10 classes (same as MNIST)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, SVHN
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.fc(x)
        return x

In [3]:
print("Downloading and preparing MNIST...")

transform_mnist = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_mnist = MNIST(root='./data', train=True, download=True, transform=transform_mnist)
test_mnist = MNIST(root='./data', train=False, download=True, transform=transform_mnist)

train_loader_mnist = DataLoader(train_mnist, batch_size=64, shuffle=True)
test_loader_mnist = DataLoader(test_mnist, batch_size=64, shuffle=False)


Downloading and preparing MNIST...
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 9912422/9912422 [01:23<00:00, 118141.84it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 252738.29it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:08<00:00, 188947.77it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 3013370.57it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [11]:
# train function
def train(model, loader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
        for inputs, labels in loop:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())
        print(f"Epoch {epoch+1} avg loss: {total_loss/len(loader):.4f}")

# evaluation function
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    progress = tqdm(loader, desc="Evaluating", leave=True, ncols=100)

    with torch.no_grad():
        for inputs, labels in progress:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            acc = 100 * correct / total
            progress.set_postfix(acc=f"{acc:.2f}%")
    print(f"Final Accuracy: {acc:.2f}%")
    return acc


In [12]:
model_mnist = SimpleCNN().to(device)
optimizer = optim.Adam(model_mnist.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Training on MNIST...")
train(model_mnist, train_loader_mnist, optimizer, criterion, epochs=5)
evaluate(model_mnist, test_loader_mnist)

Training on MNIST...


Epoch 1/5: 100%|█████████████████████████████████████| 938/938 [00:17<00:00, 53.21it/s, loss=0.0108]


Epoch 1 avg loss: 0.1535


Epoch 2/5: 100%|███████████████████████████████████| 938/938 [00:17<00:00, 52.94it/s, loss=0.000338]


Epoch 2 avg loss: 0.0445


Epoch 3/5: 100%|████████████████████████████████████| 938/938 [00:17<00:00, 53.01it/s, loss=0.00488]


Epoch 3 avg loss: 0.0312


Epoch 4/5: 100%|█████████████████████████████████████| 938/938 [00:17<00:00, 53.58it/s, loss=0.0154]


Epoch 4 avg loss: 0.0240


Epoch 5/5: 100%|█████████████████████████████████████| 938/938 [00:17<00:00, 53.78it/s, loss=0.0144]


Epoch 5 avg loss: 0.0173


Evaluating: 100%|█████████████████████████████████████| 157/157 [00:02<00:00, 66.06it/s, acc=98.72%]

Final Accuracy: 98.72%





98.72

In [13]:
print("Downloading and preparing SVHN...")

transform_svhn = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_svhn = SVHN(root='./data', split='train', download=True, transform=transform_svhn)
test_svhn = SVHN(root='./data', split='test', download=True, transform=transform_svhn)

train_loader_svhn = DataLoader(train_svhn, batch_size=64, shuffle=True)
test_loader_svhn = DataLoader(test_svhn, batch_size=64, shuffle=False)


Downloading and preparing SVHN...
Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


In [14]:
model_svhn = SimpleCNN().to(device)
model_svhn.load_state_dict(model_mnist.state_dict())  # transfer

# Optional: freeze conv layers
# for param in model_svhn.conv.parameters():
#     param.requires_grad = False

optimizer_svhn = optim.Adam(model_svhn.parameters(), lr=0.001)

print("Transfer Learning on SVHN...")
train(model_svhn, train_loader_svhn, optimizer_svhn, criterion, epochs=5)
evaluate(model_svhn, test_loader_svhn)


Transfer Learning on SVHN...


Epoch 1/5:   0%|                                                           | 0/1145 [00:00<?, ?it/s]

Epoch 1/5: 100%|████████████████████████████████████| 1145/1145 [00:25<00:00, 44.59it/s, loss=0.557]


Epoch 1 avg loss: 0.6357


Epoch 2/5: 100%|████████████████████████████████████| 1145/1145 [00:25<00:00, 45.19it/s, loss=0.309]


Epoch 2 avg loss: 0.3845


Epoch 3/5: 100%|████████████████████████████████████| 1145/1145 [00:25<00:00, 45.31it/s, loss=0.404]


Epoch 3 avg loss: 0.3122


Epoch 4/5: 100%|████████████████████████████████████| 1145/1145 [00:25<00:00, 45.06it/s, loss=0.239]


Epoch 4 avg loss: 0.2676


Epoch 5/5: 100%|████████████████████████████████████| 1145/1145 [00:26<00:00, 43.72it/s, loss=0.119]


Epoch 5 avg loss: 0.2293


Evaluating: 100%|█████████████████████████████████████| 407/407 [00:08<00:00, 50.58it/s, acc=89.12%]

Final Accuracy: 89.12%





89.12492317148126