In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


train_dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=16, shuffle=False)

In [2]:

class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=64*7*7, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


In [9]:

torch.manual_seed(42)
epochs = 10
model = MnistModel()
model.to(device)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(params=model.parameters(), lr=0.001)

model.train()
for epoch in range(epochs):
    for image, label in train_loader:
        image, label = image.to(device), label.to(device)
        y_pred = model(image)
        loss = loss_fn(y_pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"epoch: {epoch}, loss: {loss}")


epoch: 0, loss: 0.011304040439426899
epoch: 1, loss: 0.0034760849084705114
epoch: 2, loss: 0.003990110941231251
epoch: 3, loss: 0.00018936427659355104
epoch: 4, loss: 0.0012940316228196025
epoch: 5, loss: 0.00011874186020577326
epoch: 6, loss: 0.0018641944043338299
epoch: 7, loss: 6.377564659487689e-06
epoch: 8, loss: 7.01978278812021e-05
epoch: 9, loss: 1.1175863079415649e-07


In [10]:
model.eval()
correct = 0
total = 0
total_loss = 0

with torch.inference_mode():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        y_pred = model(images)

        # Calculate loss
        loss = loss_fn(y_pred, labels)
        total_loss += loss.item()

        # Calculate accuracy
        predicted_labels = torch.argmax(y_pred, dim=1)
        correct += (predicted_labels == labels).sum().item()
        total += labels.size(0)

accuracy = 100 * correct / total
average_loss = total_loss / len(test_loader)
print(f'Accuracy: {accuracy:.2f}%, Loss: {average_loss:.4f}')



Accuracy: 99.20%, Loss: 0.0368
