In [1]:
import torch
from torch import optim
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [2]:
# Prepare data and dataloader
train = datasets.MNIST("data", train = True, download = True, transform = transforms.ToTensor())
train, valid = random_split(train, [55000, 5000]) # Perform data split
train_dl = DataLoader(train, batch_size = 32) # Train dataloader
valid_dl = DataLoader(valid, batch_size = 32) # Valid dataloader

In [3]:
# Create basic model
basic_model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 10)
)

In [4]:
# Flexy model
class FlexyModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(28 * 28, 64)
    self.l2 = nn.Linear(64, 64)
    self.l3 = nn.Linear(64, 10)
    self.dropo = nn.Dropout(0.1)

  def forward(self, x):
    h1 = nn.functional.relu(self.l1(x))
    h2 = nn.functional.relu(self.l2(h1))

    # Here is the residual connection
    dropo = self.dropo(h2 + h1)

    logits = self.l3(dropo)
    return logits
  
flexy_model = FlexyModel()

In [5]:
def training(model):

  # Create optimiser
  optimiser = optim.SGD(model.parameters(), lr = 1e-2)

  # Create loss function CEL
  loss = nn.CrossEntropyLoss()
  n_epochs = 7

  for epoch in range(n_epochs):
    losses = list()
    accuracies = list()
    for batch in train_dl:
      # x = image
      x, y = batch

      # Each batch is a 28 x 28 image * number of images
      # i.e. number of images * 28 x 28
      b = x.size(0)
      x = x.view(b, -1)

      # Forward step
      logits = model(x)

      # Objective function
      # - calculate training loss using logits and actual
      obj = loss(logits, y)

      # Clean gradient
      model.zero_grad()

      # Accumulate partial derivs of obj w.r.t params
      obj.backward()

      # Updates! I.e. step opposite direction of grads
      optimiser.step()

      losses.append(obj.item())

    print("training loss:", torch.tensor(losses).mean())

    # REPEAT FOR VALIDATION, not as many steps!
    losses = list()
    model.eval()
    for batch in valid_dl:
      x, y = batch

      b = x.size(0)
      x = x.view(b, -1)

      # Note no_grad here (i.e. don't keep tracing gradients, graphs...)
      with torch.no_grad():
        logits = model(x)

      # Calculate validation loss
        obj = loss(logits, y)

      losses.append(obj.item())

    print("valid loss:", torch.tensor(losses).mean())

In [6]:
training(basic_model)

training loss: tensor(1.2649)
valid loss: tensor(0.4831)
training loss: tensor(0.3890)
valid loss: tensor(0.3399)
training loss: tensor(0.3155)
valid loss: tensor(0.2927)
training loss: tensor(0.2794)
valid loss: tensor(0.2628)
training loss: tensor(0.2531)
valid loss: tensor(0.2410)
training loss: tensor(0.2321)
valid loss: tensor(0.2231)
training loss: tensor(0.2145)
valid loss: tensor(0.2083)


In [7]:
training(flexy_model)

training loss: tensor(0.8328)
valid loss: tensor(0.3870)
training loss: tensor(0.3370)
valid loss: tensor(0.3029)
training loss: tensor(0.2841)
valid loss: tensor(0.2634)
training loss: tensor(0.2498)
valid loss: tensor(0.2343)
training loss: tensor(0.2228)
valid loss: tensor(0.2110)
training loss: tensor(0.2009)
valid loss: tensor(0.1920)
training loss: tensor(0.1829)
valid loss: tensor(0.1766)
