In [1]:
import torch
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [2]:
# hyperparameters

INPUT_NODES = 28 * 28
OUTPUT_NODES = 10
LEARNING_RATE = 1e-5
BATCH_SIZE = 32

In [3]:
training_data = MNIST(
    root="data",
    train=True,
    download=True,
    transform=T.ToTensor()
)

test_data = MNIST(
    root="data",
    train=False,
    download=True,
    transform=T.ToTensor()
)

In [240]:
class LinearLayer:
    def __init__(self, shape):
        # initialize weights with xavier initialization
        self.weights = torch.randn(shape) * torch.sqrt(torch.tensor(2 / (shape[0] + shape[1])))
        self.bias = torch.randn(shape[1])
        
    def __call__(self, input):
        self.input = input
        out = input @ self.weights + self.bias
        return out
    
    def backward(self, doutput):
        # --- comments are an example for linear layer with 100 input, 10 output nodes & BATCH_SIZE = 32 ---
        
        # doutput.shape [32, 1, 10]
        # input.shape [32, 1, 100]  => .T [100, 1, 32]
        # weights.shape [100, 10]
        self.weights.grad = self.input.squeeze(1).T @ doutput.squeeze(1) # [100, 32] @ [32, 10] = [100, 10]
        self.bias.grad = doutput.sum(0).squeeze(0) # doutput.shape [32, 1, 10] => .sum(0).squeeze(0) [10]
        
        dinput = doutput @ self.weights.T # [32, 1, 10] @ [10, 100] = [32, 1, 100]
        return dinput

    def parameters(self):
        return [self.weights, self.bias]

In [241]:
class TanH:
    def __call__(self, input):
        self.output = (torch.exp(input) - torch.exp(-input)) / (torch.exp(input) + torch.exp(-input))
        return self.output
    
    def backward(self, doutput):
        # d/dx (tanh(x)) = 1 - tanh²(x)
        dinput = (1 - self.output**2) * doutput
        return dinput
        
    def parameters(self):
        return []

In [242]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [245]:
class Model:
    def __init__(self):
        self.layers = [
            # (batch_size, 1, 784)
            LinearLayer((INPUT_NODES, 100)), # (batch_size, 1, hidden_nodes)
            TanH(),
            LinearLayer((100, OUTPUT_NODES)) # (batch_size, 1, 10)
        ]

        self.parameters = [p for layer in self.layers for p in layer.parameters()]
        
        for parameter in self.parameters:
            parameter.requires_grad = True
            parameter.retain_grad()

    def __call__(self, input):
        output = input
        for layer in self.layers:
            output = layer(output)

        return self.softmax(output)

    def train(self, epochs, train_data):
        print("------- START TRAINING -------\n")
        train_loader = DataLoader(train_data, batch_size=BATCH_SIZE)

        for epoch in range(epochs):
            batch_losses = []

            for batch in train_loader:
                images, labels = batch

                # flatten imgs so that they can be passed to NN
                flattened_imgs = images.flatten(-2) # (batch_size, 1, 28, 28) => (batch_size, 1, 784)

                # get predictions from NN
                preds = flattened_imgs
                for layer in self.layers:
                    preds = layer(preds)

                # get loss for each output node
                loss, dlogits = self.cross_entropy_loss(preds.squeeze(1), labels)
                dlogits.retain_grad()
                
                # manual backprop through layers of model
                doutput = dlogits.unsqueeze(1)
                for layer in reversed(self.layers):
                    doutput = layer.backward(doutput)
                    
                # loss.backward()
                
                # update weights based on gradients
                for layer in self.layers:
                    for parameter in layer.parameters():
                        parameter.data = parameter.data - parameter.grad * LEARNING_RATE
                        parameter.grad.zero_()

                batch_losses.append(loss.item())

            epoch_loss = sum(batch_losses) / len(batch_losses)
            print(f"Epoch {epoch + 1} / {epochs}: tr_loss: {epoch_loss}")

    def softmax(self, logits):
        # Apply softmax to the predictions 
        norm_logits = logits - torch.max(logits, dim=1, keepdim=True)[0] # subtract max for numerical stability
        soft_preds = norm_logits.exp() / norm_logits.exp().sum(1, keepdim=True) # softmax formula
        
        return soft_preds
    
    def cross_entropy_loss(self, logits, labels):
        # predictions.shape => [batch_size, 1, 10]
        # labels.shape => [batch_size]
        soft_preds = self.softmax(logits)
        log_preds = soft_preds.log()
        loss = -log_preds[range(BATCH_SIZE), labels].mean()
        
        # manual backprop
        dlogits = soft_preds.clone()
        dlogits[range(BATCH_SIZE), labels] -= 1
        dlogits /= BATCH_SIZE
        
        return loss, dlogits

In [246]:
model = Model()
model.train(5, training_data)

------- START TRAINING -------

Epoch 1 / 5: tr_loss: 2.5868744150797527
Epoch 2 / 5: tr_loss: 2.511892426172892
Epoch 3 / 5: tr_loss: 2.4491195699055988
Epoch 4 / 5: tr_loss: 2.3963043631871543
Epoch 5 / 5: tr_loss: 2.3515102249145508


In [220]:
torch.randn((100, 1, 32)).squeeze(1).shape

torch.Size([100, 32])