In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.ToTensor()

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, padding=1),  # (B,1,28,28) → (B,32,28,28)
    nn.ReLU(),
    nn.MaxPool2d(2, 2),                          # → (B,32,14,14)

    nn.Conv2d(32, 64, kernel_size=3, padding=1), # → (B,64,14,14)
    nn.ReLU(),
    nn.MaxPool2d(2, 2),                          # → (B,64,7,7)

    nn.Flatten(),                                # → (B, 64*7*7)
    nn.Linear(64 * 7 * 7, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
).to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 5

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        # 1. Forward
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 2. Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: loss = {total_loss/len(train_loader):.4f}")

Epoch 1: loss = 0.1743
Epoch 2: loss = 0.0487
Epoch 3: loss = 0.0327
Epoch 4: loss = 0.0249
Epoch 5: loss = 0.0192


In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        preds = outputs.argmax(dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Accuracy: {100 * correct / total:.2f}%")

Accuracy: 99.07%
