> Simple ResNet implementation for Fashion-MNIST (ResNet was originally used for ImageNet which contains 1,000 categories).

In [None]:
%matplotlib inline

import os
import torch

from torch import nn
from torch.nn import functional as F
from typing import Any, Union, List

# Setup Device

In [None]:
print(torch.cuda.is_available())
device: Any = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Data

In [None]:
from utils.utils import load_fashion_mnist

mnist_path: str = os.path.join('..', '..', 'data')
num_dataloader_workers: int = 4
batch_size: int = 128

# Resize images to 224x224 for AlexNet
# This is not a well-method, but it is simple and works for demonstration purposes.
train_iter, test_iter = load_fashion_mnist(batch_size, mnist_path, num_dataloader_workers, resize=224)

# ResNet-18 Builder

In [None]:
# Define A reshape layer which makes inputs compatible with Conv2d layers
class Reshape(nn.Module):
    def forward(self: Any, X: torch.Tensor) -> torch.Tensor:
        return X.view(-1, 1, 224, 224)

In [None]:
# Define A Identity layer which does nothing
class Identity(nn.Module):
    def forward(self: Any, X: torch.Tensor) -> torch.Tensor:
        return X

In [None]:
# Build Residual Sub-block
class Residual18(nn.Module):
    """Residual Block for ResNet."""
    def __init__(self: Any, in_channels: int, out_channels: int, use_1x1conv: bool = False, strides: int = 1) -> None:
        super(Residual18, self).__init__()
        self.conv1: nn.Conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                            kernel_size=3, padding=1, stride=strides)
        self.bn1: nn.BatchNorm2d = nn.BatchNorm2d(num_features=out_channels)
        self.conv2: nn.Conv2d = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                                            kernel_size=3, padding=1, stride=1)
        self.bn2: nn.BatchNorm2d = nn.BatchNorm2d(num_features=out_channels)
        self.res_conn: Any = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                        kernel_size=1, stride=strides) if use_1x1conv else Identity()

    def forward(self: Any, X: torch.Tensor) -> torch.Tensor:
        Y: torch.Tensor = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        Y += self.res_conn(X)   # Residual connection
        return F.relu(Y)

In [None]:
# Build Residual Block
def resnet18_block(in_channels: int, out_channels: int, num_residuals: int, first_block: bool = False) -> nn.Sequential:
    """Build a ResNet-18 block with multiple residual sub-blocks."""
    blks: List[Residual18] = []
    for i in range(num_residuals):
        if not first_block and i == 0:
            blks.append(Residual18(in_channels=in_channels, out_channels=out_channels, use_1x1conv=True, strides=2))
        else:
            blks.append(Residual18(in_channels=out_channels, out_channels=out_channels))
    return blks

In [None]:
# Build ResNet-18 Architecture
def resnet18(in_channels: int) -> nn.Sequential:
    """Create ResNet-18 model."""
    block1: nn.Sequential = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=3, stride=2),
        nn.BatchNorm2d(num_features=64), nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, padding=1, stride=2),
    )
    block2: nn.Sequential = resnet18_block(in_channels=64, out_channels=64, num_residuals=2, first_block=True)
    block3: nn.Sequential = resnet18_block(in_channels=64, out_channels=128, num_residuals=2)
    block4: nn.Sequential = resnet18_block(in_channels=128, out_channels=256, num_residuals=2)
    block5: nn.Sequential = resnet18_block(in_channels=256, out_channels=512, num_residuals=2)
    # Build the ResNet-18 model
    resnet18: nn.Sequential = nn.Sequential(
        Reshape(),
        *block1, *block2, *block3, *block4, *block5,
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(in_features=512, out_features=10)
    )
    return resnet18

# ResNet-18

In [None]:
# Build ResNet Architecture
resnet18_model: nn.Sequential = resnet18(in_channels=1)

In [None]:
# Initialize model parameters
def init_weights(m: nn.Module) -> None:
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

_ = resnet18_model.apply(init_weights)

In [None]:
# Check ResNet-18 model
X: torch.Tensor = torch.rand(size=(10, 1, 224, 224), dtype=torch.float32)
for layer in resnet18_model:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)
    if isinstance(layer, nn.Sequential):
        for sublayer in layer:
            print('\t', sublayer.__class__.__name__, 'output shape:\t', X.shape)

# Train Model

In [None]:
# Setup hyper-parameters
num_epochs: int = 10
lr: float = 5e-2
net: Any = resnet18_model.to(device)    # Move the model to the device (GPU or CPU)
loss: Any = nn.CrossEntropyLoss(reduction='mean')   # PyTorch's CE contains softmax
trainer: Any = torch.optim.SGD(net.parameters(), lr=lr)     # Use SGD optimizer

In [None]:
# Build model training procedure
from utils.trainer import train
train(net, train_iter, test_iter, loss, num_epochs, trainer, device)

# Verify Train Result

In [None]:
from utils.utils import verify_trained_model
verify_trained_model(net, test_iter, device)