In [None]:
# Imports
import pandas as pd
from sksurv.util import Surv
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from typing import Callable, Tuple
import seaborn as sns

from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, confusion_matrix, roc_curve

from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc
import torch

# Set random seed for reproducibility
_ = torch.manual_seed(42)

# Metric calculation

In [None]:
# data imports CAIRO train

results_CAIRO_train = pd.read_csv("results/multi_task_model_predictions_training.csv")
targets_CAIRO_train = pd.read_csv("data/training_data.csv")

# Create masks for path response
path_mask_CAIRO_train = targets_CAIRO_train["path_resp"] != -1

# CAIRO train predictions and targets
path_probs_CAIRO_train = results_CAIRO_train["path_probs"][path_mask_CAIRO_train].reset_index(drop=True)
path_preds_CAIRO_train = results_CAIRO_train["path_preds"][path_mask_CAIRO_train].reset_index(drop=True)
pfs_preds_CAIRO_train = torch.tensor(results_CAIRO_train["pfs_preds"].values, dtype=torch.float32)
os_preds_CAIRO_train = torch.tensor(results_CAIRO_train["os_preds"].values, dtype=torch.float32)

path_labels_CAIRO_train = targets_CAIRO_train["path_resp"][path_mask_CAIRO_train].reset_index(drop=True)
pfs_targets_CAIRO_train = torch.tensor(targets_CAIRO_train["pfs"].values, dtype=torch.float32)
pfs_events_CAIRO_train = torch.tensor(targets_CAIRO_train["pfsstat"].values, dtype=torch.bool)
os_targets_CAIRO_train = torch.tensor(targets_CAIRO_train["OS"].values, dtype=torch.float32)
os_events_CAIRO_train = torch.tensor(targets_CAIRO_train["OSSTAT"].values, dtype=torch.bool)

In [None]:
# Data imports CAIRO
results_CAIRO = pd.read_csv("./results/multi_task_model_predictions_test_CAIRO.csv")

targets_CAIRO = pd.read_csv("./data/test_data_CAIRO.csv")


# Create masks for path response
path_mask_CAIRO = targets_CAIRO["path_resp"] != -1

# CAIRO predictions and targets
path_probs_CAIRO = results_CAIRO["path_probs"][path_mask_CAIRO].reset_index(drop=True)
path_preds_CAIRO = results_CAIRO["path_preds"][path_mask_CAIRO].reset_index(drop=True)
pfs_preds_CAIRO = torch.tensor(results_CAIRO["pfs_preds"].values, dtype=torch.float32)
os_preds_CAIRO = torch.tensor(results_CAIRO["os_preds"].values, dtype=torch.float32)

path_labels_CAIRO = targets_CAIRO["path_resp"][path_mask_CAIRO].reset_index(drop=True)
pfs_targets_CAIRO = torch.tensor(targets_CAIRO["pfs"].values, dtype=torch.float32)
pfs_events_CAIRO = torch.tensor(targets_CAIRO["pfsstat"].values, dtype=torch.bool)   
os_targets_CAIRO = torch.tensor(targets_CAIRO["OS"].values, dtype=torch.float32)
os_events_CAIRO = torch.tensor(targets_CAIRO["OSSTAT"].values, dtype=torch.bool)


# Time points for auc-roc metrics (1,2,3 years and 1, 3, 5 years)   
# (Shoud determine again after we have AmCore), perhaps seperate ones for Cairo an AmCORE
time_points_CAIRO = torch.tensor([365, 1095], dtype=torch.float32)  # 1, 3, 5 years in days


In [None]:
# Data imports AmCORE

results_AmCore = pd.read_csv("./results/multi_task_model_predictions_test_AmCORE.csv")
targets_AmCore = pd.read_csv("./data/test_data_AmCore.csv")

# Create mask for path response
path_mask_AmCore = targets_AmCore["path_resp"] != -1

# AmCore predictions and targets
path_probs_AmCore = results_AmCore["path_probs"][path_mask_AmCore].reset_index(drop=True)
path_preds_AmCore = results_AmCore["path_preds"][path_mask_AmCore].reset_index(drop=True)
pfs_preds_AmCore = torch.tensor(results_AmCore["pfs_preds"].values, dtype=torch.float32)
os_preds_AmCore = torch.tensor(results_AmCore["os_preds"].values, dtype=torch.float32)

