# MNIST + MLP Classifier

## Load MNIST

In [3]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
transform = transforms.ToTensor()
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)

100.0%
100.0%
100.0%
100.0%


In [15]:
iterator = iter(train_loader)
feat, labels = next(iterator)

In [16]:
print(feat.shape, labels.shape)

torch.Size([128, 1, 28, 28]) torch.Size([128])


## 2. Build an MLP Classifier

In [18]:
import torch.nn as nn

In [19]:
class MLPClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


In [33]:
mlp = MLPClassifier()

optim = torch.optim.SGD(params=mlp.parameters(), lr=0.2)

loss_func = nn.CrossEntropyLoss()

epoch = 0

for x, y in train_loader:
    optim.zero_grad()
    
    y_pred = mlp.forward(x)
    
    loss = loss_func(y_pred, y)
    
    loss.backward()
    
    optim.step()
    
    epoch += 1
    
    print(f"Epoch: {epoch} | Loss: {loss.item()}")
    
    if epoch%5 == 0:
        labels = y_pred.argmax(dim=1)
        print(f"Accuracy={(labels == y).float().mean()}")

Epoch: 1 | Loss: 2.313735246658325
Epoch: 2 | Loss: 2.254995346069336
Epoch: 3 | Loss: 2.216642379760742
Epoch: 4 | Loss: 2.2066233158111572
Epoch: 5 | Loss: 2.116137981414795
Accuracy=0.40625
Epoch: 6 | Loss: 2.1082653999328613
Epoch: 7 | Loss: 2.0303168296813965
Epoch: 8 | Loss: 1.9880492687225342
Epoch: 9 | Loss: 1.9071979522705078
Epoch: 10 | Loss: 1.9142886400222778
Accuracy=0.6015625
Epoch: 11 | Loss: 1.8285441398620605
Epoch: 12 | Loss: 1.789801001548767
Epoch: 13 | Loss: 1.7015624046325684
Epoch: 14 | Loss: 1.6269309520721436
Epoch: 15 | Loss: 1.6034806966781616
Accuracy=0.671875
Epoch: 16 | Loss: 1.471725583076477
Epoch: 17 | Loss: 1.4991382360458374
Epoch: 18 | Loss: 1.4325751066207886
Epoch: 19 | Loss: 1.3637492656707764
Epoch: 20 | Loss: 1.2786868810653687
Accuracy=0.7734375
Epoch: 21 | Loss: 1.1804617643356323
Epoch: 22 | Loss: 1.1847728490829468
Epoch: 23 | Loss: 1.2464265823364258
Epoch: 24 | Loss: 1.0078176259994507
Epoch: 25 | Loss: 1.0445510149002075
Accuracy=0.796875