In [111]:


## Import
import torch
import torchvision ## Contains some utilities for working with the image data
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
#%matplotlib inline
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn.functional as F



In [112]:
data = MNIST(root = './', train = True, transform = transforms.ToTensor())
train_data, validation_data = random_split(data, [50000, 10000])

In [113]:
batch_size = 128
train_loader = DataLoader(train_data, batch_size, shuffle = True)
val_loader = DataLoader(validation_data, batch_size, shuffle = False)

In [118]:
class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size, num_classes)
    
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return(out)
    
    def training_step(self, batch):
        images, labels = batch
        out = self(images) ## Generate predictions
        loss = F.cross_entropy(out, labels) ## Calculate the loss
        return(loss)
    
    def validation_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)
        return({'val_loss':loss, 'val_acc': acc})
    
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        return({'val_loss': epoch_loss.item(), 'val_acc' : epoch_acc.item()})
    
    def epoch_end(self, epoch,result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
        
    
model = MnistModel()

In [123]:
def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return(model.validation_epoch_end(outputs))

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim = 1)
    return(torch.tensor(torch.sum(preds == labels).item()/ len(preds)))
def fit(epochs, lr, model, train_loader, val_loader, opt_func = torch.optim.SGD):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        
        ## Training Phas
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        ## Validation phase
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)
    return(history)

In [124]:
result = fit(10,1e-4, model, train_loader,val_loader)

Epoch [0], val_loss: 2.1988, val_acc: 0.2633
Epoch [1], val_loss: 2.1604, val_acc: 0.3349
Epoch [2], val_loss: 2.1235, val_acc: 0.3951
Epoch [3], val_loss: 2.0878, val_acc: 0.4468
Epoch [4], val_loss: 2.0534, val_acc: 0.4902
Epoch [5], val_loss: 2.0200, val_acc: 0.5259
Epoch [6], val_loss: 1.9877, val_acc: 0.5571
Epoch [7], val_loss: 1.9563, val_acc: 0.5813
Epoch [8], val_loss: 1.9259, val_acc: 0.6030
Epoch [9], val_loss: 1.8964, val_acc: 0.6263


In [125]:
print(result)

[{'val_loss': 2.198805332183838, 'val_acc': 0.26325157284736633}, {'val_loss': 2.1604018211364746, 'val_acc': 0.3349485695362091}, {'val_loss': 2.123467445373535, 'val_acc': 0.3950751721858978}, {'val_loss': 2.087832450866699, 'val_acc': 0.44679588079452515}, {'val_loss': 2.053379774093628, 'val_acc': 0.4902096390724182}, {'val_loss': 2.0200271606445312, 'val_acc': 0.5259097814559937}, {'val_loss': 1.9876917600631714, 'val_acc': 0.5570608973503113}, {'val_loss': 1.956333041191101, 'val_acc': 0.581289529800415}, {'val_loss': 1.9259015321731567, 'val_acc': 0.6030458807945251}, {'val_loss': 1.8963555097579956, 'val_acc': 0.6262856125831604}]
