### Initialization

In [None]:
# Libraries
import torch
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset, SubsetRandomSampler
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchmetrics.classification import MulticlassAccuracy
import matplotlib.pyplot as plt
from pathlib import Path
from timeit import default_timer as timer
from tqdm.auto import tqdm
from sklearn.model_selection import KFold
from helper_functions import multi_class_predictions

# Path to models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)
MODEL_NAME = "fashionMNIST_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# Device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"PyTorch: {torch.__version__}")

# Setup Data

In [None]:
BATCH_SIZE = 32
NUM_SPLITS = 5

# Training data
train_dataset = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

# Testing data
test_dataset = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# Build set of dataloaders for k-fold cross validation
dataset = ConcatDataset([train_dataset, test_dataset])
kfold = KFold(n_splits=NUM_SPLITS, shuffle=True)
classes = train_dataset.classes
train_dataloaders = []
test_dataloaders = []

for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    train_subsampler = SubsetRandomSampler(train_ids)
    test_subsampler = SubsetRandomSampler(test_ids)
    
    train_dataloaders.append(DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_subsampler))
    test_dataloaders.append(DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_subsampler))

### Visualize first batch

In [None]:
torch.manual_seed(42)
fig = plt.figure(figsize=(18, 9))
rows, cols = 4, 8

# Get random batch
batch_features, batch_labels = next(iter(train_dataloaders[0]))

for i in range(0, 32):
    img, label = batch_features[i], batch_labels[i]
    fig.add_subplot(rows, cols, i + 1)
    plt.imshow(img.squeeze(), cmap="gray")
    plt.title(classes[label])
    plt.axis(False)

# Build Convolutional Neural Network

In [None]:
class FashionMNISTCNN(nn.Module):
    def __init__(self, input_shape: int, hidden1: int, hidden2: int, output_shape: int, imgsize: int):
        super().__init__()
        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape, out_channels=hidden1, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden1, out_channels=hidden2, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=imgsize * 4, out_features=600),
            nn.Dropout(0.25),
            nn.Linear(in_features=600, out_features=120),
            nn.Linear(in_features=120, out_features=output_shape)
        )
        self.layer_stack = nn.Sequential(
            self.block_1,
            self.block_2,
            self.classifier
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layer_stack(x)

torch.manual_seed(42)
model = FashionMNISTCNN(1, 32, 64, len(classes), 28 * 28).to(device)

## Test Predictions before Training

In [None]:
batch_features, batch_labels = next(iter(train_dataloaders[0]))
multi_class_predictions(model, batch_features, batch_labels, classes, device)

### Loss function, optimizer, and evaluation metrics

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
acc_metric = MulticlassAccuracy(num_classes=len(classes)).to(device)

### Train model

In [None]:
torch.manual_seed(42)
start = timer()
epochs = 100
train_loss_values = []
train_acc_values = []
test_loss_values = []
test_acc_values = []

for epoch in tqdm(range(epochs)):
    ### Training ###
    model.train()
    avg_train_loss, avg_train_acc = 0, 0
    
    # Loop through training batches
    for X, y in train_dataloaders[epoch % NUM_SPLITS]:
        # Transfer to correct device
        X, y = X.to(device), y.to(device)
        
        # Forward pass
        y_pred = model(X)
        
        # Calculate the loss
        loss = loss_fn(y_pred, y)
        acc = acc_metric(y_pred.argmax(dim=1), y) * 100

        avg_train_loss += loss.detach().cpu().numpy()
        avg_train_acc += acc.detach().cpu().numpy()
        
        # Perform backpropagation on the loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_train_loss /= len(train_dataloaders[0])
    avg_train_acc /= len(train_dataloaders[0])
    train_loss_values.append(avg_train_loss)
    train_acc_values.append(avg_train_acc)
    
    ### Testing ###
    model.eval()
    avg_test_loss, avg_test_acc = 0, 0
    
    with torch.inference_mode():
        # Loop through testing batches
        for X, y in test_dataloaders[epoch % NUM_SPLITS]:
            # Transfer to correct device
            X, y = X.to(device), y.to(device) 
            
            # Forward pass
            test_pred = model(X)
            
            # Calculate loss and accuracy
            test_loss = loss_fn(test_pred, y)
            test_acc = acc_metric(test_pred.argmax(dim=1), y) * 100
            
            avg_test_loss += test_loss.detach().cpu().numpy()
            avg_test_acc += test_acc.detach().cpu().numpy()

        avg_test_loss /= len(test_dataloaders[0])
        avg_test_acc /= len(test_dataloaders[0])
        test_loss_values.append(avg_test_loss)
        test_acc_values.append(avg_test_acc)
    
end = timer()
train_time = end - start
print(f"Final Loss: {avg_test_loss:.5f} | Final Accuracy: {avg_test_acc:.2f}%")
print(f"Training Time: {(str(int(train_time / 60)) + ' min ') if train_time >= 60 else ''}{(train_time % 60):.3f} sec")

# Plot Loss and Accuracy Curves

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Loss")
plt.plot(range(epochs), train_loss_values, label="Train")
plt.plot(range(epochs), test_loss_values, label="Test")
plt.legend()
plt.subplot(1, 2, 2)
plt.title("Accuracy")
plt.plot(range(epochs), train_acc_values, label="Train")
plt.plot(range(epochs), test_acc_values, label="Test")
plt.legend();

# Save Model

In [None]:
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)

## Test Predictions after Training

In [None]:
loaded_model_0 = FashionMNISTCNN(1, 32, 64, len(classes), 28 * 28).to(device)
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))

batch_features, batch_labels = next(iter(train_dataloaders[0]))
multi_class_predictions(loaded_model_0, batch_features, batch_labels, classes, device)