In [1]:
'''*** Importing Libraries ***'''
import torch as t
from torch import nn
import numpy as np
import gzip
import pickle   

In [2]:
'''Utility Functions'''

def to_numpy(x: t.Tensor) -> np.ndarray:
    return x.detach().cpu().numpy()

def vectorized_result(j: int) -> np.ndarray:
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e

def load_mnist():
    with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
        tr_d, va_d, te_d = pickle.load(f, encoding='latin1')
    
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs, training_results)
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
    validation_data = zip(validation_inputs, va_d[1])
    test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
    test_data = zip(test_inputs, te_d[1])
    return (training_data, validation_data, test_data)


In [3]:
class ClassifierNetwork(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        layers = []
        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(nn.Sigmoid())
        layers.append(nn.Linear(hidden_size, hidden_size))
        layers.append(nn.Sigmoid())
        layers.append(nn.Linear(hidden_size, output_size))
        layers.append(nn.Sigmoid())
        self.classifier = nn.Sequential(*layers)
    
    def forward(self, T: t.Tensor) -> t.Tensor:
        return self.classifier(T) 

In [4]:
class NumberClassifier:

    def __init__(self, input_size, hidden_size, output_size, lr = 1e-3):
        self.device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
        self.model = ClassifierNetwork(input_size, hidden_size, output_size).to(self.device)
        self.loss_fn = nn.MSELoss()
        self.optimizer = t.optim.SGD(self.model.parameters(), lr = lr)
    
    def train(self, training_data, epochs = 100, batch_size = 32, test_data = None):
        epoch_loss = []
        training_data = list(training_data)

        if test_data:
            test_data = list(test_data)

        for epoch in range(epochs):
            mini_batch_loss = []
            np.random.shuffle(training_data)
            mini_batches = [training_data[i:i+batch_size] for i in range(0, len(training_data), batch_size)]
            for mini_batch in mini_batches:
                x = [x[0] for x in mini_batch]
                y = [x[1] for x in mini_batch]

                x = t.tensor(x).squeeze().to(self.device)
                y = t.tensor(y).squeeze().to(self.device)

                mini_batch_loss.append(self.update(x, y))

            epoch_loss.append(np.mean(mini_batch_loss))
            print(f'---------------Epoch {epoch}------------------')
            print(f'Loss: {epoch_loss[-1]}')
            if test_data:
                x_test = [x[0] for x in test_data]
                y_test = [x[1] for x in test_data]
                x_test = t.tensor(x_test).squeeze().to(self.device)
                y_test = t.tensor(y_test).squeeze()
                print(f'Test Accuracy: {self.evaluate(x_test, y_test)} / {len(test_data)}')

        return epoch_loss
    
    def update(self, x: t.Tensor, y: t.Tensor):
        self.optimizer.zero_grad()
        pred = self.model(x).to(t.double)
        loss = self.loss_fn(pred, y)
        loss.backward()
        self.optimizer.step()
        return to_numpy(loss)

    def evaluate(self, x: t.Tensor, y: t.Tensor):
        pred = self.predict(x)
        picks = np.argmax(pred, axis = 1) 
        logical = [int(x == y) for x,y in zip(picks, y)]
        return sum(logical)
    
    def predict(self, x: t.Tensor):
        return to_numpy(self.model(x))


            
                

In [5]:
training_data, validation_data, test_data = load_mnist()

classifier = NumberClassifier(784, 30, 10, lr=1)

loss = classifier.train(training_data, test_data=test_data)

  x = t.tensor(x).squeeze().to(self.device)


---------------Epoch 0------------------
Loss: 0.09035380888344269
Test Accuracy: 2937 / 10000
---------------Epoch 1------------------
Loss: 0.08781168427888705
Test Accuracy: 2652 / 10000


KeyboardInterrupt: 