## Import libraries

In [1]:
import numpy as np
import torch
from scipy.special import softmax
from sklearn.metrics import (
    accuracy_score,
    cohen_kappa_score,
    confusion_matrix,
    mean_absolute_error,
)
from skorch import NeuralNetClassifier
from torch import cuda, nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision import models
from torchvision.transforms.v2 import Compose, ToDtype, ToImage

from dlordinal.datasets import FGNet
from dlordinal.metrics import accuracy_off1, amae, mmae, ranked_probability_score
from dlordinal.output_layers.copoc import COPOC

## Import FGNet dataset

In [2]:
fgnet_train = FGNet(
    root="./datasets",
    download=True,
    train=True,
    transform=Compose([ToImage(), ToDtype(torch.float32, scale=True)]),
)

fgnet_test = FGNet(
    root="./datasets",
    download=True,
    train=False,
    transform=Compose([ToImage(), ToDtype(torch.float32, scale=True)]),
)

Files already downloaded and verified
Files already processed and verified
Files already split and verified
Files already downloaded and verified
Files already processed and verified
Files already split and verified


## Model training

In [3]:
device = "cuda" if cuda.is_available() else "cpu"

num_classes = len(fgnet_train.classes)

# Initialize ResNet18 model
model = models.resnet18(weights="IMAGENET1K_V1")

# Add COPOC layer
model.fc = nn.Sequential(nn.Linear(model.fc.in_features, num_classes), COPOC())
model = model.to(device)

# Skorch estimator
estimator = NeuralNetClassifier(
    module=model,
    criterion=CrossEntropyLoss().to(device),
    optimizer=Adam,
    lr=0.001,
    max_epochs=30,
    device=device,
    batch_size=200,
)

# Prepare training labels
y_train = torch.tensor(fgnet_train.targets, dtype=torch.long)

# Train model
estimator.fit(fgnet_train, y_train)

  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m1.7528[0m       [32m0.1925[0m        [35m1.8009[0m  11.5622
      2        [36m1.7169[0m       [32m0.2360[0m        [35m1.7633[0m  14.0602
      3        [36m1.6682[0m       0.1615        1.8305  10.6949
      4        [36m1.6394[0m       0.2360        1.7851  10.9021
      5        [36m1.6377[0m       [32m0.2671[0m        [35m1.7353[0m  19.3132
      6        [36m1.5851[0m       0.1366        1.8213  14.1549
      7        [36m1.5634[0m       [32m0.3913[0m        [35m1.6269[0m  12.9707
      8        [36m1.5358[0m       0.3043        1.7113  12.1856
      9        [36m1.4657[0m       0.3727        1.6548  13.1590
     10        1.4998       0.3602        1.6590  12.9906
     11        [36m1.4155[0m       0.3106        1.7164  17.4253
     12        [36m1.3896[0m       [32m0.4161[0m        [35m1.6075[0m  24.2347

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(6

## Evaluation

In [4]:
def is_unimodal(probs):
    """Check if a 1D array is unimodal (increases to a peak, then decreases)."""
    peak_idx = np.argmax(probs)
    # Increasing up to peak
    inc = np.all(np.diff(probs[: peak_idx + 1]) >= 0)
    # Decreasing after peak
    dec = np.all(np.diff(probs[peak_idx:]) <= 0)
    return inc and dec


def check_unimodality(y_pred):
    """Check unimodality for each row in y_pred and return the proportion."""
    unimodal_flags = np.array([is_unimodal(row) for row in y_pred])
    # Proportion of rows that are unimodal
    proportion = np.mean(unimodal_flags)
    print(
        f"Unimodal predictions: {np.sum(unimodal_flags)} / {len(y_pred)} ({proportion})"
    )
    return proportion


def calculate_metrics(y_true, y_pred):
    """Calculate various metrics given true labels and predicted probabilities."""
    if np.allclose(np.sum(y_pred, axis=1), 1):
        y_pred_proba = y_pred
    else:
        y_pred_proba = softmax(y_pred, axis=1)

    y_pred_max = np.argmax(y_pred, axis=1)

    # Metrics
    amae_metric = amae(y_true, y_pred)
    mmae_metric = mmae(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred_max)
    acc = accuracy_score(y_true, y_pred_max)
    acc_1off = accuracy_off1(y_true, y_pred)
    qwk = cohen_kappa_score(y_true, y_pred_max, weights="quadratic")
    rps = ranked_probability_score(y_true, y_pred_proba)
    # Check unimodality
    unimodal_prop = check_unimodality(y_pred_proba)

    metrics = {
        "ACC": acc,
        "1OFF": acc_1off,
        "MAE": mae,
        "QWK": qwk,
        "AMAE": amae_metric,
        "MMAE": mmae_metric,
        "RPS": rps,
        "Unimodality": unimodal_prop,
    }

    for key, value in metrics.items():
        print(f"{key}: {value}")

    print(confusion_matrix(y_true, y_pred_max))

    return metrics


# Evaluate on test set
test_probs = estimator.predict_proba(fgnet_test)
print(calculate_metrics(fgnet_test.targets, test_probs))

Unimodal predictions: 201 / 201 (100.00%)
ACC: 0.5422885572139303
1OFF: 0.9353233830845771
MAE: 0.527363184079602
QWK: 0.8252978168618028
AMAE: 0.5437950937950937
MMAE: 0.9285714285714286
RPS: 0.6556535991962896
Unimodality: 1.0
[[18  3  1  0  0  0]
 [ 7 31 16  5  1  0]
 [ 0  6 18  9  0  0]
 [ 0  1 12 24  5  0]
 [ 0  0  3 10 15  2]
 [ 0  0  0  2  9  3]]
{'ACC': 0.5422885572139303, '1OFF': 0.9353233830845771, 'MAE': 0.527363184079602, 'QWK': 0.8252978168618028, 'AMAE': 0.5437950937950937, 'MMAE': 0.9285714285714286, 'RPS': 0.6556535991962896, 'Unimodality': 1.0}
