In [33]:
import torch
import torchvision
from torchvision import transforms, datasets
import torch.optim as optim

In [34]:
train = datasets.MNIST("", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST("", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))


In [35]:
trainset = torch.utils.data.DataLoader(train, batch_size=8, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=8, shuffle=True)

In [36]:
import torch.nn as nn
import torch.nn.functional as F

In [37]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)

net = Net()
print(net)

Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=10, bias=True)
)


In [40]:
## TRAIN

optimizer = optim.Adam(net.parameters(), lr=0.001)

EPOCHS=2

for epoch in range(EPOCHS):
    print(f"Epoch {epoch}")
    for data in trainset:
        # X contains the images of a batch (8 images), and y the corresponding labels
        X, y = data
        net.zero_grad()
        output = net(X.view(-1, 28*28))
        
        loss = F.nll_loss(output, y)
        loss.backward()
        optimizer.step()
    print(f"Loss: {loss}\n")

Epoch 0
Loss: 0.04139792174100876

Epoch 1
Loss: 0.0015932309906929731



In [41]:
## TEST

total=0
correct=0

with torch.no_grad():
    for data in testset:
        X, y = data
        output = net(X.view(-1, 28*28))
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
                correct += 1
            total += 1

print(f"Accuracy: {(correct/total)*100}")

Accuracy: 97.25
