In [27]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm.notebook import tqdm
from ipywidgets import IntProgress

In [39]:
# Load MNIST data

train = datasets.MNIST("", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST("", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

trainset = torch.utils.data.DataLoader(train, batch_size=8, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=8, shuffle=True)

In [40]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)

net = Net()
net

Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=10, bias=True)
)

In [41]:
## TRAIN

LOAD_MODEL=True #Load a pre-trained model if it exists. Set to False to re-train the model

EPOCHS=1
def train():
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch}")
        running_loss = 0.0
        
        for i, data in tqdm(enumerate(trainset)):
            # X contains the images of a batch (8 images), and y the corresponding labels
            X, y = data
            net.zero_grad()
            output = net(X.view(-1, 28*28))

            loss = F.nll_loss(output, y)
            loss.backward()
            optimizer.step()
            
            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
                
        print(f"Loss: {loss}\n")


PATH = 'nn_mnist.pth'
if LOAD_MODEL:
    try:
        net.load_state_dict(torch.load(PATH))
    except Exception:
        train()
        torch.save(net.state_dict(), PATH)
else:
    train()
    torch.save(net.state_dict(), PATH)
    

Epoch 0


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[1,  2000] loss: 0.583
[1,  4000] loss: 0.296
[1,  6000] loss: 0.234
[1,  8000] loss: 0.209
[1, 10000] loss: 0.187
[1, 12000] loss: 0.182
[1, 14000] loss: 0.159

Loss: 0.014889711514115334



In [42]:
## TEST

def test():
    total=0
    correct=0

    with torch.no_grad():
        for data in testset:
            X, y = data
            output = net(X.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1

    print(f"Accuracy: {(correct/total)*100}")

test()

Accuracy: 95.58
