In [1]:
import torchvision
import torch
import torchvision.transforms.v2 as transforms
from torch.utils.data import DataLoader
image_path = "../NNs with PyTorch/"
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(dtype=torch.float32,scale=True)
])
mnist_train_dataset = torchvision.datasets.MNIST(root=image_path,train=True,transform=transform,download=False)
mnist_test_dataset = torchvision.datasets.MNIST(root=image_path,train=False,transform=transform,download=False)
batch_size = 64
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_dataset,batch_size,True)

In [None]:
# Construct the NN model

import torch.nn as nn
hidden_units = [32,16]
image_size = mnist_train_dataset[0][0].shape
input_size = image_size[0]*image_size[1]*image_size[2]
all_layers = [nn.Flatten()]
for hidden_unit in hidden_units:
    all_layers.append(nn.Linear(input_size,hidden_unit))
    all_layers.append(nn.ReLU())
    input_size = hidden_unit

all_layers.append(nn.Linear(hidden_units[-1],10))
model = nn.Sequential(*all_layers)
model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
  (4): ReLU()
  (5): Linear(in_features=16, out_features=10, bias=True)
)

In [4]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
torch.manual_seed(1)
num_epochs = 20
model.to("mps")
for epoch in range(num_epochs):
    accuracy_hist_train = 0
    for x_batch, y_batch in train_dl:
        x_batch, y_batch = x_batch.to("mps"), y_batch.to("mps")
        pred: torch.Tensor = model(x_batch)
        loss = loss_fn(pred,y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        is_correct = (pred.argmax(1) == y_batch).float()
        accuracy_hist_train += is_correct.sum()
    accuracy_hist_train /= len(train_dl.dataset)
    print(f"Epoch {epoch} Accuracy {accuracy_hist_train:.4f}")

Epoch 0 Accuracy 0.8531
Epoch 1 Accuracy 0.9287
Epoch 2 Accuracy 0.9413
Epoch 3 Accuracy 0.9506
Epoch 4 Accuracy 0.9556
Epoch 5 Accuracy 0.9593
Epoch 6 Accuracy 0.9628
Epoch 7 Accuracy 0.9648
Epoch 8 Accuracy 0.9672
Epoch 9 Accuracy 0.9690
Epoch 10 Accuracy 0.9711
Epoch 11 Accuracy 0.9730
Epoch 12 Accuracy 0.9740
Epoch 13 Accuracy 0.9750
Epoch 14 Accuracy 0.9768
Epoch 15 Accuracy 0.9780
Epoch 16 Accuracy 0.9782
Epoch 17 Accuracy 0.9801
Epoch 18 Accuracy 0.9806
Epoch 19 Accuracy 0.9811


In [5]:
# Let's test it on the test set
pred = model(mnist_test_dataset.data.to("mps")/255.)
is_correct = (torch.argmax(pred,dim=1) == mnist_test_dataset.targets.to("mps")).float()
print(f"Test accuracy: {is_correct.mean():.4f}")

Test accuracy: 0.9640