path_labels_AmCore = targets_AmCore["path_resp"][path_mask_AmCore].reset_index(drop=True)
pfs_targets_AmCore = torch.tensor(targets_AmCore["pfs"].values, dtype=torch.float32)
pfs_events_AmCore = torch.tensor(targets_AmCore["pfsstat"].values, dtype=torch.bool)
os_targets_AmCore = torch.tensor(targets_AmCore["OS"].values, dtype=torch.float32)
os_events_AmCore = torch.tensor(targets_AmCore["OSSTAT"].values, dtype=torch.bool)

# Time points for auc-roc metrics (example: 1, 3 years in days)
time_points_AmCORE = torch.tensor([365, 1095], dtype=torch.float32)

In [None]:
# Bootstrap function for path. response metrics

def compute_confidence_interval(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    metric_fn: Callable[[np.ndarray, np.ndarray], float],
    n_bootstraps: int = 1000,
    ci: float = 0.95,
    random_state: int = 42
) -> Tuple[float, float, float]:
    """
    Compute a metric and its confidence interval using bootstrapping.
    """
    rng = np.random.default_rng(random_state)
    metrics = []

    for _ in range(n_bootstraps):
        indices = rng.integers(0, len(y_true), len(y_true))
        if len(np.unique(y_true[indices])) < 2:
            continue  # Skip iteration if not enough class diversity
        sample_metric = metric_fn(y_true[indices], y_pred[indices])
        metrics.append(sample_metric)

    metric_mean = np.mean(metrics)
    lower_bound = np.percentile(metrics, ((1 - ci) / 2) * 100)
    upper_bound = np.percentile(metrics, (1 - (1 - ci) / 2) * 100)
    return metric_mean, lower_bound, upper_bound


## Accuracy - Path. Resp.

In [None]:
accuracy_CAIRO_train = accuracy_score(path_labels_CAIRO_train, path_preds_CAIRO_train)
print(f"Accuracy (CAIRO train): {accuracy_CAIRO_train:.3f}")

mean, lower_bound, upper_bound = compute_confidence_interval(path_labels_CAIRO_train, path_preds_CAIRO_train, accuracy_score)
print(f"Bootstrap accuracy (CAIRO train): {mean:.3f} (95% CI: [{lower_bound:.3f}, {upper_bound:.3f}])")

### AmCore

In [None]:
accuracy_AmCore = accuracy_score(path_labels_AmCore, path_preds_AmCore)
print(f"Accuracy (AmCore): {accuracy_AmCore:.3f}")

mean, lower_bound, upper_bound = compute_confidence_interval(path_labels_AmCore, path_preds_AmCore, accuracy_score)
print(f"Bootstrap accuracy (AmCore): {mean:.3f} (95% CI: [{lower_bound:.3f}, {upper_bound:.3f}])")

## AUC-ROC - Path. Resp.

### AmCore

In [None]:
auc_roc_AmCore = roc_auc_score(path_labels_AmCore, path_probs_AmCore)
print(f"AUC-ROC (AmCore): {auc_roc_AmCore:.3f}")

mean, lower_bound, upper_bound = compute_confidence_interval(path_labels_AmCore, path_probs_AmCore, roc_auc_score)
print(f"Bootstrap AUC-ROC (AmCore): {mean:.3f} (95% CI: [{lower_bound:.3f}, {upper_bound:.3f}])")

fpr, tpr, thresholds = roc_curve(path_labels_AmCore, path_probs_AmCore)
plt.figure()
plt.plot(fpr, tpr, label=f'AUC = {auc_roc_AmCore:.3f}')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (AmCORE)')
plt.legend(loc='lower right')
plt.show()

## F1 Score - Path. Resp. 

In [None]:
f1_AmCore = f1_score(path_labels_AmCore, path_preds_AmCore)
print(f"F1 Score (AmCore): {f1_AmCore:.3f}")

mean, lower_bound, upper_bound = compute_confidence_interval(path_labels_AmCore, path_preds_AmCore, f1_score)
print(f"Bootstrap F1 Score (AmCore): {mean:.3f} (95% CI: [{lower_bound:.3f}, {upper_bound:.3f}])")

## Confusion matrix - Path. Resp.

In [None]:

