In [1]:
!pip install torch
!pip install torchmetrics
!pip install torchvision

Collecting torchmetrics
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.9.0 torchmetrics-1.2.0


## 1. Import Library

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T

from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchmetrics import Accuracy
from torchmetrics.aggregation import MeanMetric

## 2. Build Config

In [3]:
# Build config
title = 'FashionMNIST'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_root = 'data'
batch_size = 64
base_lr = 0.01
momentum = 0.9
epochs = 5
checkpoint_dir = 'checkpoint'

## 3. Build Directory

In [4]:
# Build directory
os.makedirs(checkpoint_dir, exist_ok=True)

## 4. Build Dataset

In [5]:
# Build dataset
train_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,)),
])
train_data = FashionMNIST(data_root, train=True, download=True, transform=train_transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

val_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,)),
])
val_data = FashionMNIST(data_root, train=False, download=True, transform=val_transform)
val_loader = DataLoader(val_data, batch_size=batch_size)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 13607267.55it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 195449.04it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3745787.24it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5553569.19it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw






## 5. Build Model

In [6]:
# Define model
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.head = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = x.reshape((x.shape[0], -1))
        x = self.head(x)
        return x

# Build model
model = MyModel()

# Move model to device
model = model.to(device)

# Build model
model = MyModel()
print(model)

# Move model to device
model = model.to(device)

MyModel(
  (head): Linear(in_features=784, out_features=10, bias=True)
)


## 6. Set Optimizer, Scheduler, Loss function

In [7]:
# Build optimizer
optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=momentum)

# Build scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs * len(train_loader))

# Build loss function
loss_fn = nn.CrossEntropyLoss()

# Build metric function
metric_fn = Accuracy(task='multiclass', num_classes=10)
metric_fn = metric_fn.to(device)

## 7. Define Train Loop

In [8]:
# Define training loop
def train(loader, model, optimizer, scheduler, loss_fn, metric_fn, device):
    # Set model to train mode
    model.train()

    # Create average meters to measure loss and metric
    loss_mean = MeanMetric()
    metric_mean = MeanMetric()

    # train model for one epoch
    for inputs, targets in loader:
        # Move data to device
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        metric = metric_fn(outputs, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update statistics
        loss_mean.update(loss.to('cpu'))
        metric_mean.update(metric.to('cpu'))

        # Update learning rate
        scheduler.step()

    # Summarize statistics
    summary = {'loss': loss_mean.compute(), 'metric': metric_mean.compute()}

    return summary

## 8. Define Evaluation Loop

In [9]:
# Define evaluation function
def evaluate(loader, model, loss_fn, metric_fn, device):
    # Set model to evaluation mode
    model.eval()

    # Create average meters to measure loss and accuracy
    loss_mean = MeanMetric()
    metric_mean = MeanMetric()

    # Evalute model for one epoch
    for inputs, targets in loader:
        # Move data to device
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward
        with torch.no_grad():
            outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        metric = metric_fn(outputs, targets)

        # Update statistics
        loss_mean.update(loss.to('cpu'))
        metric_mean.update(metric.to('cpu'))

    # Summarize statistics
    summary = {'loss': loss_mean.compute(), 'metric': metric_mean.compute()}

    return summary

## 9. Define Main Loop

In [10]:
# Main loop
for epoch in range(epochs):
    # train one epoch
    train_summary = train(train_loader, model, optimizer, scheduler, loss_fn, metric_fn, device)

    # evaluate one epoch
    val_summary = evaluate(val_loader, model, loss_fn, metric_fn, device)

    # print log
    print((f'Epoch {epoch+1}: '
           + f'Train Loss {train_summary["loss"]:.04f}, '
           + f'Train Accuracy {train_summary["metric"]:.04f}, '
           + f'Test Loss {val_summary["loss"]:.04f}, '
           + f'Test Accuracy {val_summary["metric"]:.04f}'))

    # save model
    state_dict = {
        'epoch': epoch + 1,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    checkpoint_path = f'{checkpoint_dir}/{title}_last.pth'
    torch.save(state_dict, checkpoint_path)

Epoch 1: Train Loss 0.5293, Train Accuracy 0.8143, Test Loss 0.4886, Test Accuracy 0.8312
Epoch 2: Train Loss 0.4489, Train Accuracy 0.8435, Test Loss 0.4666, Test Accuracy 0.8354
Epoch 3: Train Loss 0.4231, Train Accuracy 0.8530, Test Loss 0.4689, Test Accuracy 0.8313
Epoch 4: Train Loss 0.4064, Train Accuracy 0.8596, Test Loss 0.4454, Test Accuracy 0.8421
Epoch 5: Train Loss 0.3965, Train Accuracy 0.8633, Test Loss 0.4423, Test Accuracy 0.8438


## 10. Load Model

In [11]:
# Load model
model_pretrained = MyModel()

checkpoint_path = f'{checkpoint_dir}/{title}_last.pth'
state_dict = torch.load(checkpoint_path)

model_pretrained.load_state_dict(state_dict['model'])

<All keys matched successfully>

## 11. Comparison with randomly initiailized Model

In [12]:
model_random = MyModel()

model_random.to(device)
model_pretrained.to(device)

random_summary = evaluate(val_loader, model_random, loss_fn, metric_fn, device)
pretrained_summary = evaluate(val_loader, model_pretrained, loss_fn, metric_fn, device)

print(f'[Random] Test Acc {random_summary["metric"]:.04f}')
print(f'[Pretrained] Test Acc {pretrained_summary["metric"]:.04f}')

[Random] Test Acc 0.0616
[Pretrained] Test Acc 0.8438
