# Ordinal Regression with CORN

This notebook demonstrates training and evaluation of a Rank-Consistent Ordinal Regression (CORN) model using the Skorch framework with a ResNet-18 backbone, using the FG-NET dataset.

## 1. Setup and imports

First, we install the `skorch` library. It is not a dependency of `dlordinal`, but it makes the training and evaluation of the model easier.

In [1]:
!pip install skorch



This cell imports all necessary libraries from PyTorch, Torchvision, Skorch, Scikit-learn, and your custom modules.

In [2]:
import numpy as np
import torch
from scipy.special import softmax
from sklearn.metrics import (
    accuracy_score,
    cohen_kappa_score,
    confusion_matrix,
    mean_absolute_error,
)
from skorch import NeuralNetClassifier
from skorch.callbacks import EarlyStopping, LRScheduler

# Importing Skorch utilities and callbacks
from skorch.dataset import ValidSplit
from torch import cuda
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models import resnet18
from torchvision.transforms import Compose, ToTensor

# Importing custom ordinal components
from dlordinal.datasets import FGNet
from dlordinal.losses import CORNLoss
from dlordinal.metrics import accuracy_off1, amae, mmae, ranked_probability_score
from dlordinal.wrappers.corn import CORNClassifierWrapper

## 2. Define Evaluation Metrics Function

This utility function processes the output probabilities, calculates various ordinal and classification metrics (including AMMA, QWK, and RPS), and prints the results, including the confusion matrix.

In [3]:
def calculate_metrics(y_true, y_pred):
    """
    Processes the model output (logits or probabilities) and calculates
    a comprehensive set of ordinal and classification metrics.

    y_true: True ordinal labels (e.g., [0, 1, 2])
    y_pred: Predicted probabilities (for CORN, these are the one-hot results
            from the wrapper's predict_proba)
    """

    # Ensure y_pred is in probability space for metrics like RPS
    if np.allclose(np.sum(y_pred, axis=1), 1):
        y_pred_proba = y_pred
    else:
        # If the input is raw logits, convert to softmax probabilities
        y_pred_proba = softmax(y_pred, axis=1)

    # Determine the predicted class by finding the index with the max probability
    y_pred_max = np.argmax(y_pred, axis=1)

    # --- Metric Calculation ---
    amae_metric = amae(y_true, y_pred_max)
    mmae_metric = mmae(y_true, y_pred_max)
    mae = mean_absolute_error(y_true, y_pred_max)
    acc = accuracy_score(y_true, y_pred_max)
    acc_1off = accuracy_off1(y_true, y_pred_max)
    qwk = cohen_kappa_score(y_true, y_pred_max, weights="quadratic")
    rps = ranked_probability_score(y_true, y_pred_proba)

    metrics = {
        "ACC": acc,
        "1OFF": acc_1off,
        "MAE": mae,
        "QWK": qwk,
        "AMAE": amae_metric,
        "MMAE": mmae_metric,
        "RPS": rps,
    }

    # --- Output ---
    print("\n--- Evaluation Metrics ---")
    for key, value in metrics.items():
        print(f"{key}: {value:.4f}")

    print("\nConfusion Matrix:")
    print(confusion_matrix(y_true, y_pred_max))

    return metrics

## 3. Load Data and Setup Device

This cell loads the FG-NET dataset for training and testing and sets up the device, preferring CUDA if available.

In [4]:
# --- Data Loading ---
fgnet_train = FGNet(
    root="./datasets",
    download=True,
    train=True,
    transform=Compose([ToTensor()]),  # Converts images to PyTorch tensors
)

fgnet_test = FGNet(
    root="./datasets",
    download=True,
    train=False,
    transform=Compose([ToTensor()]),
)

num_classes = len(fgnet_train.classes)  # J: Total number of ordinal classes (e.g., 7)
classes = fgnet_train.classes
targets = fgnet_train.targets

# --- Device Setup ---
device = "cuda" if cuda.is_available() else "cpu"
print(f"Using {device} device")

Files already downloaded and verified
Files already processed and verified
Files already split and verified
Files already downloaded and verified
Files already processed and verified
Files already split and verified
Using cpu device


## 4. Model and Skorch Wrapper Definition

