In [29]:
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [2]:
# Get the device for training the model
device = "cuda" if torch.cuda.is_available() else "cpu"
# uncomment line below to force cpu use for model
# device = "cpu"

print(f"Using device: {device}")

Using device: cuda


In [3]:
# Import the data
data = loadmat("mnist_lda.mat")

data_train = data["train_data"]
data_test = data["test_data"]
class_train = data["train_class"]
class_test = data["test_class"]

In [13]:
# define the neural netowrk
# here we use ReLu for the activation function
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(9, 28),
            nn.ReLU(),
            nn.Linear(28, 28),
            nn.ReLU(),
            nn.Linear(28, 10)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = Network().to(device)

In [19]:
# hyperparams
n_epochs = 10
batch_size = 100
learning_rate = 0.001
input_size = 9

In [36]:
# create the training loop for the model
def train(tl, lr=learning_rate, n=n_epochs, inpsz=input_size, bs=batch_size, ml=model):
    # We use cross entropy for the error function
    criterion = nn.CrossEntropyLoss()
    # And we use stochastic gradiant descent since it is faster
    opttimizer = torch.optim.SGD(ml.parameters(), lr=lr)
    
    # a vector to keep track of the error
    err_tracker = np.zeros(len(tl) * n)
    
    # Training loop
    for epoch in range(n):
        for i, (images, labels) in enumerate(tl):
            sample = images.reshape(bs, -1).to(device)
            labels = labels.view(labels.shape[0], 1).to(device)
            
            # forward
            output = ml(sample)
            error = criterion(output, labels)
            err_tracker[i*epoch] = error.item()
            
            # backward
            optimizer.zero_grad()
            error.backward()
            optimizer.step()
    
    return model, err_tracker

In [24]:
# create the evaluation loop
def validate(tl, ml=model, bs=batch_size):
    with torch.no_grad():
        label_est = []
        label_true = []
        
        n_samples = 0
        n_diff = 0
        
        for images, labels in tl:
            sample = images.reshape(bs, -1).to(device)
            labels = labels.view(labels.shape[0], 1).to(device)
            
            outputs = ml(sample)
            
            label_est += outputs.tolist()
            label_true += labels.tolist()
            
            n_diff += torch.mean(torch.abs(outputs-labels))
            n_samples += 1
        
        acc = n_diff/n_samples
    
    return label_est, label_true, acc

In [41]:
# Class to make the data usable for the model
# This is a quirk of torch, and not something relevant for the theory of this exercise
class NumbersDataset(Dataset):
    def __init__(self, samples, labels):
        self.samples = torch.from_numpy(samples).to(torch.float32)
        self.labels = torch.from_numpy(labels).to(torch.float32)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

In [42]:
# ready the data for the model
data_train = NumbersDataset(data_train, class_train)
data_test = NumbersDataset(data_test, class_test)

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=True)

ValueError: given numpy array has byte order different from the native byte order. Conversion between byte orders is currently not supported.

In [37]:
# Train the data
model_trained, error = train(train_loader)

ValueError: given numpy array has byte order different from the native byte order. Conversion between byte orders is currently not supported.