# Confusion matrix for AmCore
cm_AmCore = confusion_matrix(path_labels_AmCore, path_preds_AmCore)
print("\nConfusion Matrix (AmCore):")
plt.figure(figsize=(4, 4))
sns.heatmap(cm_AmCore, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['No', 'Complete'], yticklabels=['No', 'Complete'])
plt.xlabel('Predicted PR')
plt.ylabel('True PR')
plt.title('Pathological Response (AmCORE)')
plt.show()


## C-Index - PFS & OS

In [None]:
# C-index for OS train
c_index = ConcordanceIndex()
c_index_os_cairo_train = c_index(os_preds_CAIRO_train, os_events_CAIRO_train, os_targets_CAIRO_train)
print(f"Concordance Index (OS, CAIRO train): {c_index_os_cairo_train.item():.3f}, Confidence Interval: {c_index.confidence_interval()}")

### CAIRO

In [None]:
# PFS
c_index = ConcordanceIndex()
c_index_pfs = c_index(pfs_preds_CAIRO, pfs_events_CAIRO, pfs_targets_CAIRO)
print(f"Concordance Index (PFS, CAIRO): {c_index_pfs.item():.3f}, confidence interval: {c_index.confidence_interval()}")


# OS
c_index = ConcordanceIndex()
c_index_os = c_index(os_preds_CAIRO, os_events_CAIRO, os_targets_CAIRO)
print(f"Concordance Index (OS, CAIRO): {c_index_os.item():.3f}, Confidence Interval: {c_index.confidence_interval()}")

### AmCore

In [None]:
# PFS
c_index = ConcordanceIndex()
c_index_pfs = c_index(pfs_preds_AmCore, pfs_events_AmCore, pfs_targets_AmCore)
print(f"Concordance Index (PFS, AmCore): {c_index_pfs.item():.3f}, confidence interval: {c_index.confidence_interval()}")


# OS (stratified)
#start, end = 1, 720
#mask = (os_targets_AmCore >= start) & (os_targets_AmCore <= end)

c_index = ConcordanceIndex()
c_index_os = c_index(os_preds_AmCore, os_events_AmCore, os_targets_AmCore)
print(f"Concordance Index (OS, AmCore): {c_index_os.item():.3f}, Confidence Interval: {c_index.confidence_interval()}")

## Time dependent AUC-ROC - PFS & OS

### CAIRO

In [None]:
# PFS
auc = Auc()
auc_pfs = auc(pfs_preds_CAIRO, pfs_events_CAIRO, pfs_targets_CAIRO, new_time=time_points_CAIRO)

auc_val = auc_pfs.item()
ci_low, ci_high = auc.confidence_interval()
print(f"AUC at 365 days: {auc_val:.3f} (95% CI: [{ci_low:.3f}, {ci_high:.3f}])")


In [None]:
# OS
auc = Auc()
time_points_CAIRO = torch.tensor([720.0])
auc_pfs = auc(os_preds_CAIRO, os_events_CAIRO, os_targets_CAIRO, new_time=time_points_CAIRO)

auc_val = auc_pfs.item()
ci_low, ci_high = auc.confidence_interval()
print(f"AUC at 365 days: {auc_val:.3f} (95% CI: [{ci_low:.3f}, {ci_high:.3f}])")

### AmCore

In [None]:
# PFS (AmCore)
auc = Auc()
auc_pfs = auc(pfs_preds_AmCore, pfs_events_AmCore, pfs_targets_AmCore, new_time=time_points_pfs)
for i, t in enumerate(time_points_pfs):
    auc_val = auc_pfs[i].item()
    ci_low = auc.confidence_interval()[0, i]
    ci_high = auc.confidence_interval()[1, i]
    print(f"AUC at {t.item():.0f} days: {auc_val:.3f} (95% CI: [{ci_low:.3f}, {ci_high:.3f}])")

In [None]:
# OS (AmCore)
auc = Auc()
auc_os = auc(os_preds_AmCore, os_events_AmCore, os_targets_AmCore, new_time=time_points_AmCORE)
for i, t in enumerate(time_points_AmCORE):
    auc_val = auc_os[i].item()
    ci_low = auc.confidence_interval()[0, i]
    ci_high = auc.confidence_interval()[1, i]
    print(f"AUC at {t:.0f} days: {auc_val:.3f} (95% CI: [{ci_low:.3f}, {ci_high:.3f}])")