This cell defines the ResNet-18 model, adapts its final layer for the CORN method (outputting Jâˆ’1 logits), and wraps it with the Skorch and CORN classifiers.

In [5]:
# --- ResNet Adaptation for CORN (J-1 Logits) ---
base_model = resnet18(weights="IMAGENET1K_V1")
# Replace the final fully connected layer to output J-1 logits
base_model.fc = torch.nn.Linear(base_model.fc.in_features, num_classes - 1)
base_model = base_model.to(device)

# --- Skorch NeuralNetClassifier Configuration ---
clf = NeuralNetClassifier(
    module=base_model,
    # Use CORNLoss, passing the total number of classes (J) to the constructor
    criterion=CORNLoss(num_classes=num_classes),
    optimizer=AdamW,
    lr=0.001,
    max_epochs=30,
    # verbose=0, # Uncomment to suppress Skorch output
    train_split=ValidSplit(
        0.1, random_state=0
    ),  # Use 10% of the training data for internal validation
    callbacks=[
        # Stop early if validation loss doesn't improve for 5 epochs
        EarlyStopping(patience=5, monitor="valid_loss"),
        # Reduce LR by 0.5 if validation loss plateaus for 3 epochs
        LRScheduler(policy=ReduceLROnPlateau, patience=3, factor=0.5),
    ],
    device=device,
    batch_size=200,
)

# --- Final CORN Wrapper ---
# The CORNClassifierWrapper uses the trained Skorch object (clf)
# and provides the Scikit-learn interface, implementing the CORN aggregation logic.
corn_clf = CORNClassifierWrapper(clf)

print(f"Skorch classifier ready to train with {num_classes-1} output logits.")

Skorch classifier ready to train with 5 output logits.


## 5. Train the Model

This cell executes the training process using the Skorch/CORN wrapper.

In [6]:
print("--- Starting Training ---")

# Skorch automatically handles data loading and iteration.
# We pass the FGNet dataset instance and its targets (converted to a LongTensor).
corn_clf.fit(X=fgnet_train, y=torch.tensor(fgnet_train.targets, dtype=torch.long))

print("--- Training Complete ---")

--- Starting Training ---
  epoch    train_loss    valid_acc    valid_loss      lr      dur
-------  ------------  -----------  ------------  ------  -------
      1        [36m0.5372[0m       [32m0.0617[0m        [35m0.5114[0m  0.0010  13.2285
      2        [36m0.2260[0m       [32m0.0741[0m        [35m0.4839[0m  0.0010  12.5660
      3        [36m0.0920[0m       0.0494        0.6455  0.0010  12.1744
      4        [36m0.0306[0m       0.0617        0.6849  0.0010  11.9297
      5        [36m0.0155[0m       0.0617        [35m0.3541[0m  0.0010  13.0846
      6        [36m0.0056[0m       0.0617        0.3857  0.0010  12.6702
      7        [36m0.0042[0m       0.0370        0.4297  0.0010  12.3104
      8        [36m0.0019[0m       0.0370        0.4507  0.0010  12.4775
      9        [36m0.0013[0m       0.0370        0.4369  0.0010  12.5651
Stopping since valid_loss has not improved in the last 5 epochs.
--- Training Complete ---


## 6. Evaluate on Test Set

This cell uses the trained classifier to predict probabilities on the test set and calls the metric calculation function.

In [7]:
print("\n--- Evaluating on Test Set ---")

# Get the raw targets for comparison
test_targets = fgnet_test.targets

# Get prediction probabilities using the CORN aggregation logic (one-hot output)
# This uses the thresholding logic inside the CORNClassifierWrapper
test_probs = corn_clf.predict_proba(fgnet_test)

# Calculate and display all metrics
metrics = calculate_metrics(np.array(test_targets), test_probs)


--- Evaluating on Test Set ---

--- Evaluation Metrics ---
ACC: 0.5124
1OFF: 0.8756
MAE: 0.6517
QWK: 0.7337
AMAE: 0.7544
MMAE: 1.6429
RPS: 0.6517

Confusion Matrix:
[[15  6  0  1  0  0]
 [ 8 37  6  7  2  0]
 [ 1 12  4 15  0  1]
 [ 0  3  7 28  4  0]
 [ 0  0  3  7 17  3]
 [ 0  1  2  4  5  2]]
