0. Imports and Random Seed setting

In [1]:
import torch
import torchvision
import numpy as np

torch.manual_seed(42)

<torch._C.Generator at 0x20af0597010>

1. Data Loading and Preperation

In [14]:
train_data = torchvision.datasets.MNIST(root = '../data',train = True,download= True)
test_data = torchvision.datasets.MNIST(root = '../data',train = False,download=True)
#Has been ran once to obtain and store data, now download paramter has been changed to False

x_train = train_data.data.float() / 255 #Normalising Greyscale Image Pixel Intesity
y_train = train_data.targets

x_test = test_data.data.float() / 255
y_test = test_data.targets

x_train = x_train.view(-1,28*28)
x_test = x_test.view(-1,28*28)  #Flattens the images to 2d, could use -1 or 60_000 and 10_000

def one_hot(labels,num_classes = 10):
    output = torch.zeros(labels.size(0),num_classes)
    row = torch.arange(labels.size(0))
    output[row,labels] = 1.0
    return output
#One hot encode label data so it can be usefully analysed

y_test_oh = one_hot(y_test,10)
y_train_oh = one_hot(y_train,10)

def get_batches(x,y,batch_size = 64):
    idx = torch.randperm(x.size(0))
    for i in range(0,x.size(0),batch_size):
        batches_idx = idx[i:i+batch_size]
        yield x[batches_idx],y[batches_idx]
        
        

2. Building and Training MLP without High Level help

In [23]:
def accuracy_of_pred(probs,y_oh):
    predicted = torch.argmax(probs,dim = 1)
    true_classes = torch.argmax(y_oh,dim = 1)
    correct = (predicted == true_classes).float()
    accuracy = correct.mean()
    return accuracy

In [33]:
w1 = torch.randn(28*28,64,requires_grad= True)
w2 = torch.randn(64,10,requires_grad= True)
b1 = torch.randn(64,requires_grad= True)
b2 = torch.randn(10,requires_grad= True)

epoch_n = 60
running_loss = 0.0
batch_count = 0
learning_rate = 0.01


for i in range(1,epoch_n+1):
    epoch_probs = []
    epoch_ys = []
    for x,y in get_batches(x_train,y_train_oh):
        z1 = x @ w1 + b1
        relu = torch.maximum(z1,torch.zeros_like(z1))
        z2 = relu @ w2 + b2
        
        probs = torch.softmax(z2,dim = 1)
        shift = 1e-10 #Prevent taking log of 0 choose small shift to minimally effect results
        class_probs = torch.sum(probs * y,dim= 1)
        cross_entropy = -torch.log(class_probs + shift)
        loss = cross_entropy.mean()
        
        epoch_probs.append(probs.detach())
        epoch_ys.append(y.detach())
        
        loss.backward()
        batch_count +=1
        running_loss += loss.item()
        
        with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad
            b1 -= learning_rate * b1.grad
            b2 -= learning_rate * b2.grad #Note to self must be done in place for requires_grad = True to stay
        
        w1.grad.zero_()
        w2.grad.zero_()
        b1.grad.zero_()
        b2.grad.zero_()
    
    
    print(f"Epoch {i}: avg loss = {running_loss / batch_count:.4f}")
    all_probs = torch.cat(epoch_probs,dim=0)
    all_ys = torch.cat(epoch_ys,dim = 0)
    print(f'Accuracy on Epoch {i} : {accuracy_of_pred(all_probs,all_ys)}')


Epoch 1: avg loss = 15.6582
Accuracy on Epoch 1 : 0.2758333384990692
Epoch 2: avg loss = 13.5688
Accuracy on Epoch 2 : 0.44795000553131104
Epoch 3: avg loss = 11.9998
Accuracy on Epoch 3 : 0.565416693687439
Epoch 4: avg loss = 10.6455
Accuracy on Epoch 4 : 0.659166693687439
Epoch 5: avg loss = 9.6828
Accuracy on Epoch 5 : 0.7016666531562805
Epoch 6: avg loss = 9.0084
Accuracy on Epoch 6 : 0.715149998664856
Epoch 7: avg loss = 8.5093
Accuracy on Epoch 7 : 0.721916675567627
Epoch 8: avg loss = 8.1235
Accuracy on Epoch 8 : 0.7268333435058594
Epoch 9: avg loss = 7.8148
Accuracy on Epoch 9 : 0.7305333614349365
Epoch 10: avg loss = 7.5620
Accuracy on Epoch 10 : 0.7340333461761475
Epoch 11: avg loss = 7.3503
Accuracy on Epoch 11 : 0.7364833354949951
Epoch 12: avg loss = 7.1703
Accuracy on Epoch 12 : 0.7383166551589966
Epoch 13: avg loss = 7.0148
Accuracy on Epoch 13 : 0.7401999831199646
Epoch 14: avg loss = 6.8789
Accuracy on Epoch 14 : 0.7421833276748657
Epoch 15: avg loss = 6.7592
Accuracy 

Evaluating on Test data

In [35]:
probs = []
ys= []

for x,y in get_batches(x_test,y_test_oh):
        z1 = x @ w1 + b1
        relu = torch.maximum(z1,torch.zeros_like(z1))
        z2 = relu @ w2 + b2
        preds = torch.softmax(z2,dim = 1)
        class_probs = torch.sum(preds * y,dim= 1)
        
        probs.append(preds.detach())
        ys.append(y.detach())
        
probs = torch.cat(probs,dim=0)
ys = torch.cat(ys,dim = 0)       
print(accuracy_of_pred(probs,ys).item())


0.9046000242233276
