# FashionMNIST training (Colab)

Self-contained training script from `train.py` adapted for Google Colab. It will download the dataset automatically and use GPU when available.

In [2]:
import torch
from torchvision import datasets
import torchvision.transforms.v2 as transforms
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running on: {device}')

sd_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(dtype=torch.float32, scale=True)
])

ds_train = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=sd_transform
)
ds_test = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=sd_transform
)

batch_size = 64
dataloader_train = torch.utils.data.DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True
)
dataloader_test = torch.utils.data.DataLoader(
    ds_test,
    batch_size=batch_size,
)


Running on: cpu


In [3]:
from torch import nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.network = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.network(x)
        return logits


def test_accuracy(model, dataloader):
    n_corrects = 0
    model.eval()
    with torch.no_grad():
        for image_batch, label_batch in dataloader:
            image_batch = image_batch.to(device)
            label_batch = label_batch.to(device)
            logits_batch = model(image_batch)
            predict_batch = logits_batch.argmax(dim=1)
            n_corrects += (predict_batch == label_batch).sum().item()

    accuracy = n_corrects / len(dataloader.dataset)
    return accuracy


def train_epoch(model, dataloader, loss_fn, optimizer):
    model.train()
    last_loss = None
    for image_batch, label_batch in dataloader:
        image_batch = image_batch.to(device)
        label_batch = label_batch.to(device)
        logits_batch = model(image_batch)
        loss = loss_fn(logits_batch, label_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        last_loss = loss.item()
    return last_loss


def test_epoch(model, dataloader, loss_fn):
    model.eval()
    with torch.no_grad():
        total_loss = 0.0
        for image_batch, label_batch in dataloader:
            image_batch = image_batch.to(device)
            label_batch = label_batch.to(device)
            logits_batch = model(image_batch)
            loss = loss_fn(logits_batch, label_batch)
            total_loss += loss.item()

    return total_loss / len(dataloader)


model = MyModel().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)


In [4]:
n_epochs = 10

train_loss_log = []
val_loss_log = []
train_acc_log = []
val_acc_log = []

for epoch in range(n_epochs):
    print(f'Epoch {epoch + 1}/{n_epochs}')

    train_loss = train_epoch(model, dataloader_train, loss_fn, optimizer)
    print(f'    training loss: {train_loss}')
    train_loss_log.append(train_loss)

    val_loss = test_epoch(model, dataloader_test, loss_fn)
    print(f'    validation loss: {val_loss}')
    val_loss_log.append(val_loss)

    train_acc = test_accuracy(model, dataloader_train)
    print(f'    training accuracy: {train_acc * 100:.3f}%')
    train_acc_log.append(train_acc)

    val_acc = test_accuracy(model, dataloader_test)
    print(f'    validation accuracy: {val_acc * 100:.3f}%')
    val_acc_log.append(val_acc)


Epoch 1/10
    training loss: 2.191484212875366
    validation loss: 2.146845406028116
    training accuracy: 44.433%
    validation accuracy: 44.390%
Epoch 2/10
    training loss: 1.9149912595748901


KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, n_epochs + 1), train_loss_log, label='train loss')
plt.xticks(range(1, n_epochs + 1))
plt.xlabel('epochs')
plt.ylabel('loss')
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(range(1, n_epochs + 1), val_acc_log, label='validation accuracy')
plt.xticks(range(1, n_epochs + 1))
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.grid()

plt.tight_layout()
plt.show()
