# Implementing a digit classifier

In [1]:
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from ignite.metrics import ConfusionMatrix
import itertools
from tqdm import tqdm
import time

In [2]:
## Set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cpu')

In [3]:
# Set the seeds
seed = 16
torch.manual_seed(seed) # Set CPU seed
# Set GPU seeds
if torch.cuda.is_available():
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
# Make torch algos deterministic
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Create a rng
rng = torch.Generator().manual_seed(seed)

In [4]:
root = "./resources/data/"
val_size = 10000

# Transform data 
t = transforms.Compose([
  transforms.PILToTensor(),
  transforms.ConvertImageDtype(torch.float),
  torch.flatten,
])

train_data = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=t)
train_set, val_set = torch.utils.data.random_split(train_data, [len(train_data) - val_size, val_size], generator=rng)
print(f"Train size: {len(train_set)}\nValidation size: {len(val_set)}")
train_set[0][0].shape

Train size: 50000
Validation size: 10000


torch.Size([784])

In [5]:
# Model parameters
model_params = dict(
	sizes=[784, 30, 10],
	learning_rate=3,
  device=device,
)

In [9]:
class DigitClassifier(nn.Module):
  def __init__(self, sizes: tuple[int], learning_rate: float, device: torch.device):
    """
    Args:
      sizes: size of each layer
      learning_rate: learning rate when optimising parameters
      device: torch device type (cuda or cpu)
    """
    super().__init__()
    self.num_layers = len(sizes)
    self.act_fn = nn.Sigmoid()
    # Define linear weights between each layer:
    self.linears = nn.ModuleList(
        [nn.Linear(ip, op) for ip, op in zip(sizes, sizes[1:])]
    )

    self.num_classes = sizes[-1]
    self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
    self.loss_module = nn.MSELoss(reduction='mean')

    # Set device
    self.to(device)

  def forward(self, x):
    """Forward pass over all neurons"""
    for layer in self.linears:
      x = layer(x)
      x = self.act_fn(x)
    return x

  def train_loop(self, train_data_loader: DataLoader, num_epochs: int = 30, val_data_loader: DataLoader = None):
    """Training neurons"""
    # Set model to train mode
    model.train()

    # Training loop
    for epoch in range(num_epochs):
      for data_inputs, data_labels in train_data_loader:
        # 1. Move input data to device
        data_inputs = data_inputs.to(device)
        data_labels = data_labels.to(device)

        # 2. Run model on input data
        preds = self.forward(data_inputs)

        # 3. Calculate loss
        loss = self.loss_module(
            preds, nn.functional.one_hot(
                data_labels, num_classes=self.num_classes).float()
        )

        # 4. Perform backpropogation
        self.optimizer.zero_grad()
        loss.backward()

        # 5. Update parameters
        self.optimizer.step()

      # After epoch
      # TODO: Evaluate model on train set
      # Evaluate model on val set
      if val_data_loader:
        precision = self.precision(val_data_loader)
        total = len(val_data_loader.dataset)
        print(f"Epoch: {epoch}: {precision*total} / {total}")

  @torch.no_grad()
  def precision(self, data_loader: DataLoader):
    """Return precision of data"""
    model.eval()
    predicted, labels = zip(*((self.forward(ip), lbl)
                            for ip, lbl in data_loader))
    # Stack and reshape predictions and labels
    predicted = torch.stack(predicted).reshape(-1, self.num_classes)
    labels = torch.stack(labels).flatten()
    # Use argmax to calculate predicted labels
    predicted = predicted.argmax(-1)

    # Calculate true positives
    true_positives = (predicted == labels).sum()

    return true_positives / len(labels)


model = DigitClassifier(**model_params)
model

DigitClassifier(
  (act_fn): Sigmoid()
  (linears): ModuleList(
    (0): Linear(in_features=784, out_features=30, bias=True)
    (1): Linear(in_features=30, out_features=10, bias=True)
  )
  (loss_module): MSELoss()
)

In [10]:
# Check forward model is working
model.forward(train_data[0][0])

tensor([0.4591, 0.5053, 0.4917, 0.5166, 0.5069, 0.6360, 0.4782, 0.4576, 0.3895,
        0.5229], grad_fn=<SigmoidBackward0>)

In [11]:
mini_batch = 10
epochs = 30
train_dataloader = DataLoader(train_set, batch_size=mini_batch, shuffle=True, num_workers=0, drop_last=False)
val_dataloader = DataLoader(val_set, batch_size=mini_batch, shuffle=False, num_workers=0, drop_last=False)
model.train_loop(num_epochs=epochs, train_data_loader=train_dataloader, val_data_loader=val_dataloader)

