In [1]:
%pip install torch numpy matplotlib tqdm torchvision ipywidgets --upgrade

Collecting numpy
  Using cached numpy-2.1.3-cp312-cp312-macosx_14_0_arm64.whl.metadata (62 kB)
Using cached numpy-2.1.3-cp312-cp312-macosx_14_0_arm64.whl (5.1 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.1.2
    Uninstalling numpy-2.1.2:
      Successfully uninstalled numpy-2.1.2
Successfully installed numpy-2.1.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [4]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets

train_dataset = datasets.MNIST(root='.', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='.', train=False, download=True, transform=transforms.ToTensor())

batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))

In [39]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        torch.manual_seed(2024)

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=3)

        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=3, padding=0)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.affine4 = nn.Linear(32, 20)
        self.affine5 = nn.Linear(20, 10)

        self.activation = nn.ReLU()
    
    def forward(self, x):
        x: torch.Tensor = self.conv1(x)
        x = self.activation(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.activation(x)

        x = self.conv3(x)
        x = self.activation(x)
        x = self.pool3(x)

        x = x.view(-1, 32)

        x = self.affine4(x)
        x = self.activation(x)

        x = self.affine5(x)

        return x

model = CNN().to(device)

print(sum(p.numel() for p in model.parameters()))

8934


In [40]:
lr = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [41]:
a = torch.zeros(64, 1, 28, 28).to(device)
b: torch.Tensor = model(a)

print(b.shape)

torch.Size([64, 10])


In [42]:
from tqdm.notebook import tqdm

num_epochs = 10

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10)

for epoch in tqdm(range(num_epochs), "epoch"):
    for image, label in tqdm(train_loader, "batch", leave=False):
        image, label = image.to(device), label.to(device)
        output = model(image)

        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    num_correct = 0
    with torch.no_grad():
        for image, label in tqdm(test_loader, "test", leave=False):
            image, label = image.to(device), label.to(device)
            output = model(image)
            pred = output.argmax(dim=1)
            num_correct += (pred == label).sum()
    
    scheduler.step()

    print(f"Epoch: {epoch + 1} / Acc: {num_correct / len(test_dataset) * 100:.2f}%")

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 1 / Acc: 94.51%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 2 / Acc: 95.78%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 3 / Acc: 95.85%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 4 / Acc: 97.38%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 5 / Acc: 97.50%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 6 / Acc: 97.94%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 7 / Acc: 98.08%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 8 / Acc: 98.32%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 9 / Acc: 98.39%


batch:   0%|          | 0/938 [00:00<?, ?it/s]

test:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 10 / Acc: 98.36%
