In [98]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import v2
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")

In [99]:
class MultilayerPerceptron(nn.Module):
    def __init__(self, in_features, out_features):
        super(MultilayerPerceptron, self).__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, 32),
            nn.ReLU(),
            nn.Linear(32, out_features),
        )

    def forward(self, x):
        return self.layers(x)

In [100]:
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
EPOCHS = 5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [101]:
transformations = v2.Compose(
    (
        v2.ToTensor(),
    )
)

net = MultilayerPerceptron(784, 10).to(DEVICE)

dataset = datasets.MNIST("../data", True, transformations, download=True)
train_dataset, test_dataset = random_split(dataset, [0.8, 0.2])
dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
optimizer = optim.Adam(net.parameters(), LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw






In [102]:
net.train()

for epoch in range(EPOCHS):
    for images, labels in tqdm(dataloader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        net.zero_grad()

        predictions = net(images)
        loss = criterion(predictions, labels)

        loss.backward()
        optimizer.step()
        
    print(f"Epoch: {epoch + 1} Loss: {loss}")

100%|██████████| 750/750 [00:03<00:00, 189.51it/s]


Epoch: 1 Loss: 0.24516916275024414


100%|██████████| 750/750 [00:04<00:00, 184.18it/s]


Epoch: 2 Loss: 0.13111528754234314


100%|██████████| 750/750 [00:04<00:00, 181.98it/s]


Epoch: 3 Loss: 0.13048899173736572


100%|██████████| 750/750 [00:04<00:00, 177.66it/s]


Epoch: 4 Loss: 0.1356794685125351


100%|██████████| 750/750 [00:04<00:00, 175.55it/s]

Epoch: 5 Loss: 0.09040802717208862





In [103]:
net.eval()
correct = 0

for image, label in test_dataset:
    prediction = torch.argmax(F.softmax(net(image)))
    
    if prediction == label:
        correct += 1

print(f"Accuracy: {correct / len(test_dataset) * 100}%")

Accuracy: 94.93333333333334%
