In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchinfo import summary

from tqdm import tqdm

from modules.blocks import BasicBlock
from modules.models import ResNet

# Download and preprocess the MNIST dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train_dataset = torchvision.datasets.MNIST(
    root="../data", train=True, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2048, shuffle=True)

# Instantiate the model, loss function, and optimizer
model = ResNet(block=BasicBlock, in_channels=1, layers=[3, 4, 6, 3], num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

device = torch.device('mps')
model.to(device)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

100%|██████████| 30/30 [00:22<00:00,  1.31it/s]


Epoch [1/5], Loss: 0.3186


100%|██████████| 30/30 [00:20<00:00,  1.47it/s]


Epoch [2/5], Loss: 0.0471


100%|██████████| 30/30 [00:19<00:00,  1.50it/s]


Epoch [3/5], Loss: 0.0263


100%|██████████| 30/30 [00:19<00:00,  1.51it/s]


Epoch [4/5], Loss: 0.0151


100%|██████████| 30/30 [00:19<00:00,  1.50it/s]

Epoch [5/5], Loss: 0.0151





In [5]:
# Define the testing dataset
test_dataset = torchvision.datasets.MNIST(
    root="../data", train=False, transform=transform, download=True
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Evaluation function
def evaluate(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient calculation during evaluation
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Test the model
test_accuracy = evaluate(model, test_loader, device)
print(f"Test Accuracy: {test_accuracy:.2f}%")

Test Accuracy: 98.43%
