In [None]:
# MNIST PyTorch model by building custom Dataset subclass
# built-in MNIST PyTorch Dataset subclass available on torchvision.dataset; this is for practice
# MNIST datasets in CSV form accessed at https://github.com/phoebetronic/mnist

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import random

In [None]:
class MyDataset(Dataset):

    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file, delimiter=",", header = None)
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        entry = self.df.iloc[index]
        image_in_1D = torch.from_numpy(entry[1:].to_numpy())
        image_in_3D = torch.reshape(image_in_1D, (1, 28, 28))
        image = image_in_3D / 255
        label = torch.tensor(entry[0])
        return image, label

In [None]:
train_dataset = MyDataset("mnist_train.csv")
test_dataset = MyDataset("mnist_test.csv")
train_dataloader = DataLoader(train_dataset, batch_size = 100, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size = 100, shuffle = True)

In [None]:
# http://machinelearningmastery.com/building-a-convolutional-neural-network-in-pytorch/
# with amendments made to fit my 1 channel, 28x28 pixel images

class CIFAR10Model(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 28, kernel_size=(3,3), stride=1, padding=1)
        self.act1 = nn.ReLU()
        self.drop1 = nn.Dropout(0.3)
        
        self.conv2 = nn.Conv2d(28, 28, kernel_size=(3,3), stride=1, padding=1)
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.flat = nn.Flatten()
        
        self.fc3 = nn.Linear(5488, 512)
        self.act3 = nn.ReLU()
        self.drop3 = nn.Dropout(0.5)
        
        self.fc4 = nn.Linear(512, 10)

    def forward(self, x):
        
        # input 1x28x28, output 28x28x28
        x = self.act1(self.conv1(x))
        x = self.drop1(x)     
        
        # input 28x28x28, output 28x28x28
        x = self.act2(self.conv2(x))
        
        # input 28x28x28, output 28x14x14
        x = self.pool2(x)
        
        # input 28x14x14, output 5488
        x = self.flat(x)
        
        # input 5488, output 512
        x = self.act3(self.fc3(x))
        x = self.drop3(x)
        
        # input 512, output 10
        x = self.fc4(x)
        
        return x

In [None]:
model = CIFAR10Model()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

learning_rate = 1e-3
epochs = 20
batch_size = 100

In [None]:
# https://docs.pytorch.org/tutorials/beginner/basics/optimization_tutorial.html

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch_num, (X, y) in enumerate(dataloader):
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        if (batch_num + 1) % 100 == 0:
            loss, current = loss.item(), (batch_num + 1) * batch_size
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]") 


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_of_batches = len(dataloader)
    test_loss, correct = 0, 0
    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_of_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")    

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

In [None]:
# testing this has actually worked

indices = random.sample(range(0, 10000), 10)

for i in indices:
    image, label = test_dataset[i]
    image = torch.reshape(image, (1, 1, 28, 28))
    pred = model(image).argmax(1)
    plt.imshow(image[0,0,:,:], cmap="grey")
    plt.title(f"Image {i} from the MNIST test dataset")
    plt.show()
    print(f"Model suggests this should show {pred}.\n")
    

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')

In [None]:
model = CIFAR10Model()
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()

In [None]:
# All done!