## 1. Data preprocessing

In [13]:
from functools import lru_cache

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split


@lru_cache
def _download_data():
    data = torchvision.datasets.FashionMNIST(
        "path", download=True, transform=transforms.ToTensor()
    )
    return data


def get_train_val_loaders(
    batch_size: int,
    fraction_of_train_set: float = 1.0
) -> tuple[DataLoader, DataLoader]:
    dataset = _download_data()
    train_size = int(fraction_of_train_set * 0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return (train_loader, val_loader)

## 2. Data presentation tools

In [14]:
import matplotlib.pyplot as plt
from matplotlib.axes import Axes


def plot_metrics(results: list[dict[str, float]]):
    x = [i for i in range(len(results))]
    accuracy = [results[i]['accuracy'] for i in x]
    recall = [results[i]['recall'] for i in x]
    f1 = [results[i]['f1'] for i in x]

    plt.figure(figsize=(8, 6))
    plt.plot(x, accuracy, linestyle='-', color='b', label='Accuracy')
    plt.plot(x, recall, linestyle='-', color='r', label='Recall')
    plt.plot(x, f1, linestyle='-', color='g', label='F1 Score')

    plt.xlabel('Epoch')
    plt.ylabel('Metric score')
    plt.title('Plot of metric scores')

    plt.xlim(0, len(results))

    plt.legend()

    plt.grid(True)
    plt.show()


def plot_cost(results: list[dict[str, float]]):
    x = [i for i in range(len(results))]
    cost = [results[i]['cost'] for i in x]

    plt.figure(figsize=(8, 6))
    plt.plot(x, cost, linestyle='-', color='r', label='Cost')

    plt.xlabel('Epoch')
    plt.ylabel('Cost function value')
    plt.title('Plot of cost value over time (epochs)')

    plt.xlim(0, len(results))

    plt.legend()

    plt.grid(True)
    plt.show()


def compare_plots(results_with_titles: list[tuple[dict, str]], cost: bool = False):
    count = len(results_with_titles)
    # cols = min(count, 3)
    # rows = ((count - 1) // 3) + 1
    cols = min(count, 2)
    rows = ((count - 1) // 2) + 1

    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
    for i in range(rows):
        for j in range(cols):
            ax: Axes = axs[i][j] if rows > 1 else axs[j]
            # if (3 * i) + j < count:
            if (2 * i) + j < count:
                # res, title = results_with_titles[(3 * i) + j]
                res, title = results_with_titles[(2 * i) + j]
                x = [k for k in range(len(res))]
                if cost:
                    cost = [res[j]['cost'] for j in x]
                    ax.plot(x, cost, linestyle='-', color='r', label='Cost')
                else:
                    accuracy = [res[j]['accuracy'] for j in x]
                    recall = [res[j]['recall'] for j in x]
                    f1 = [res[j]['f1'] for j in x]
                    ax.plot(x, accuracy, linestyle='-', color='b', label='Accuracy')
                    ax.plot(x, recall, linestyle='-', color='r', label='Recall')
                    ax.plot(x, f1, linestyle='-', color='g', label='F1 Score')
                    ax.set_ylim(0, 1)
                ax.set_title(title)
                ax.set_xlim(0, len(res))
                ax.legend()
                ax.grid(True)
            else:
                ax.set_visible(False)

    fig.supxlabel('Epoch', fontsize=12)
    if cost:
        fig.supylabel('Cost function value', fontsize=12)
        fig.suptitle('Plot of cost value over time (epochs)', fontsize=14)
    else:
        fig.supylabel('Metric scores', fontsize=12)
        fig.suptitle('Plot of metric scores over time (epochs)', fontsize=14)

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()

## 3. Model

In [15]:
from collections.abc import Callable
from dataclasses import dataclass

import torch
import torch.nn as nn


@dataclass
class HiddenLayerConfig:
    neurons: int
    f_activ: nn.Module


# Define the model
class ImageClassifier(nn.Module):
    def __init__(
        self,
        input_features: int,
        output_labels: int,
        hidden_layers_config: list[HiddenLayerConfig]
    ):
        super().__init__()
        layers = []
        for i in range(len(hidden_layers_config)):
            input_size = input_features if i == 0 else hidden_layers_config[i-1].neurons
            layers.append(nn.Linear(input_size, hidden_layers_config[i].neurons))
            layers.append(hidden_layers_config[i].f_activ)
        layers.append(nn.Linear(hidden_layers_config[-1].neurons, output_labels))
        self.layers = nn.Sequential(*layers)
        # Init weights
        self.layers.apply(self._init_weights)
        self.flatten = nn.Flatten()

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        x = self.flatten(x)
        return self.layers(x)

## 4. Training

In [None]:
from sklearn.metrics import accuracy_score, f1_score, recall_score
from torch.nn.modules import Module
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader


def train_one_epoch(
    model: ImageClassifier,
    train_loader: DataLoader,
    loss_fn: Module,
    optimizer: Optimizer
):
    for X, y in train_loader:
        # Zero gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        y_pred = model(X)

        # Compute the loss and its gradients
        loss = loss_fn(y_pred, y)
        loss.backward()

        # Adjust learning weights
        optimizer.step()


def train_model(
    model: ImageClassifier,
    max_epochs: int,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_fn: Module,
    optimizer: Optimizer,
    debug: bool = False
):
    results = []
    prev_avg_val_loss = 0.0
    # Initialize training
    for epoch in range(max_epochs):
        model.train()  # Set model to training mode

        # Train for one epoch
        train_one_epoch(
            model=model,
            train_loader=train_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
        )

        model.eval()
        running_val_loss = 0.0
        all_labels = []
        all_preds = []

        with torch.no_grad():
            for X, y in val_loader:
                y_pred = model(X)
                val_loss = loss_fn(y_pred, y)
                running_val_loss += val_loss.item()

                predicted_classes = torch.argmax(y_pred, dim=1)
                all_labels.extend(y.cpu().numpy())
                all_preds.extend(predicted_classes.cpu().numpy())

        # Calculate metrics for the epoch
        accuracy = accuracy_score(all_labels, all_preds)
        recall = recall_score(all_labels, all_preds, average='weighted')
        f1 = f1_score(all_labels, all_preds, average='weighted')
        avg_val_loss = running_val_loss / len(val_loader)

        results.append({
            "accuracy": accuracy,
            "recall": recall,
            "f1": f1,
            "cost": avg_val_loss
        })

        delta = avg_val_loss - prev_avg_val_loss
        if debug:
            print(f"Epoch [{epoch+1}/{max_epochs}] - avg val loss: {avg_val_loss:.4f} (delta: {delta})")
        prev_avg_val_loss = avg_val_loss

    print("Training complete!")
    return results

In [17]:
import torch.nn as nn

BATCH_SIZE = 30

train_loader, val_loader = get_train_val_loaders(batch_size=16)
train_loader.batch_sampler

model = ImageClassifier(
    input_features=28 * 28,  # For FashionMNIST images,
    output_labels=10,  # For FashionMNIST images,
    hidden_layers_config=[
        HiddenLayerConfig(14, nn.Sigmoid()),
        HiddenLayerConfig(7, nn.Sigmoid())
    ]
)

optimizer = torch.optim.SGD(model.parameters())

results = train_model(
    model=model,
    max_epochs=1,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=nn.CrossEntropyLoss(),
    optimizer=optimizer,
    # tb_writer: SummaryWriter = None
)

plot_metrics(results)
plot_cost(results)

ValueError: Classification metrics can't handle a mix of multiclass and continuous-multioutput targets