In [21]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [47]:
model = nn.Sequential(
    nn.Linear(28*28,64),
    nn.ReLU(),
    nn.Linear(64,64),
    nn.ReLU(),
    nn.Linear(64,64),
    nn.ReLU(),
    nn.Linear(64,10)
)

In [48]:
params = model.parameters()
optimizer = optim.SGD(params, lr = 1e-2)

In [49]:
loss = nn.CrossEntropyLoss()

In [50]:
train_data = datasets.MNIST('data',train = True, download = True, transform = transforms.ToTensor())
train,val = random_split(train_data, [55000,5000])
train_loader = DataLoader(train,batch_size = 32)
val_loader = DataLoader(val,batch_size = 32)

In [51]:
n_epochs = 5
for epoch in range(n_epochs):
    losses = list()
    accuracies = list()
    model.eval()
    for batch in train_loader:
        
        x,y, = batch
        b = x.size(0)

        #flatten image
        x = x.view(b,-1)
        
        l = model(x)
        J = loss(l,y)
        
        model.zero_grad()
        
        J.backward()
        
        optimizer.step()
        
        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())
        
    print(f'Epoch {epoch + 1}', end = ',')
    print(f'train loss:{torch.tensor(losses).mean():.2f}', end = ',')
    print(f'train accuracy:{torch.tensor(accuracies).mean():.2f}')
        
    losses = list()
    accuracies = list()
    for batch in val_loader:
        
        x,y, = batch
        b = x.size(0)
        #flatten image
        x = x.view(b,-1)
        
        with torch.no_grad():
            l = model(x)
        J = loss(l,y)
        
        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())
        
    print(f'Epoch {epoch + 1}',end= ',')
    print(f'validation loss:{torch.tensor(losses).mean():.2f}',end = ',')
    print(f'validation accuracy: {torch.tensor(accuracies).mean():.2f}')

Epoch 1,train loss:1.83,train accuracy:0.43
Epoch 1,validation loss:0.74,validation accuracy: 0.80
Epoch 2,train loss:0.50,train accuracy:0.85
Epoch 2,validation loss:0.40,validation accuracy: 0.88
Epoch 3,train loss:0.37,train accuracy:0.89
Epoch 3,validation loss:0.33,validation accuracy: 0.91
Epoch 4,train loss:0.31,train accuracy:0.91
Epoch 4,validation loss:0.29,validation accuracy: 0.92
Epoch 5,train loss:0.27,train accuracy:0.92
Epoch 5,validation loss:0.26,validation accuracy: 0.93


In [57]:
test_data = datasets.MNIST('data',train = False, download = True, transform = transforms.ToTensor())
test_data

Dataset MNIST
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: ToTensor()

In [53]:
x_hat = model(x)
J = loss(l,y)