In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms as T

In [2]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")
print(device)

mps


In [None]:
transform = T.Compose(
    [
        T.ToTensor(),
        T.RandomHorizontalFlip(),
        T.Normalize((0.5,), (0.5,))
    ]
)

In [4]:
trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
testset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=4) # pin_memory and num_workers probably causes overhead here that is not worth it with this datset but wanted to experiment
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=True, num_workers=4)

In [6]:
class NNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 2, 1, 1)
        self.conv2 = nn.Conv2d(32, 64, 2, 1, 1)

        self.pool = nn.AvgPool2d(2)

        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
model = NNet()
model.to(device) 

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

In [8]:
epochs = 4

for i in range(epochs):
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(inputs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    print(f"EPOCH: {i} -- LOSS: {running_loss/len(trainloader)}")

EPOCH: 0 -- LOSS: 0.5384796292838951
EPOCH: 1 -- LOSS: 0.25168742509810255
EPOCH: 2 -- LOSS: 0.1745916185802004
EPOCH: 3 -- LOSS: 0.13646455237001645


In [9]:
# Eval
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        out = model(inputs)

        _, pred = torch.max(out, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()

print(f"ACC: {100 * correct / total}%")

ACC: 96.11%
