In [148]:
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.multiprocessing as mp
from torchinfo import summary

rng = np.random.default_rng()

In [149]:
training_data = datasets.EMNIST(root="Torch MNIST", split="digits", train=True,download=True, transform=transforms.ToTensor())
test_data = datasets.EMNIST(root="Torch MNIST", split="digits", train=False,download=True, transform=transforms.ToTensor())

In [150]:
batch_size = 2**16
train_dataloader = DataLoader(training_data, batch_size = batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size = batch_size, shuffle=True)

In [151]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [152]:
class AddGaussianNoise(nn.Module):
    def __init__(self, mean=0., std=1., *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.std = std
        self.mean = mean
        
    def forward(self, tensor):
        return tensor + torch.randn(tensor.size(), device=device) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [153]:
transform_sequence = torch.nn.Sequential(
    transforms.RandomAffine(15, translate=(0.1,0.1), scale=(0.9,1.1)),
    AddGaussianNoise(mean=0.0,std=0.1)
)

In [154]:
class NeuralNetwork(nn.Module):
    def __init__(self,):
        super().__init__()
        self.linear_stack = nn.Sequential(OrderedDict([
            ("flatten", nn.Flatten()),
            ("dropout1", nn.Dropout(0.1)),
            ("linear1", nn.Linear(28*28*1,1024)),
            ("act1", nn.LeakyReLU()),
            ("batch_norm1", nn.BatchNorm1d(1024)),
            ("dropout2", nn.Dropout(0.1)),
            ("linear2", nn.Linear(1024,1024)),
            ("act2", nn.LeakyReLU()),
            ("batch_norm2", nn.BatchNorm1d(1024)),
            ("dropout3", nn.Dropout(0.1)),
            ("linear3", nn.Linear(1024,256)),
            ("act3", nn.LeakyReLU()),
            ("batch_norm3", nn.BatchNorm1d(256)),
            ("dropout4", nn.Dropout(0.1)),
            ("linear4", nn.Linear(256,64)),
            ("act4", nn.LeakyReLU()),
            ("batch_norm4", nn.BatchNorm1d(64)),
            ("linear5", nn.Linear(64,10))
        ]))

    def forward(self, x):
        if self.training:
            x = transform_sequence(x)
        x = self.linear_stack(x)
        return x

In [155]:
model = NeuralNetwork().to(device)
summary(model, input_size=(batch_size, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
NeuralNetwork                            [65536, 10]               --
├─Sequential: 1-1                        [65536, 10]               --
│    └─Flatten: 2-1                      [65536, 784]              --
│    └─Dropout: 2-2                      [65536, 784]              --
│    └─Linear: 2-3                       [65536, 1024]             803,840
│    └─LeakyReLU: 2-4                    [65536, 1024]             --
│    └─BatchNorm1d: 2-5                  [65536, 1024]             2,048
│    └─Dropout: 2-6                      [65536, 1024]             --
│    └─Linear: 2-7                       [65536, 1024]             1,049,600
│    └─LeakyReLU: 2-8                    [65536, 1024]             --
│    └─BatchNorm1d: 2-9                  [65536, 1024]             2,048
│    └─Dropout: 2-10                     [65536, 1024]             --
│    └─Linear: 2-11                      [65536, 256]              

In [161]:
loss_fn = nn.CrossEntropyLoss()
Optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
model.train(True)

for epoch in range(0,50):
    running_loss = 0.0
    for batch_num, batch_data in enumerate(train_dataloader):
        #print("Epoch {0} Batch {1}".format(epoch+1, batch_num+1))
        Optimiser.zero_grad()
        
        data, labels = batch_data
        
        plt.show()
        labels = nn.functional.one_hot(labels, num_classes = 10).float()
        data = data.to(device)
        labels = labels.to(device)
        
        prediction = model(data)
        loss = loss_fn(prediction,labels)
        
        loss.backward()
        
        
        running_loss += loss.item()
        log_num = 1
        if log_num == 1:
            print(f'[{epoch + 1}, {batch_num + 1:5d}] loss: {running_loss:.3f}')
            running_loss = 0.0
        else:
            if batch_num % log_num == (log_num-1):
                print(f'[{epoch + 1}, {batch_num + 1:5d}] loss: {running_loss / log_num:.3f}')
                running_loss = 0.0
        Optimiser.step()

[1,     1] loss: 0.041
[1,     2] loss: 0.036
[1,     3] loss: 0.031
[1,     4] loss: 0.032
[2,     1] loss: 0.045
[2,     2] loss: 0.031
[2,     3] loss: 0.032
[2,     4] loss: 0.030
[3,     1] loss: 0.034
[3,     2] loss: 0.054
[3,     3] loss: 0.036
[3,     4] loss: 0.032
[4,     1] loss: 0.034
[4,     2] loss: 0.035
[4,     3] loss: 0.034
[4,     4] loss: 0.045
[5,     1] loss: 0.029
[5,     2] loss: 0.028
[5,     3] loss: 0.039
[5,     4] loss: 0.039
[6,     1] loss: 0.045
[6,     2] loss: 0.050
[6,     3] loss: 0.060
[6,     4] loss: 0.045
[7,     1] loss: 0.071
[7,     2] loss: 0.042
[7,     3] loss: 0.029
[7,     4] loss: 0.032
[8,     1] loss: 0.049
[8,     2] loss: 0.028
[8,     3] loss: 0.025
[8,     4] loss: 0.044
[9,     1] loss: 0.044
[9,     2] loss: 0.055
[9,     3] loss: 0.047
[9,     4] loss: 0.055
[10,     1] loss: 0.021
[10,     2] loss: 0.067
[10,     3] loss: 0.033
[10,     4] loss: 0.048
[11,     1] loss: 0.049
[11,     2] loss: 0.028
[11,     3] loss: 0.031
[11,

In [162]:
total = 0
correct = 0
model.eval()

with torch.no_grad():
    for loader_data in test_dataloader:
        data, labels = loader_data
        data = data.to(device)
        labels = labels.to(device)
        output = model(data)
        _, predicted = torch.max(output, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print( 1 - (correct/total))

0.005249999999999977
