# GBSG2 Benchmark Analysis

Comprehensive analysis of GBSG2 benchmark results comparing different loss functions for survival analysis.

## Loss Functions Compared
1. **NLL** (Negative Log-Likelihood) - Baseline
2. **CPL** - Concordance Pairwise Loss
3. **CPL (IPCW)** - CPL with Inverse Probability of Censoring Weighting
4. **NLL+CPL (Fixed)** - Normalized Combination Loss with fixed weights
5. **NLL+CPL (IPCW, Fixed)** - Normalized IPCW Loss with fixed weights
6. **NLL+CPL (GradNorm)**
7. **NLL+CPL (IPCW, GradNorm)** 


In [1]:
# Setup: imports and plotting style
import os
from datetime import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", context="talk")
plt.rcParams["figure.dpi"] = 300
RESULTS_DIR = "results"
assert os.path.isdir(RESULTS_DIR), f"Missing directory: {RESULTS_DIR}"

# Preferred display names for methods
LOSS_DISPLAY_MAP = {
    "nll": "NLL",
    "cpl": "CPL",
    "cpl_ipcw": "CPL (ipcw)",
    "cpl_nll_fixed": "NLL+CPL (Fixed)",
    "cpl_nll_gradnorm": "NLL+CPL (GradNorm)",
    "cpl_ipcw_nll_fixed": "NLL+CPL (IPCW, Fixed)",
    "cpl_ipcw_nll_gradnorm": "NLL+CPL (IPCW, GradNorm)",
}


In [2]:
# Load CSVs
files = {
    "WHAS500": os.path.join(RESULTS_DIR, "WHAS500_tabular_results_20250924_151428.csv"),
    "GBSG2": os.path.join(RESULTS_DIR, "GBSG2_tabular_results_20250924_151402.csv"),
    "METABRIC": os.path.join(RESULTS_DIR, "METABRIC_tabular_results_20250924_151349.csv"),
}

datasets = {}
for name, path in files.items():
    df = pd.read_csv(path)
    # Normalize column names (lowercase) just in case
    df.columns = [c.strip().lower() for c in df.columns]
    # Ensure expected columns exist
    required = {"loss_type", "uno_cindex", "cumulative_auc", "brier_score"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"{name} missing columns: {missing}")
    datasets[name] = df


In [3]:
# Select best-per-loss by Uno's C-index for each dataset
best_rows = {}
for name, df in datasets.items():
    idx = df.groupby("loss_type")["uno_cindex"].idxmax()
    best = df.loc[idx].copy().reset_index(drop=True)
    # Map display names and order
    best["method"] = best["loss_type"].map(LOSS_DISPLAY_MAP).fillna(best["loss_type"]) 
    best_rows[name] = best

# Concatenate for multi-dataset plotting
long_best = []
for name, best in best_rows.items():
    temp = best.copy()
    temp["dataset"] = name
    long_best.append(temp)
long_best = pd.concat(long_best, ignore_index=True)

# Enforce method ordering consistently across plots
method_order = [
    "NLL",
    "CPL",
    "CPL (ipcw)",
    "NLL+CPL (Fixed)",
    "NLL+CPL (IPCW, Fixed)",
    "NLL+CPL (GradNorm)",
    "NLL+CPL (IPCW, GradNorm)",
]
long_best["method"] = pd.Categorical(long_best["method"], categories=method_order, ordered=True)
long_best = long_best.sort_values(["dataset", "method"])


In [4]:
# Plotting helpers

def save_fig(fig, base_name: str):
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    path = os.path.join(RESULTS_DIR, f"tabular_compare_{base_name}_{ts}.png")
    fig.savefig(path, bbox_inches="tight")
    print(f"Saved: {path}")

# Barplot per metric across datasets
METRICS = [
    ("uno_cindex", "Uno's C-index", True),
    ("cumulative_auc", "Cumulative AUC", True),
    ("brier_score", "Brier score (lower is better)", False),
]

palette = sns.color_palette("Set2", n_colors=len(method_order))


In [5]:
# Create fixed-size figures: one per metric (methods x datasets)
for metric, title, higher_is_better in METRICS:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)
    for ax, dataset in zip(axes, ["WHAS500", "GBSG2", "METABRIC"]):
        data = long_best[long_best["dataset"] == dataset]
        sns.barplot(
            data=data,
            x="method",
            y=metric,
            palette=palette,
            ax=ax,
            order=method_order,
        )
        ax.set_title(f"{dataset}")
        ax.set_xlabel("")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
        ax.set_ylabel(title if ax is axes[0] else "")
        # Annotate bars
        for container in ax.containers:
            ax.bar_label(container, fmt="{:.3f}", fontsize=9, rotation=90, padding=2)
    fig.suptitle(f"Best-per-loss comparison on {title}")
    fig.tight_layout(rect=[0, 0.02, 1, 0.92])
    save_fig(fig, base_name=metric)
    plt.close(fig)

# Also print the best rows for traceability
display_cols = [
    "dataset", "method", "uno_cindex", "cumulative_auc", "brier_score",
    "harrell_cindex", "incident_auc", "hidden_dim", "best_lr", "best_temperature"
]
long_best_display = long_best.copy()
long_best_display = long_best_display[[c for c in display_cols if c in long_best_display.columns]]
long_best_display



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")


Saved: results\tabular_compare_uno_cindex_20250924_162349.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")


Saved: results\tabular_compare_cumulative_auc_20250924_162350.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")


Saved: results\tabular_compare_brier_score_20250924_162350.png


Unnamed: 0,dataset,method,uno_cindex,cumulative_auc,brier_score,harrell_cindex,incident_auc,hidden_dim,best_lr,best_temperature
13,GBSG2,NLL,0.671626,0.763143,0.130271,0.702687,0.735333,64,0.0005,1.0
7,GBSG2,CPL,0.669344,0.744505,0.132273,0.688374,0.744562,128,0.0005,1.0
8,GBSG2,CPL (ipcw),0.68134,0.764165,0.110747,0.707876,0.76615,64,0.05,1.0
11,GBSG2,NLL+CPL (Fixed),0.669225,0.75959,0.120788,0.697999,0.730059,128,0.0005,0.5
9,GBSG2,"NLL+CPL (IPCW, Fixed)",0.663655,0.744639,0.132631,0.689211,0.727752,64,0.0005,0.5
12,GBSG2,NLL+CPL (GradNorm),0.662927,0.744769,0.129339,0.686951,0.717535,64,0.0005,1.0
10,GBSG2,"NLL+CPL (IPCW, GradNorm)",0.674228,0.753121,0.138268,0.694568,0.731707,128,0.0005,1.0
20,METABRIC,NLL,0.655401,0.68249,0.161361,0.643439,0.756097,128,0.05,1.0
14,METABRIC,CPL,0.649567,0.709076,0.150907,0.661311,0.738178,64,0.05,1.0
15,METABRIC,CPL (ipcw),0.652904,0.691644,0.206407,0.655442,0.738514,128,0.05,1.0
