In [1]:
from torch import nn, optim, tensor, no_grad, randn
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms
from collections import OrderedDict

In [2]:
# Get the data
mnist_data = datasets.MNIST(
    root = 'data',
    train=True,
    download = True,
    transform = transforms.ToTensor())

In [3]:
def get_abs_split(dataset, train_fraction = 0.8):
    """
    Returns the absolute value of items to split into (train, test) from the input dataset

    Arguments
    ---------
    dataset : The torchvision dataset to split.

    Parameters
    ----------
    train_fraction : Int (Default = 0.8)
        The fraction of data points that will be used for training. Must be between [0, 1].

    Returns
    -------
    train_size, test_size : Tuple (int, int)
        The absolute number of data points to in train vs test.
    """
    size = dataset.data.size()[0]
    train_size = int(size * train_fraction)
    test_size = size - train_size
    return train_size, test_size

In [4]:
# Split the data into train and test
train_val_split = 0.8
train, val = random_split(mnist_data, get_abs_split(mnist_data, train_val_split))

In [5]:
# Create data loaders for the train and test data
train_loader_batch_size = 32
train_loader = DataLoader(train, batch_size = train_loader_batch_size)

val_loader_batch_size = 32
val_loader = DataLoader(val, batch_size = val_loader_batch_size)

In [6]:
# Model architecture parameters
mnist_image_size = mnist_data.data.size()[1] *  mnist_data.data.size()[2]
linear1_dims = 64
linear2_dims = 64
output_dims = 10 # predicting 10 digits

In [7]:
# Define a simple model

# model = nn.Sequential(OrderedDict([
#     ('linear1', nn.Linear(mnist_image_size, linear1_dims)),
#     ('relu1', nn.ReLU()),
#     ('linear2', nn.Linear(linear1_dims, linear2_dims)),
#     ('relu2', nn.ReLU()),
#     ('dropout', nn.Dropout(0.1)), # in case of overfitting
#     ('output', nn.Linear(linear2_dims, output_dims))
# ]))

In [8]:
# Define a more flexible model
class MNISTModel(nn.Module):
    def __init__(self, input_dims, linear1_dims, linear2_dims, output_dims, dropout_rate = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(input_dims, linear1_dims)
        self.linear2 = nn.Linear(linear1_dims, linear2_dims)
        self.linear3 = nn.Linear(linear2_dims, output_dims)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        hidden1 = nn.functional.relu(self.linear1(x))
        hidden2 = nn.functional.relu(self.linear2(hidden1))
        dropout = self.dropout(hidden2 + hidden1) # Creates a residual connection
        logits = self.linear3(dropout)
        return logits

In [10]:
model = MNISTModel(mnist_image_size, linear1_dims, linear2_dims, output_dims)

In [11]:
# Define an optimizer
lr = 1e-2
optimizer = optim.SGD(model.parameters(), lr=lr)

In [12]:
# Define loss
loss = nn.CrossEntropyLoss()

In [13]:
# Training loop
epochs = 5

for epoch in range(epochs):
    # Training loop
    train_losses = list()
    for batch in train_loader:
        x, y = batch
        batch_size = x.size(0)

        # flatten x : batch size * num channels (= 1) * image size (= 28 * 28)
        x = x.view(batch_size, -1)

        # 1. Forward pass
        logits = model(x)

        # 2. Compute objective
        J = loss(logits, y)

        # 3. Clean the gradients. We are updating gradients per batch and not accumulating so clear them between each batch.
        model.zero_grad() # Note: Can also zero gradients from the optimizer. optimizer.zero_grad()

        # 4. Backward pass. Accumulate the partial derivatives.
        J.backward()

        # 5. Learn and update weights and biases
        optimizer.step()
        train_losses.append(J.item())
    print(f'Epoch {epoch + 1}, training loss: {tensor(train_losses).mean():.2f}')

    # Validation loop
    val_losses = list()
    for batch in val_loader:
        x, y = batch
        batch_size = x.size(0)

        # flatten x : batch size * num channels (= 1) * image size (= 28 * 28)
        x = x.view(batch_size, -1)

        # 1. Forward pass
        with no_grad():
            logits = model(x)

        # 2. Compute objective
        J = loss(logits, y)
        val_losses.append(J.item())
    print(f'Epoch {epoch + 1}, validation loss: {tensor(val_losses).mean():.2f}')


Epoch 1, training loss: 0.87
Epoch 1, validation loss: 0.45
Epoch 2, training loss: 0.39
Epoch 2, validation loss: 0.37
Epoch 3, training loss: 0.33
Epoch 3, validation loss: 0.32
Epoch 4, training loss: 0.29
Epoch 4, validation loss: 0.28
Epoch 5, training loss: 0.26
Epoch 5, validation loss: 0.26
