In [1]:
import datetime
from typing import Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.tensorboard import SummaryWriter

In [2]:
# Download the data
def load_data() -> Tuple[Dataset, Dataset, Dataset]:
    from torchvision import datasets
    from torchvision.transforms import ToTensor

    training_data = datasets.FashionMNIST(
        root="data", train=True, download=True, transform=ToTensor()
    )

    test_data = datasets.FashionMNIST(
        root="data", train=False, download=True, transform=ToTensor()
    )

    validation_size = int(0.6 * len(test_data))
    test_size = len(test_data) - validation_size
    validation_data, test_data = random_split(
        test_data, [validation_size, test_size])
    return training_data, validation_data, test_data


training_data, validation_data, test_data = load_data()
training_loader = DataLoader(training_data, batch_size=4, shuffle=True)
# needs to be reproducible, so no shuffling
validation_loader = DataLoader(validation_data, batch_size=4, shuffle=False)
test_loader = DataLoader(test_data, batch_size=4, shuffle=False)

# Model architecture:
```mermaid
flowchart 
    A{Input: \n2D 28x28 array} 
    A -->|28, 28| B[Flatten]
    B -->|784| C[Linear + ReLU]
    C -->|512| D[Linear + ReLU]
    D -->|512| E[Linear]
    E -->|10| F{Output: \nprobabilities of\neach label}
```


In [3]:
# Create the model
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using {device} device")


class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


config = {}

model = Model(config).to(device)

Using cpu device


In [4]:
# Train the model
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


def train_one_epoch(tb_writer=None, epoch_index=0):
    running_loss = 0.0
    last_loss = 0.0

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000  # loss per batch
            print("  batch {} loss: {}".format(i + 1, last_loss))
            if tb_writer is not None:
                tb_writer.add_scalar('Training loss', running_loss /
                                     1000, epoch_index * len(training_loader) + i)

            running_loss = 0.0

    return last_loss


start_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(f'runs/testing_{start_timestamp}')

EPOCHS = 5
best_validation_loss = 1e9

for epoch in range(EPOCHS, 2*EPOCHS):
    print(f"EPOCH {epoch + 1}")

    # Train the model
    model.train(True)
    avg_loss = train_one_epoch(writer, epoch)

    running_validation_loss = 0.0

    # Test it's intermediate performance
    model.eval()
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_validation_loss += vloss
    avg_vloss = running_validation_loss / (i + 1)
    print("LOSS train {} valid {}".format(avg_loss, avg_vloss))
    writer.add_scalars('Training vs. Validation loss', {
                       'Training': avg_loss, 'Validation': avg_vloss}, epoch + 1)
    writer.flush()

    if avg_vloss < best_validation_loss:
        best_validation_loss = avg_vloss
        model_path = f"model_{start_timestamp}_{epoch}"
        torch.save(model.state_dict(), model_path)

EPOCH 6


  batch 1000 loss: 1.3545449294298888


  batch 2000 loss: 0.7317184730656445


  batch 3000 loss: 0.6615200548283756


  batch 4000 loss: 0.5781671963580884


  batch 5000 loss: 0.5578867159232032


  batch 6000 loss: 0.5419525289270095


  batch 7000 loss: 0.5187083787210286


  batch 8000 loss: 0.4792804029965773


  batch 9000 loss: 0.4838538914558012


  batch 10000 loss: 0.4693503861064091


  batch 11000 loss: 0.462582439044374


  batch 12000 loss: 0.4676974238224793


  batch 13000 loss: 0.43523717983881943


  batch 14000 loss: 0.426569092213409


  batch 15000 loss: 0.4150264247608138


LOSS train 0.4150264247608138 valid 0.4451819360256195
EPOCH 7


  batch 1000 loss: 0.3998089502919465


  batch 2000 loss: 0.41008523176563905


  batch 3000 loss: 0.4033077511799056


  batch 4000 loss: 0.3955201131055364


  batch 5000 loss: 0.4093074310820084


  batch 6000 loss: 0.39660053290193903


  batch 7000 loss: 0.37872266429720913


  batch 8000 loss: 0.3795764488193672


  batch 9000 loss: 0.39418133225117347


  batch 10000 loss: 0.40835364494629905


  batch 11000 loss: 0.39210483338602353


  batch 12000 loss: 0.3859634173109371


  batch 13000 loss: 0.36857391634742814


  batch 14000 loss: 0.3992565610330785


  batch 15000 loss: 0.3720108799263544


LOSS train 0.3720108799263544 valid 0.3879810869693756
EPOCH 8


  batch 1000 loss: 0.3589565509983877


  batch 2000 loss: 0.37594136465271005


  batch 3000 loss: 0.3478714046175446


  batch 4000 loss: 0.35116436267009704


  batch 5000 loss: 0.3498368435918237


  batch 6000 loss: 0.3458454228466726


  batch 7000 loss: 0.35142021947255125


  batch 8000 loss: 0.36657327230589


  batch 9000 loss: 0.3566670378212584


  batch 10000 loss: 0.33185284006212895


  batch 11000 loss: 0.34707213478761695


  batch 12000 loss: 0.32892330991991914


  batch 13000 loss: 0.34242258520697944


  batch 14000 loss: 0.352343448432337


  batch 15000 loss: 0.3551879076836849


LOSS train 0.3551879076836849 valid 0.4083085060119629
EPOCH 9


  batch 1000 loss: 0.3293057084519532


  batch 2000 loss: 0.3232165395982738


  batch 3000 loss: 0.3297821080376743


  batch 4000 loss: 0.3323437011573842


  batch 5000 loss: 0.32138048060561414


  batch 6000 loss: 0.3422932093202253


  batch 7000 loss: 0.31971878160431516


  batch 8000 loss: 0.3043320016366488


  batch 9000 loss: 0.33493755099902045


  batch 10000 loss: 0.327584721623949


  batch 11000 loss: 0.3405358667632408


  batch 12000 loss: 0.30782074210689464


  batch 13000 loss: 0.3126215653450781


  batch 14000 loss: 0.31839185775439544


  batch 15000 loss: 0.31149232537587523


LOSS train 0.31149232537587523 valid 0.3739641010761261
EPOCH 10


  batch 1000 loss: 0.30649648767196774


  batch 2000 loss: 0.3260837357753189


  batch 3000 loss: 0.30214775839423236


  batch 4000 loss: 0.292137747337194


  batch 5000 loss: 0.32010161901660467


  batch 6000 loss: 0.2933321541266778


  batch 7000 loss: 0.2955069772548286


  batch 8000 loss: 0.3197626822444727


  batch 9000 loss: 0.32605947726782325


  batch 10000 loss: 0.309505619931675


  batch 11000 loss: 0.28475356966631216


  batch 12000 loss: 0.2920735901299649


  batch 13000 loss: 0.29579464676836503


  batch 14000 loss: 0.29988778893547713


  batch 15000 loss: 0.3047588032895583


LOSS train 0.3047588032895583 valid 0.3930676579475403