Epoch: 0: 9205.0 / 10000
Epoch: 1: 9357.0 / 10000
Epoch: 2: 9421.0 / 10000
Epoch: 3: 9446.0 / 10000
Epoch: 4: 9492.0 / 10000
Epoch: 5: 9524.0 / 10000
Epoch: 6: 9519.0 / 10000
Epoch: 7: 9536.0 / 10000
Epoch: 8: 9554.0 / 10000
Epoch: 9: 9568.0 / 10000
Epoch: 10: 9559.0 / 10000
Epoch: 11: 9557.0 / 10000
Epoch: 12: 9565.0 / 10000
Epoch: 13: 9578.0 / 10000
Epoch: 14: 9579.0 / 10000
Epoch: 15: 9581.0 / 10000
Epoch: 16: 9584.0 / 10000
Epoch: 17: 9583.0 / 10000
Epoch: 18: 9583.0 / 10000
Epoch: 19: 9590.0 / 10000
Epoch: 20: 9580.0 / 10000
Epoch: 21: 9599.0 / 10000
Epoch: 22: 9598.0 / 10000
Epoch: 23: 9600.0 / 10000
Epoch: 24: 9594.0 / 10000
Epoch: 25: 9592.0 / 10000
Epoch: 26: 9611.0 / 10000
Epoch: 27: 9606.0 / 10000
Epoch: 28: 9595.0 / 10000
Epoch: 29: 9592.0 / 10000


In [24]:
# Save model parameters
state_dict = model.state_dict()
timestr = time.strftime("%Y%m%d-%H%M%S")
model_fname = f"resources/model/{timestr}.tar"
torch.save(state_dict, model_fname)

In [27]:
# Load model
state_dict = torch.load(model_fname)
model = DigitClassifier(**model_params)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [28]:
# Evaluate model on val_set
precision = model.precision(DataLoader(val_set, batch_size=8, shuffle=False, drop_last=False))
precision

tensor(0.9516)

# Exercise - No hidden layer

Aim: Find out how well NN does without a hidden layer

In [31]:
model_2layer_params = dict(
  sizes=[784, 10],
  learning_rate=3,
  device=device,
)
model_2layer = DigitClassifier(**model_2layer_params)
model_2layer

DigitClassifier(
  (act_fn): Sigmoid()
  (linears): ModuleList(
    (0): Linear(in_features=784, out_features=10, bias=True)
  )
  (loss_module): CrossEntropyLoss()
)

In [32]:
# Train model
epochs = 10
model_2layer.train_loop(num_epochs=epochs, train_data_loader=train_dataloader, val_data_loader=val_dataloader)

Epoch: 0: 8927.0 / 10000
Epoch: 1: 8999.0 / 10000
Epoch: 2: 8931.0 / 10000
Epoch: 3: 9023.0 / 10000
Epoch: 4: 9016.0 / 10000
Epoch: 5: 9082.0 / 10000
Epoch: 6: 9066.0 / 10000
Epoch: 7: 9091.0 / 10000
Epoch: 8: 9072.0 / 10000
Epoch: 9: 9041.0 / 10000


In [36]:
# Perform hyperparameter search over mini_batch and learning rate:

learning_rates = [0.1,1,1.5,2,4]
precisions = []
epochs = 5

for learning_rate in learning_rates:
  m = DigitClassifier(sizes=[784,10], learning_rate=learning_rate, device=device)
  m.train_loop(num_epochs=epochs, train_data_loader=train_dataloader)
  precision = m.precision(DataLoader(val_set, batch_size=8))
  print(f"Learning rate: {learning_rate}\tPrecision: {precision:.2f}")
  precisions.append(precision)

Learning rate: 0.1	Precision: 0.8964999914169312
Learning rate: 1	Precision: 0.9092000126838684
Learning rate: 1.5	Precision: 0.9049000144004822
Learning rate: 2	Precision: 0.9067000150680542
Learning rate: 4	Precision: 0.9052000045776367


## Playground

In [166]:
# Create a dataset for playground work
idea_set, _ = torch.utils.data.random_split(val_set, [16, len(val_set) - 16], rng)

In [179]:
result, labels = zip(*((model.forward(ip), lbl) for ip, lbl in DataLoader(idea_set, batch_size=8)))
labels

(tensor([4, 7, 0, 2, 2, 8, 8, 4]), tensor([1, 2, 2, 8, 2, 4, 7, 8]))

In [180]:
result = torch.stack(result).reshape(-1, 10)
labels = torch.stack(labels).flatten()
labels

tensor([4, 7, 0, 2, 2, 8, 8, 4, 1, 2, 2, 8, 2, 4, 7, 8])

In [184]:
(result.argmax(-1) == labels).sum() / len(labels)

tensor(1.)

In [200]:
len(val_dataloader.dataset)

12000