In [None]:
"""
Expected data path and required files: 

DATA_DIR
├── UK_65
│    ├── "dataset.csv"
│    ├── "to_censure.pkl"
│    ├── "inactive_ids.pkl"
│    ├── "stats.pkl"
│    ├── "deeper_main_meds.csv"
│    └── "all_meds_with_counts.csv"
│    
├── UK_70
├── FR_65
└── FR_70
"""

DATA_DIR =  '../../../data/datasets'
OUTPUT_DIR = '../../../data/results/'
SEED = 2023

In [None]:
import warnings
import os

from typing import Dict, List

import pandas as pd
import pickle
import ast

import numpy as np

from fractions import Fraction

from IPython.display import display

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, confusion_matrix, precision_recall_curve

from sklearn.linear_model import LogisticRegression
from imblearn.under_sampling import RandomUnderSampler

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline

import matplotlib.pyplot as plt
import seaborn as sns

# Plot like R
import matplotlib as mpl
mpl.style.use('ggplot')

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Helvetica']
plt.rcParams['font.size'] = 12

sns.set_palette('pastel')
sns.set(style='ticks')

# sns.despine()

In [None]:
DISEASES_OF_INTEREST = [
    'alzheimer',
    #'parkinson',
    # 'vascular_dementias',
    #'mci',
    # 'alcohol_dementias',
    # 'frontotemporal_dementias',
    # 'other_dementias',
    # 'parkinson_dementias',
    'all_dementias',
    'all_dementias+mci'
]

# Utils

### Dataset

In [None]:
PREDICTORS = [
'A06','A10','B03','G04','N02','N03','N05','N06','mci_at_baseline'
]
with open(os.path.join(OUTPUT_DIR, 'datasets_prediction.pkl'), 'rb') as f:
    DATASETS_PRED = pickle.load(f)

def get_data(country:str, age:int, disease:str, include_charlson_bmi:bool, 
             pred_up_to_year:int, include_deceased_as_zeros:bool=False, base_dir:str=DATA_DIR):
    """
    Load and prepare data for prediction task.
    
    Parameters:
    -----------
    country : str
        'UK' or 'FR'
    age : int
        65 or 70
    disease : str
        Disease of interest from DISEASES_OF_INTEREST or 'any'
    include_charlson_bmi : bool
        Whether to include BMI and CHARLSON features
    pred_up_to_year : int
        Prediction horizon in years
    include_deceased_as_zeros : bool, default=False
        If True, include patients who died during prediction period without disease as negatives (0)
        If False, exclude patients who died during prediction period without disease (original behavior)
    base_dir : str
        Base directory for data
        
    Returns:
    --------
    X : DataFrame
        Features
    y : Series
        Target labels (1=disease, 0=no disease)
    """
    assert country in {'UK', 'FR'}, f'`country` must be either UK or FR, not {country}.'
    assert age in {65, 70}, f'`age` must be either 65 or 70, not {age}.'
    assert disease in DISEASES_OF_INTEREST or disease=='any', f'`disease` must be either in DISEASES_OF_INTEREST' 
    f'({DISEASES_OF_INTEREST}) or any (at least one neurodegenerative diseases), not {age}.'
    
    ## Load corresponding dataset
    dataset = DATASETS_PRED[age][country].copy()
    
    # filter out patients without bmi or charlson record
    if include_charlson_bmi:
        mask = dataset[['avg. BMI', 'avg. CHARLSON']].notna().all(axis=1)
        n = mask.sum()
        print(f'* Filtering out patients without BMI or CHARLSON record ({len(dataset)-n:,} patients {(1-n/len(dataset))*100:.2f} %)')
        print(f'\t{len(dataset):,} -> {n:,} patients')
        dataset = dataset.loc[mask]
    
    ## Filter out patients who died or became inactive / temporaire before a potential disease
    pred_up_to_days = pred_up_to_year*365.25
    
    # sick
    has_disease = ( lambda x: len(x) > 0 ) if disease=='any' else ( lambda x: any(d == disease for d, _ in x) )   
    time_to_first_disease = lambda x: min(x, key=lambda x: x[1])[1]
    f = lambda x: has_disease(x) and (time_to_first_disease(x) <= pred_up_to_days)
    
    mask_sick = dataset['diseases'].apply(f)
    
    # dead or inactive
    mask_die_inac = (~dataset['person_state_code'].eq('A')) & (dataset['duration (days)'] <= pred_up_to_days)
    
    if include_deceased_as_zeros:
        # NEW APPROACH: Keep ALL patients, including those who died without disease (counted as 0)
        n_deaths_total = mask_die_inac.sum()
        n_deaths_with_disease = (mask_die_inac & mask_sick).sum()
        n_deaths_without_disease = (mask_die_inac & (~mask_sick)).sum()
        
        print(f'* Including ALL patients, including {n_deaths_total:,} deceased during period ({n_deaths_total/len(dataset)*100:.2f} %)')
        print(f'  - {n_deaths_with_disease:,} deceased WITH disease → counted as "1"')
        print(f'  - {n_deaths_without_disease:,} deceased WITHOUT disease → counted as "0"')
        print(f'\tNo patients excluded: {len(dataset):,} patients kept')
        
        # No filtering - keep all patients
        filter_out = pd.Series(False, index=dataset.index)
    else:
        # ORIGINAL APPROACH: Exclude patients who died without developing the disease
        filter_out = mask_die_inac & (~mask_sick)
        n = filter_out.sum()
        print(f'* Filtering out patients who died or became inactive before the disease of interest ({n:,} patients {n/len(dataset)*100:.2f} %)')
        print(f'\t{len(dataset):,} -> {len(dataset)-n:,} patients')
    
    dataset = dataset.loc[~filter_out]
    
    dataset['is.female'] = dataset.pop('gender_code').eq('F').astype(int)
    
    ## X, y
    y = mask_sick.loc[~filter_out].astype(int)
    
    predictors = ['is.female'] + PREDICTORS
    if include_charlson_bmi:
        predictors = ['avg. BMI', 'avg. CHARLSON'] + predictors
    
    X = dataset[predictors] #TODO: add time to first_disease and person_id to analyse error ?
    
    n = y.sum()
    print(f"\n{len(X):,} patients among which {n:,} with the disease of interest ({n/len(X)*100:.2f} %)")
    
    if include_deceased_as_zeros and mask_die_inac.sum() > 0:
        n_deaths_without_disease = (mask_die_inac.loc[~filter_out] & (~y.astype(bool))).sum()
        print(f"Note: Includes {n_deaths_without_disease:,} patients who died without the disease (counted as negatives)")
    
    return X, y

def split(X, y, seed=SEED):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=seed)
    
    n_train, n_test = len(y_train), len(y_test)
    class1_train, class1_test = y_train.sum(), y_test.sum()
    print(f"""Datasets proportions:
    - {class1_train:,} diseases in the train set out of {n_train:,} patients ({class1_train/n_train*100:.2f} %)
    - {class1_test:,} diseases in the test set out of {n_test:,} patients ({class1_test/n_test*100:.2f} %)""")
    return X_train, X_test, y_train, y_test

### model

In [None]:
def plot_roc(fpr, tpr, roc_auc, ax):
    ax.plot(fpr, tpr, 'o-', color='navy', lw=2, label=f'AUC = {roc_auc}')
    ax.plot([0, 1], [0, 1], color='black', lw=1)
    ax.set_xlim(0, 1.01)
    ax.set_ylim(0, 1.05)
    ax.set_xlabel('False positive rate')
    ax.set_ylabel('Detection rate')
    # ax.set_title('Receiver Operating Characteristic')
    ax.legend(loc="lower right")

def plot_density(preds, test, ax, label='Disease'):
    sns.kdeplot(preds[test == 1], fill=True, ax=ax, label=label.capitalize())
    sns.kdeplot(preds[test == 0], fill=True, ax=ax, label=f'No {label.lower()}')
    ax.set_xlabel('Predicted probability')
    ax.set_xlim(0, 1)
    ax.set_ylabel('Density')
    # ax.set_title('Density Probability')
    ax.legend()
    
def get_better_ax(ax, axis='y'):
    # frame
    for spine in ['top', 'right']: 
        ax.spines[spine].set_visible(False)
    ax.tick_params(top=False, right=False)

    # y axis
    ax.grid(axis=axis, linestyle="--", alpha=0.5)

In [None]:
"""
Created on Tue Nov  6 10:06:52 2018

@author: yandexdataschool

Original Code found in:
https://github.com/yandexdataschool/roc_comparison

updated: Raul Sanchez-Vazquez
"""

import numpy as np
import scipy.stats
from scipy import stats

# AUC comparison adapted from
# https://github.com/Netflix/vmaf/
def compute_midrank(x):
    """Computes midranks.
    Args:
       x - a 1D numpy array
    Returns:
       array of midranks
    """
    J = np.argsort(x)
    Z = x[J]
    N = len(x)
    T = np.zeros(N, dtype=np.float64)
    i = 0
    while i < N:
        j = i
        while j < N and Z[j] == Z[i]:
            j += 1
        T[i:j] = 0.5*(i + j - 1)
        i = j
    T2 = np.empty(N, dtype=np.float64)
    # Note(kazeevn) +1 is due to Python using 0-based indexing
    # instead of 1-based in the AUC formula in the paper
    T2[J] = T + 1
    return T2

def compute_midrank_weight(x, sample_weight):
    """Computes midranks.
    Args:
       x - a 1D numpy array
    Returns:
       array of midranks
    """
    J = np.argsort(x)
    Z = x[J]
    cumulative_weight = np.cumsum(sample_weight[J])
    N = len(x)
    T = np.zeros(N, dtype=np.float64)
    i = 0
    while i < N:
        j = i
        while j < N and Z[j] == Z[i]:
            j += 1
        T[i:j] = cumulative_weight[i:j].mean()
        i = j
    T2 = np.empty(N, dtype=np.float64)
    T2[J] = T
    return T2


def fastDeLong(predictions_sorted_transposed, label_1_count, sample_weight):
    if sample_weight is None:
        return fastDeLong_no_weights(predictions_sorted_transposed, label_1_count)
    else:
        return fastDeLong_weights(predictions_sorted_transposed, label_1_count, sample_weight)


def fastDeLong_weights(predictions_sorted_transposed, label_1_count, sample_weight):
    """
    The fast version of DeLong's method for computing the covariance of
    unadjusted AUC.
    Args:
       predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
          sorted such as the examples with label "1" are first
    Returns:
       (AUC value, DeLong covariance)
    Reference:
     @article{sun2014fast,
       title={Fast Implementation of DeLong's Algorithm for
              Comparing the Areas Under Correlated Receiver Oerating Characteristic Curves},
       author={Xu Sun and Weichao Xu},
       journal={IEEE Signal Processing Letters},
       volume={21},
       number={11},
       pages={1389--1393},
       year={2014},
       publisher={IEEE}
     }
    """
    # Short variables are named as they are in the paper
    m = label_1_count
    n = predictions_sorted_transposed.shape[1] - m
    positive_examples = predictions_sorted_transposed[:, :m]
    negative_examples = predictions_sorted_transposed[:, m:]
    k = predictions_sorted_transposed.shape[0]

    tx = np.empty([k, m], dtype=np.float64)
    ty = np.empty([k, n], dtype=np.float64)
    tz = np.empty([k, m + n], dtype=np.float64)
    for r in range(k):
        tx[r, :] = compute_midrank_weight(positive_examples[r, :], sample_weight[:m])
        ty[r, :] = compute_midrank_weight(negative_examples[r, :], sample_weight[m:])
        tz[r, :] = compute_midrank_weight(predictions_sorted_transposed[r, :], sample_weight)
    total_positive_weights = sample_weight[:m].sum()
    total_negative_weights = sample_weight[m:].sum()
    pair_weights = np.dot(sample_weight[:m, np.newaxis], sample_weight[np.newaxis, m:])
    total_pair_weights = pair_weights.sum()
    aucs = (sample_weight[:m]*(tz[:, :m] - tx)).sum(axis=1) / total_pair_weights
    v01 = (tz[:, :m] - tx[:, :]) / total_negative_weights
    v10 = 1. - (tz[:, m:] - ty[:, :]) / total_positive_weights
    sx = np.cov(v01)
    sy = np.cov(v10)
    delongcov = sx / m + sy / n
    return aucs, delongcov

def fastDeLong_no_weights(predictions_sorted_transposed, label_1_count):
    """
    The fast version of DeLong's method for computing the covariance of
    unadjusted AUC.
    Args:
       predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
          sorted such as the examples with label "1" are first
    Returns:
       (AUC value, DeLong covariance)
    Reference:
     @article{sun2014fast,
       title={Fast Implementation of DeLong's Algorithm for
              Comparing the Areas Under Correlated Receiver Oerating
              Characteristic Curves},
       author={Xu Sun and Weichao Xu},
       journal={IEEE Signal Processing Letters},
       volume={21},
       number={11},
       pages={1389--1393},
       year={2014},
       publisher={IEEE}
     }
    """
    # Short variables are named as they are in the paper
    m = label_1_count
    n = predictions_sorted_transposed.shape[1] - m
    positive_examples = predictions_sorted_transposed[:, :m]
    negative_examples = predictions_sorted_transposed[:, m:]
    k = predictions_sorted_transposed.shape[0]

    tx = np.empty([k, m], dtype=np.float64)
    ty = np.empty([k, n], dtype=np.float64)
    tz = np.empty([k, m + n], dtype=np.float64)
    for r in range(k):
        tx[r, :] = compute_midrank(positive_examples[r, :])
        ty[r, :] = compute_midrank(negative_examples[r, :])
        tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :])
    aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n
    v01 = (tz[:, :m] - tx[:, :]) / n
    v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m
    sx = np.cov(v01)
    sy = np.cov(v10)
    delongcov = sx / m + sy / n
    return aucs, delongcov


def calc_pvalue(aucs, sigma):
    """Computes log(10) of p-values.
    Args:
       aucs: 1D array of AUCs
       sigma: AUC DeLong covariances
    Returns:
       log10(pvalue)
    """
    l = np.array([[1, -1]])
    z = np.abs(np.diff(aucs)) / np.sqrt(np.dot(np.dot(l, sigma), l.T))
    return np.log10(2) + scipy.stats.norm.logsf(z, loc=0, scale=1) / np.log(10)

def compute_ground_truth_statistics(ground_truth, sample_weight):
    assert np.array_equal(np.unique(ground_truth), [0, 1])
    order = (-ground_truth).argsort()
    label_1_count = int(ground_truth.sum())
    if sample_weight is None:
        ordered_sample_weight = None
    else:
        ordered_sample_weight = sample_weight[order]

    return order, label_1_count, ordered_sample_weight

def delong_roc_variance(ground_truth, predictions, sample_weight=None):
    """
    Computes ROC AUC variance for a single set of predictions
    Args:
       ground_truth: np.array of 0 and 1
       predictions: np.array of floats of the probability of being class 1
    """
    order, label_1_count, ordered_sample_weight = compute_ground_truth_statistics(
        ground_truth, sample_weight)
    predictions_sorted_transposed = predictions[np.newaxis, order]
    aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count, ordered_sample_weight)
    assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers"
    return aucs[0], delongcov

In [None]:
def get_auc_with_ci(preds, test, ci=0.95):
    auc, auc_cov = delong_roc_variance(test, preds)
    auc_std = np.sqrt(auc_cov)
    lower_upper_q = np.abs(np.array([0, 1]) - (1 - ci) / 2)
    
    ci = stats.norm.ppf(lower_upper_q, loc=auc, scale=auc_std)
    ci[ci > 1] = 1
    return auc, ci[0], ci[1]

In [None]:
def find_best_threshold(preds, test, strategy='default'):    
    if strategy == 'default':
        best_thresh = .5
    
    elif strategy == 'Youden':
        fpr, tpr, thresholds = roc_curve(test, preds)
        J = tpr - fpr
        ix = np.argmax(J)
        best_thresh = thresholds[ix]
    
    elif strategy == 'F1-score':
        precision, recall, thresholds = precision_recall_curve(test, preds)
        fscore = (2 * precision * recall) / (precision + recall)
        ix = np.argmax(fscore)
        best_thresh = thresholds[ix]
    
    elif strategy == '5%-fpr':
        best_thresh = 1
        fpr=0
        while fpr < 5/100:
            best_thresh -= .001
            preds_binary = np.where(preds >= best_thresh, 1, 0)    
            tn, fp, fn, tp = confusion_matrix(test, preds_binary).ravel()
            fpr = fp/(fp+tn)
    
    elif strategy == 'detect half':
        best_thresh = 1
        detection_rate = 0
        while detection_rate < .5:
            best_thresh -= .001
            preds_binary = np.where(preds >= best_thresh, 1, 0)    
            tn, fp, fn, tp = confusion_matrix(test, preds_binary).ravel()
            detection_rate = tp/(tp+fn)    
    
    else:
        raise ValueError(f'The strategy {strategy} is not implemented')
    
    return best_thresh

In [None]:
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss
import matplotlib.pyplot as plt

def plot_calibration_curve(y_true, y_prob, title="Calibration Plot", ax=None):
      """Plot calibration curve and compute calibration metrics"""
      if ax is None:
          fig, ax = plt.subplots(figsize=(8, 6))

      # Calibration curve
      fraction_of_positives, mean_predicted_value = calibration_curve(
          y_true, y_prob, n_bins=10, strategy='uniform'
      )

      # Plot perfect calibration line
      ax.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')

      # Plot actual calibration
      ax.plot(mean_predicted_value, fraction_of_positives, 's-',
              label=f'Model calibration')

      ax.set_xlabel('Mean Predicted Probability')
      ax.set_ylabel('Fraction of Positives')
      ax.set_title(title)
      ax.legend()
      ax.grid(True, alpha=0.3)

      # Compute metrics
      brier = brier_score_loss(y_true, y_prob)

      # Expected Calibration Error (ECE)
      bin_boundaries = np.linspace(0, 1, 11)
      bin_lowers = bin_boundaries[:-1]
      bin_uppers = bin_boundaries[1:]

      ece = 0
      for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
          in_bin = (y_prob > bin_lower) & (y_prob <= bin_upper)
          prop_in_bin = in_bin.mean()

          if prop_in_bin > 0:
              accuracy_in_bin = y_true[in_bin].mean()
              avg_confidence_in_bin = y_prob[in_bin].mean()
              ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

      ax.text(0.05, 0.95, f'Brier Score: {brier:.3f}\nECE: {ece:.3f}',
              transform=ax.transAxes, verticalalignment='top',
              bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

      return brier, ece

def evaluate(preds, test, title, disease_label):
    if preds.ndim == 2: preds = preds[:, 1].reshape(test.shape)
    
    ## Plot AUC/ROC and density probability
    fpr, tpr, thresholds = roc_curve(test, preds)
    auc, lower_ci, upper_ci = get_auc_with_ci(preds, test, ci=0.95)
    
    fig, (ax_roc, ax_density) = plt.subplots(1, 2, figsize=(12, 4))
    plt.suptitle(title)

    plot_roc(fpr, tpr, f'{auc:.2f} ({lower_ci:.2f}-{upper_ci:.2f})', ax_roc)
    plot_density(preds, test, ax_density, disease_label)

    for ax, axis in [(ax_roc, 'both'), (ax_density, 'y')]: get_better_ax(ax, axis)
    plt.tight_layout()
    plt.show()
    
    ## Other metrics depending on the threshold
    res = {}
    for strategy in ['5%-fpr', 'detect half', 'Youden', 'F1-score']:
        threshold = find_best_threshold(preds, test, strategy)        
        preds_binary = np.where(preds >= threshold, 1, 0)    
        
        tn, fp, fn, tp = confusion_matrix(test, preds_binary).ravel()
        
        # Metrics
        if fp > 0:
            frac = Fraction(f'{tp}/{fp}')
            true_to_false_ratio = f'{frac.numerator} to {frac.denominator}'
        else:
            true_to_false_ratio = 'nan'
            
        res[strategy] = {
            'threshold':threshold,
            
            'tn':tn,
            'fp':fp,
            'fn':fn,
            'tp':tp,
            
            'detection_rate':tp/(tp+fn), 
            'missed_per_10_cases':fn/(tp+fn)*10, 
            'fpr':fp/(fp+tn), 
            'true_to_false_ratio':true_to_false_ratio,
        }
    display(pd.DataFrame(res).T)

    # Nouvelles métriques de calibration
    print("\n--- Calibration Metrics ---")
    brier, ece = plot_calibration_curve(test, preds,
                                         title=f"Calibration - {title}")
    plt.show()

    print(f"Brier Score: {brier:.3f}")
    print(f"Expected Calibration Error: {ece:.3f}")
    
    res['auc']=(auc, lower_ci, upper_ci)
    res['brier_score'] = brier
    res['ece'] = ece


    return res

In [None]:
def save_model(model, performances:Dict[str, Dict[str, float]], 
               disease:str, up_to_year:int, model_name:str, base_dir:str=OUTPUT_DIR):
    """
    Save model and performances in {base_dir}/{disease}_{up_to_year}/{model_name}
    respectively as model.pkl and performances.pkl.
    """
    model_dir = os.path.join(base_dir, f"{disease}_{up_to_year}")
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, "model.pkl")
    with open(model_path, "wb") as file:
        pickle.dump(model, file)

    performances_path = os.path.join(model_dir, "performances.pkl")
    with open(performances_path, "wb") as file:
        pickle.dump(performances, file)

def load_model(disease:str, up_to_year:int, model_name:str, base_dir:str=OUTPUT_DIR):
    """
    Load model and performances saved in {base_dir}/{disease}_{up_to_year}/{model_name}
    respectively as model.pkl and performances.pkl.
    """
    model_dir = os.path.join(base_dir, f"{disease}_{up_to_year}")
    with open(path, "rb") as file:
        model = pickle.load(os.path.join(model_dir, "model.pkl"))
    
    with open(performances_path, "rb") as file:
        performances = pickle.load(os.path.join(model_dir, "performances.pkl"))
    
    return model, performances

# Training

In [None]:
age = 65
disease = 'all_dementias'
include_charlson_bmi = True
pred_up_to_year = 2
exclude_mci_baseline = False  
include_deceased_as_zeros = True

print(f'Task: detect patients that will developp {disease} in the following {pred_up_to_year} years')
if exclude_mci_baseline:
      print(f'Excluding patients with MCI at baseline')

print(f'\n\nUK {age}:')
X, y = get_data(country='UK', age=age, disease=disease, include_charlson_bmi=include_charlson_bmi, pred_up_to_year=pred_up_to_year,include_deceased_as_zeros=include_deceased_as_zeros)

# Exclure les patients MCI à baseline
if exclude_mci_baseline and 'mci_at_baseline' in X.columns:
    mask = X['mci_at_baseline'] == 0
    X = X[mask]
    y = y[mask]
    # Supprimer la feature mci_at_baseline
    X = X.drop('mci_at_baseline', axis=1)
    print(f'Removed {(~mask).sum()} MCI patients at baseline')

X_train, X_val, y_train, y_val = split(X, y)

print(f'\n\nFR {age}:')
X_test, y_test = get_data(country='FR', age=age, disease=disease, include_charlson_bmi=include_charlson_bmi, pred_up_to_year=pred_up_to_year)

# Exclure les patients MCI à baseline pour le test set aussi
if exclude_mci_baseline and 'mci_at_baseline' in X_test.columns:
    mask_test = X_test['mci_at_baseline'] == 0
    X_test = X_test[mask_test]
    y_test = y_test[mask_test]
    X_test = X_test.drop('mci_at_baseline', axis=1)
    print(f'Removed {(~mask_test).sum()} MCI patients at baseline from test set')

In [None]:
from sklearn.calibration import CalibratedClassifierCV


res = {}

X_train = X_train.fillna(0)
X_val = X_val.fillna(0)

X_train_sub, X_cal, y_train_sub, y_cal = train_test_split(
      X_train, y_train, test_size=0.2, random_state=SEED, stratify=y_train
  )

model = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000).fit(X_train, y_train)
model.fit(X_train_sub, y_train_sub)
# Calibration post-hoc avec Platt scaling
calibrated_clf = CalibratedClassifierCV(model, method='isotonic', cv='prefit')
calibrated_clf.fit(X_cal, y_cal)
y_prob_calibrated = model.predict_proba(X_val)

res['class_weight'] = evaluate(y_prob_calibrated, y_val, f'Results of the baseline on val set using balanced class_weight', 'dementias')['auc']

In [None]:
from sklearn.dummy import DummyClassifier
dummy = DummyClassifier(strategy='prior')
dummy.fit(X_train_sub, y_train_sub)
y_prob_dummy = dummy.predict_proba(X_val)
brier_dummy = brier_score_loss(y_val, y_prob_dummy[:, 1])
print(f"Brier score dummy: {brier_dummy}")

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

SEED = 1

def compare_roc_curves_mci(age=65, disease='all_dementias', include_charlson_bmi=True, 
                            pred_up_to_year=5, include_deceased_as_zeros=True):

      fig, ax = plt.subplots(figsize=(10, 8))

      # Modèle AVEC patients MCI
      print("=== AVEC patients MCI ===")
      X_with, y_with = get_data(country='UK', age=age, disease=disease,
                                  include_charlson_bmi=include_charlson_bmi,
                                  pred_up_to_year=pred_up_to_year,
                                  include_deceased_as_zeros=include_deceased_as_zeros)

      X_train_with, X_val_with, y_train_with, y_val_with = split(X_with, y_with)
      X_train_with = X_train_with.fillna(0)
      X_val_with = X_val_with.fillna(0)

      model_with = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
      model_with.fit(X_train_with, y_train_with)
      y_prob_with = model_with.predict_proba(X_val_with)[:, 1]

      fpr_with, tpr_with, _ = roc_curve(y_val_with, y_prob_with)
      auc_with, lower_ci_with, upper_ci_with = get_auc_with_ci(y_prob_with, y_val_with, ci=0.95)

      # Modèle SANS patients MCI
      print("=== SANS patients MCI ===")
      X_without, y_without = get_data(country='UK', age=age, disease=disease,
                                      include_charlson_bmi=include_charlson_bmi,
                                      pred_up_to_year=pred_up_to_year,
                                      include_deceased_as_zeros=include_deceased_as_zeros)

      # Exclure les patients MCI à baseline
      if 'mci_at_baseline' in X_without.columns:
          mask = X_without['mci_at_baseline'] == 0
          X_without = X_without[mask]
          y_without = y_without[mask]
          X_without = X_without.drop('mci_at_baseline', axis=1)
          print(f'Removed {(~mask).sum()} MCI patients at baseline')

      X_train_without, X_val_without, y_train_without, y_val_without = split(X_without, y_without)
      X_train_without = X_train_without.fillna(0)
      X_val_without = X_val_without.fillna(0)

      model_without = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
      model_without.fit(X_train_without, y_train_without)
      y_prob_without = model_without.predict_proba(X_val_without)[:, 1]

      fpr_without, tpr_without, _ = roc_curve(y_val_without, y_prob_without)
      auc_without, lower_ci_without, upper_ci_without = get_auc_with_ci(y_prob_without, y_val_without, ci=0.95)

      # Plot des deux courbes avec intervalles de confiance
      ax.plot(fpr_with, tpr_with, 'b-', linewidth=2.5,
              label=f'MCI patients included (AUC = {auc_with:.2f} [{lower_ci_with:.2f}-{upper_ci_with:.2f}])')
      ax.plot(fpr_without, tpr_without, 'r-', linewidth=2.5,
              label=f'MCI patients excluded (AUC = {auc_without:.2f} [{lower_ci_without:.2f}-{upper_ci_without:.2f}])')

      # Ligne de référence
      ax.plot([0, 1], [0, 1], 'k--', alpha=0.6, linewidth=1.5, label='Random classifier')

      # Formatage amélioré
      ax.set_xlabel('False Positive Rate', fontsize=16)
      ax.set_ylabel('True Positive Rate', fontsize=16)
      ax.set_title(f"All-cause dementia disease risk prediction ROC curve",
                      fontsize=20, pad=20)
      ax.legend(loc='lower right', fontsize=16)
      ax.grid(True, alpha=0.3)
      ax.set_xlim(0, 1)
      ax.set_ylim(0, 1)

      # Remove top and right spines
      ax.spines['top'].set_visible(False)
      ax.spines['right'].set_visible(False)

      plt.tight_layout()
      plt.savefig(OUTPUT_DIR+'figure4.pdf')
      plt.show()

      # Statistiques comparatives
      print(f"\n=== RÉSULTATS COMPARATIFS ===")
      print(f"Avec MCI - Taille dataset: {len(X_with)}, AUC: {auc_with:.2f} [{lower_ci_with:.2f}-{upper_ci_with:.2f}]")
      print(f"Sans MCI - Taille dataset: {len(X_without)}, AUC: {auc_without:.2f} [{lower_ci_without:.2f}-{upper_ci_without:.2f}]")
      print(f"Différence AUC: {auc_with - auc_without:.3f}")

      return {
          'auc_with_mci': auc_with,
          'ci_with_mci': (lower_ci_with, upper_ci_with),
          'auc_without_mci': auc_without,
          'ci_without_mci': (lower_ci_without, upper_ci_without),
          'n_with_mci': len(X_with),
          'n_without_mci': len(X_without)
      }

# Utilisation
results = compare_roc_curves_mci(disease='alzheimer', include_deceased_as_zeros=True)

In [None]:
def plot_roc_curves_multitime(country='UK', age=65, disease='all_dementias', 
                                     include_charlson_bmi=True, exclude_mci_baseline=False,
                                     include_deceased_as_zeros=True, prediction_years=[2, 5, 10]):
           """
           Plot ROC curves for different prediction time horizons (2, 5, 10 years) on the same figure.
           
           Parameters:
           -----------
           country : str
               'UK' or 'FR'
           age : int
               65 or 70
           disease : str
               Disease of interest
           include_charlson_bmi : bool
               Whether to include BMI and CHARLSON features
           exclude_mci_baseline : bool
               Whether to exclude patients with MCI at baseline
           include_deceased_as_zeros : bool
               Whether to include deceased patients as negatives
           prediction_years : list
               List of prediction horizons in years (e.g., [2, 5, 10])
           """

           fig, ax = plt.subplots(figsize=(10, 8))

           # Colors for different time horizons
           colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

           results_summary = {}

           for i, pred_year in enumerate(prediction_years):
               print(f"\n=== PREDICTION À {pred_year} ANS ===")

               # Load data for this prediction horizon
               X, y = get_data(country=country, age=age, disease=disease,
                              include_charlson_bmi=include_charlson_bmi,
                              pred_up_to_year=pred_year,
                              include_deceased_as_zeros=include_deceased_as_zeros)

               # Exclude MCI patients if requested
               if exclude_mci_baseline and 'mci_at_baseline' in X.columns:
                   mask = X['mci_at_baseline'] == 0
                   X = X[mask]
                   y = y[mask]
                   X = X.drop('mci_at_baseline', axis=1)
                   print(f'Removed {(~mask).sum()} MCI patients at baseline')

               # Split data
               X_train, X_val, y_train, y_val = split(X, y)
               X_train = X_train.fillna(0)
               X_val = X_val.fillna(0)

               # Train model
               model = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
               model.fit(X_train, y_train)

               # Get predictions
               y_prob = model.predict_proba(X_val)[:, 1]

               # Calculate ROC curve
               fpr, tpr, _ = roc_curve(y_val, y_prob)
               auc_score, lower_ci, upper_ci = get_auc_with_ci(y_prob, y_val, ci=0.95)

               # Plot ROC curve
               color = colors[i % len(colors)]
               ax.plot(fpr, tpr, color=color, linewidth=2.5,
                      label=f'{pred_year} years (AUC = {auc_score:.2f} [{lower_ci:.2f}-{upper_ci:.2f}])')

               # Store results
               results_summary[f'{pred_year}_years'] = {
                   'auc': auc_score,
                   'ci_lower': lower_ci,
                   'ci_upper': upper_ci,
                   'n_patients': len(X),
                   'n_cases': y.sum(),
                   'prevalence': y.mean()
               }

           # Reference line (random classifier)
           ax.plot([0, 1], [0, 1], 'k--', alpha=0.6, linewidth=1.5, label='Random classifier')

           # Formatting
           ax.set_xlabel('False Positive Rate', fontsize=16)
           ax.set_ylabel('True Positive Rate', fontsize=16)
           ax.set_title(f"Alzheimer's disease risk prediction ROC curve",
                       fontsize=20, pad=20)

           ax.legend(loc='lower right', fontsize=16)
           ax.grid(True, alpha=0.3)
           ax.set_xlim(0, 1)
           ax.set_ylim(0, 1)

           # Remove top and right spines
           ax.spines['top'].set_visible(False)
           ax.spines['right'].set_visible(False)

           plt.tight_layout()
           plt.savefig(OUTPUT_DIR+'figure4.pdf')
           plt.show()


           # Print summary statistics
           print(f"\n=== RÉSUMÉ COMPARATIF ===")
           print(f"{'Time':<8} {'AUC':<15} {'95% CI':<20} {'N patients':<12} {'N cases':<10} {'Prevalence':<12}")
           print("-" * 80)

           for pred_year in prediction_years:
               results = results_summary[f'{pred_year}_years']
               print(f"{pred_year} ans{'':<4} "
                     f"{results['auc']:.3f}{'':<11} "
                     f"[{results['ci_lower']:.3f}-{results['ci_upper']:.3f}]{'':<8} "
                     f"{results['n_patients']:<12,} "
                     f"{results['n_cases']:<10,} "
                     f"{results['prevalence']*100:.2f}%")

           return results_summary

# Exemple d'utilisation
print("Comparaison des courbes ROC pour différents horizons de prédiction")
results = plot_roc_curves_multitime(
    country='UK',
    age=65,
    disease='alzheimer',
    include_charlson_bmi=True,
    exclude_mci_baseline=False,
    include_deceased_as_zeros=True,
    prediction_years=[2, 5, 10]
)

In [None]:
def generate_performance_table(diseases=['all_dementias', 'alzheimer'], 
                                      prediction_years=[2, 5, 10], 
                                      age=65, 
                                      include_charlson_bmi=True,
                                      exclude_mci_baseline=False,
                                      include_deceased_as_zeros=True,
                                      n_bootstrap=1000):
    """
    Generate a comprehensive performance table for different diseases and prediction horizons.
    
    Parameters:
    -----------
    diseases : list
        List of diseases to evaluate
    prediction_years : list
        List of prediction horizons in years
    age : int
        Age threshold (65 or 70)
    include_charlson_bmi : bool
        Whether to include BMI and CHARLSON features
    exclude_mci_baseline : bool
        Whether to exclude patients with MCI at baseline
    include_deceased_as_zeros : bool
        Whether to include deceased patients as negatives
    n_bootstrap : int
        Number of bootstrap samples for confidence intervals
        
    Returns:
    --------
    pd.DataFrame
        Performance table with all metrics
    """

    results = []
    screening_results = []  # Store data for screening analysis

    for disease in diseases:
        for pred_year in prediction_years:
            print(f"\n{'='*60}")
            print(f"Processing {disease} - {pred_year} years prediction")
            print(f"{'='*60}")

            row_data = {
                'Disease': 'Dementia' if disease == 'all_dementias' else disease.capitalize(),
                'Prediction up to year': f'{pred_year} years'
            }

            # Evaluate on both UK and FR
            for country in ['UK', 'FR']:
                print(f"\n--- Evaluating on {country} ---")

                try:
                    # Load data
                    if country == 'UK':
                        # For UK: train and test on the same data (with train/val split)
                        X, y = get_data(country='UK', age=age, disease=disease,
                                        include_charlson_bmi=include_charlson_bmi,
                                        pred_up_to_year=pred_year,
                                        include_deceased_as_zeros=include_deceased_as_zeros)

                        # Exclude MCI if requested
                        if exclude_mci_baseline and 'mci_at_baseline' in X.columns:
                            mask = X['mci_at_baseline'] == 0
                            X = X[mask]
                            y = y[mask]
                            X = X.drop('mci_at_baseline', axis=1)
                            print(f'Removed {(~mask).sum()} MCI patients at baseline')

                        # Split data
                        X_train, X_test, y_train, y_test = split(X, y)
                        X_train = X_train.fillna(0)
                        X_test = X_test.fillna(0)

                    else:  # FR
                        # For FR: train on UK, test on FR
                        print("Training on UK data...")
                        X_train, y_train = get_data(country='UK', age=age, disease=disease,
                                                    include_charlson_bmi=include_charlson_bmi,
                                                    pred_up_to_year=pred_year,
                                                    include_deceased_as_zeros=include_deceased_as_zeros)

                        if exclude_mci_baseline and 'mci_at_baseline' in X_train.columns:
                            mask_train = X_train['mci_at_baseline'] == 0
                            X_train = X_train[mask_train]
                            y_train = y_train[mask_train]
                            X_train = X_train.drop('mci_at_baseline', axis=1)

                        X_train = X_train.fillna(0)

                        print("Testing on FR data...")
                        X_test, y_test = get_data(country='FR', age=age, disease=disease,
                                                include_charlson_bmi=include_charlson_bmi,
                                                pred_up_to_year=pred_year,
                                                include_deceased_as_zeros=include_deceased_as_zeros)

                        if exclude_mci_baseline and 'mci_at_baseline' in X_test.columns:
                            mask_test = X_test['mci_at_baseline'] == 0
                            X_test = X_test[mask_test]
                            y_test = y_test[mask_test]
                            X_test = X_test.drop('mci_at_baseline', axis=1)

                        X_test = X_test.fillna(0)

                    print(f"Train set: {len(X_train)} patients ({y_train.sum()} cases)")
                    print(f"Test set: {len(X_test)} patients ({y_test.sum()} cases)")

                    # Train model
                    model = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
                    model.fit(X_train, y_train)

                    # Get predictions
                    y_prob = model.predict_proba(X_test)[:, 1]

                    # Store data for screening analysis (for ALL diseases on UK validation set)
                    if country == 'UK':
                        screening_results.append({
                            'disease': disease,
                            'pred_year': pred_year,
                            'y_test': y_test,
                            'y_prob': y_prob,
                            'n_patients': len(y_test),
                            'n_cases': y_test.sum()
                        })

                    # Calculate ROC AUC with CI
                    roc_auc_score, roc_lower_ci, roc_upper_ci = get_auc_with_ci(y_prob, y_test, ci=0.95)
                    
                    # Calculate Brier score for calibration
                    from sklearn.metrics import brier_score_loss
                    brier_score = brier_score_loss(y_test, y_prob)

                    # Calculate detection rate for 5% FPR
                    def bootstrap_metric(y_true, y_scores, metric_func, n_bootstrap=n_bootstrap):
                        """Bootstrap confidence intervals for custom metrics"""
                        np.random.seed(SEED)
                        bootstrap_values = []
                        n_samples = len(y_true)

                        for _ in range(n_bootstrap):
                            # Bootstrap sample
                            indices = np.random.choice(n_samples, n_samples, replace=True)
                            y_boot = y_true.iloc[indices] if hasattr(y_true, 'iloc') else y_true[indices]
                            scores_boot = y_scores[indices]

                            try:
                                value = metric_func(y_boot, scores_boot)
                                if not np.isnan(value):
                                    bootstrap_values.append(value)
                            except:
                                continue

                        if len(bootstrap_values) > 0:
                            return (np.mean(bootstrap_values),
                                    np.percentile(bootstrap_values, 2.5),
                                    np.percentile(bootstrap_values, 97.5))
                        else:
                            return (np.nan, np.nan, np.nan)

                    # Detection rate at 5% FPR
                    def detection_rate_at_5pct_fpr(y_true, y_scores):
                        fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                        # Find threshold for 5% FPR
                        target_fpr = 0.05
                        idx = np.argmax(fpr >= target_fpr)
                        if idx > 0:
                            return tpr[idx] * 100  # Convert to percentage
                        return 0

                    # FPR at 50% detection rate
                    def fpr_at_50pct_detection(y_true, y_scores):
                        fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                        # Find threshold for 50% TPR
                        target_tpr = 0.5
                        idx = np.argmax(tpr >= target_tpr)
                        if idx < len(fpr):
                            return fpr[idx] * 100  # Convert to percentage
                        return 100

                    # Calculate metrics with bootstrap CIs
                    det_rate_mean, det_rate_lower, det_rate_upper = bootstrap_metric(
                        y_test, y_prob, detection_rate_at_5pct_fpr)

                    fpr_mean, fpr_lower, fpr_upper = bootstrap_metric(
                        y_test, y_prob, fpr_at_50pct_detection)

                    # Format results
                    roc_auc_str = f"{roc_auc_score:.2f} ({roc_lower_ci:.2f}-{roc_upper_ci:.2f})"
                    brier_str = f"{brier_score:.3f}"
                    det_rate_str = f"{det_rate_mean:.1f}% ({det_rate_lower:.1f}%-{det_rate_upper:.1f}%)"
                    fpr_str = f"{fpr_mean:.1f}% ({fpr_lower:.1f}%-{fpr_upper:.1f}%)"

                    if country == 'UK':
                        row_data['ROC AUC on UK'] = roc_auc_str
                        row_data['Brier Score on UK'] = brier_str
                        row_data['Detection rate for 5% FPR'] = det_rate_str
                        row_data['FPR for 50% detection rate'] = fpr_str
                        
                        # Calculate precision of top 1% and lift based on actual dataset prevalence
                        dataset_prevalence = y_test.sum() / len(y_test)  # Actual prevalence in test set
                        
                        n_patients = len(y_test)
                        top_1_percent_size = max(1, int(0.01 * n_patients))  # At least 1 patient
                        
                        # Sort patients by prediction score (highest first)
                        sorted_indices = np.argsort(y_prob)[::-1]
                        y_sorted = y_test.iloc[sorted_indices] if hasattr(y_test, 'iloc') else y_test[sorted_indices]
                        
                        # Get top 1% patients
                        top_1_percent_patients = y_sorted[:top_1_percent_size]
                        top_1_percent_cases = top_1_percent_patients.sum()
                        
                        # Calculate precision of top 1%
                        precision_top_1_percent = top_1_percent_cases / top_1_percent_size if top_1_percent_size > 0 else 0
                        
                        # Calculate lift compared to actual dataset prevalence
                        lift = precision_top_1_percent / dataset_prevalence if dataset_prevalence > 0 else 0
                        
                        row_data['Precision top 1%'] = f"{precision_top_1_percent*100:.1f}%"
                        row_data['Lift (vs dataset prevalence)'] = f"{lift:.1f}x"
                        
                    else:
                        row_data['ROC AUC on FR'] = roc_auc_str
                        row_data['Brier Score on FR'] = brier_str

                    print(f"{country} Results:")
                    print(f"  ROC AUC: {roc_auc_str}")
                    print(f"  Brier Score: {brier_str}")
                    if country == 'UK':
                        print(f"  Detection rate at 5% FPR: {det_rate_str}")
                        print(f"  FPR at 50% detection: {fpr_str}")

                except Exception as e:
                    print(f"Error processing {country}: {str(e)}")
                    if country == 'UK':
                        row_data['ROC AUC on UK'] = 'N/A'
                        row_data['Brier Score on UK'] = 'N/A'
                        row_data['Detection rate for 5% FPR'] = 'N/A'
                        row_data['FPR for 50% detection rate'] = 'N/A'
                        row_data['Precision top 1%'] = 'N/A'
                        row_data['Lift (vs dataset prevalence)'] = 'N/A'
                    else:
                        row_data['ROC AUC on FR'] = 'N/A'
                        row_data['Brier Score on FR'] = 'N/A'

            results.append(row_data)

    # Create DataFrame
    df = pd.DataFrame(results)

    # Reorder columns to match the desired format
    column_order = ['Disease', 'Prediction up to year', 'ROC AUC on UK', 'Brier Score on UK',
                    'Detection rate for 5% FPR', 'FPR for 50% detection rate', 
                    'ROC AUC on FR', 'Brier Score on FR',
                    'Precision top 1%', 'Lift (vs dataset prevalence)']
    df = df[column_order]

    # Calculate screening requirements for 80% detection rate for ALL diseases
    if screening_results:  # Only if we have screening data
        print("\n" + "="*100)
        print("SCREENING REQUIREMENTS FOR 80% DETECTION RATE")
        print("="*100)

        for screen_data in screening_results:
            disease = screen_data['disease']
            pred_year = screen_data['pred_year']
            y_test = screen_data['y_test']
            y_prob = screen_data['y_prob']
            n_patients = screen_data['n_patients']
            n_cases = screen_data['n_cases']

            # Calculate threshold for 80% detection rate
            fpr, tpr, thresholds = roc_curve(y_test, y_prob)
            target_tpr = 0.80

            # Find the index where TPR >= 80%
            idx = np.argmax(tpr >= target_tpr)

            disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()

            if idx < len(thresholds) and tpr[idx] >= target_tpr:
                threshold_80 = thresholds[idx]
                fpr_80 = fpr[idx]
                tpr_80 = tpr[idx]

                # Calculate how many patients need to be screened
                prevalence = n_cases / n_patients

                # Number of patients needed to screen per case detected at 80% sensitivity
                patients_to_screen_per_case = 1 / (tpr_80 * prevalence) if (tpr_80 * prevalence) > 0 else float('inf')

                print(f"\n{disease_name} - {pred_year}-year prediction:")
                print(f"  Threshold for 80% detection: {threshold_80:.3f}")
                print(f"  Sensitivity (TPR): {tpr_80*100:.1f}%")
                print(f"  FPR: {fpr_80*100:.1f}%")
                print(f"  Prevalence in test set: {prevalence*100:.2f}% ({n_cases}/{n_patients})")
                print(f"  Number needed to screen: {patients_to_screen_per_case:.0f} patients per case detected")
            else:
                print(f"\n{disease_name} - {pred_year}-year prediction:")
                print(f"  Cannot achieve 80% detection rate with available data")
                print(f"  Maximum achievable TPR: {max(tpr)*100:.1f}%")

    return df

In [None]:
def generate_multi_algorithm_table(diseases=['all_dementias', 'alzheimer'], 
                                          prediction_years=[2, 5, 10], 
                                          age=65, 
                                          include_charlson_bmi=True,
                                          exclude_mci_baseline=False,
                                          include_deceased_as_zeros=True,
                                          n_bootstrap=500):
           """
           Generate a performance comparison table for different algorithms.
           For each prediction task, compare algorithms: Logistic Regression, Random Forest, SVM, Neural Network.
           
           Parameters:
           -----------
           diseases : list
               List of diseases to evaluate
           prediction_years : list
               List of prediction horizons in years
           age : int
               Age threshold (65 or 70)
           include_charlson_bmi : bool
               Whether to include BMI and CHARLSON features
           exclude_mci_baseline : bool
               Whether to exclude patients with MCI at baseline
           include_deceased_as_zeros : bool
               Whether to include deceased patients as negatives
           n_bootstrap : int
               Number of bootstrap samples for confidence intervals (reduced for speed)
               
           Returns:
           --------
           pd.DataFrame
               Performance table comparing algorithms
           """

           # Import additional algorithms
           from sklearn.ensemble import RandomForestClassifier
           from sklearn.svm import SVC
           from sklearn.neural_network import MLPClassifier
           from sklearn.preprocessing import StandardScaler
           from sklearn.pipeline import Pipeline

           # Define algorithms to compare
           algorithms = {
               'Logistic Regression': {
                   'model': LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000),
                   'name': 'Logistic Regression'
               },
               'Random Forest': {
                   'model': RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=SEED, n_jobs=-1),
                   'name': 'Random Forest'
               },
               'SVM': {
                   'model': Pipeline([
                       ('scaler', StandardScaler()),
                       ('svm', SVC(probability=True, class_weight='balanced', random_state=SEED, kernel='rbf'))
                   ]),
                   'name': 'Support Vector Machine'
               },
               'Neural Network': {
                   'model': Pipeline([
                       ('scaler', StandardScaler()),
                       ('mlp', MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=1000, random_state=SEED))
                   ]),
                   'name': 'Neural Network (MLP)'
               }
           }

           results = []

           for disease in diseases:
               for pred_year in prediction_years:
                   print(f"\n{'='*80}")
                   print(f"Processing {disease} - {pred_year} years prediction")
                   print(f"{'='*80}")

                   # Load data once for all algorithms
                   print("Loading UK data...")
                   X_uk, y_uk = get_data(country='UK', age=age, disease=disease,
                                        include_charlson_bmi=include_charlson_bmi,
                                        pred_up_to_year=pred_year,
                                        include_deceased_as_zeros=include_deceased_as_zeros)

                   print("Loading FR data...")
                   X_fr, y_fr = get_data(country='FR', age=age, disease=disease,
                                        include_charlson_bmi=include_charlson_bmi,
                                        pred_up_to_year=pred_year,
                                        include_deceased_as_zeros=include_deceased_as_zeros)

                   # Exclude MCI if requested
                   if exclude_mci_baseline:
                       if 'mci_at_baseline' in X_uk.columns:
                           mask_uk = X_uk['mci_at_baseline'] == 0
                           X_uk = X_uk[mask_uk]
                           y_uk = y_uk[mask_uk]
                           X_uk = X_uk.drop('mci_at_baseline', axis=1)
                           print(f'UK: Removed {(~mask_uk).sum()} MCI patients at baseline')

                       if 'mci_at_baseline' in X_fr.columns:
                           mask_fr = X_fr['mci_at_baseline'] == 0
                           X_fr = X_fr[mask_fr]
                           y_fr = y_fr[mask_fr]
                           X_fr = X_fr.drop('mci_at_baseline', axis=1)
                           print(f'FR: Removed {(~mask_fr).sum()} MCI patients at baseline')

                   # Split UK data for training and testing
                   X_train, X_test_uk, y_train, y_test_uk = split(X, y)
                   X_train = X_train.fillna(0)
                   X_test_uk = X_test_uk.fillna(0)
                   X_test_fr = X_fr.fillna(0)

                   print(f"Train set: {len(X_train)} patients ({y_train.sum()} cases)")
                   print(f"UK Test set: {len(X_test_uk)} patients ({y_test_uk.sum()} cases)")
                   print(f"FR Test set: {len(X_test_fr)} patients ({y_fr.sum()} cases)")

                   # Test each algorithm
                   for algo_key, algo_info in algorithms.items():
                       print(f"\n--- Training {algo_info['name']} ---")

                       row_data = {
                           'Disease': 'Dementia' if disease == 'all_dementias' else disease.capitalize(),
                           'Prediction up to year': f'{pred_year} years',
                           'Algorithm': algo_info['name']
                       }

                       try:
                           # Train model
                           model = algo_info['model']

                           # Handle class imbalance for Neural Network (no built-in class_weight)
                           if 'Neural Network' in algo_key:
                               # Calculate class weights manually for MLPClassifier
                               from sklearn.utils.class_weight import compute_class_weight
                               class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
                               # Create sample weights
                               sample_weights = np.array([class_weights[1] if y == 1 else class_weights[0] for y in y_train])

                               # Use a custom approach for MLPClassifier by oversampling minority class during training
                               from sklearn.utils import resample

                               # Separate majority and minority classes
                               X_train_majority = X_train[y_train == 0]
                               X_train_minority = X_train[y_train == 1]
                               y_train_majority = y_train[y_train == 0]
                               y_train_minority = y_train[y_train == 1]

                               # Upsample minority class to balance
                               n_majority = len(X_train_majority)
                               n_minority = len(X_train_minority)

                               if n_minority < n_majority:
                                   # Upsample minority class
                                   X_train_minority_upsampled, y_train_minority_upsampled = resample(
                                       X_train_minority, y_train_minority,
                                       replace=True, n_samples=n_majority, random_state=SEED)

                                   # Combine majority class with upsampled minority class
                                   X_train_balanced = pd.concat([X_train_majority, X_train_minority_upsampled])
                                   y_train_balanced = pd.concat([y_train_majority, y_train_minority_upsampled])

                                   # Shuffle the data
                                   from sklearn.utils import shuffle
                                   X_train_balanced, y_train_balanced = shuffle(X_train_balanced, y_train_balanced, random_state=SEED)
                               else:
                                   X_train_balanced = X_train
                                   y_train_balanced = y_train

                               model.fit(X_train_balanced, y_train_balanced)
                           else:
                               model.fit(X_train, y_train)

                           # Bootstrap function for confidence intervals
                           def bootstrap_metric_fast(y_true, y_scores, metric_func, n_bootstrap=n_bootstrap):
                               """Faster bootstrap with reduced samples"""
                               np.random.seed(SEED)
                               bootstrap_values = []
                               n_samples = len(y_true)

                               for _ in range(n_bootstrap):
                                   indices = np.random.choice(n_samples, n_samples, replace=True)
                                   y_boot = y_true.iloc[indices] if hasattr(y_true, 'iloc') else y_true[indices]
                                   scores_boot = y_scores[indices]

                                   try:
                                       value = metric_func(y_boot, scores_boot)
                                       if not np.isnan(value):
                                           bootstrap_values.append(value)
                                   except:
                                       continue

                               if len(bootstrap_values) > 0:
                                   return (np.mean(bootstrap_values),
                                          np.percentile(bootstrap_values, 2.5),
                                          np.percentile(bootstrap_values, 97.5))
                               else:
                                   return (np.nan, np.nan, np.nan)

                           # Define metrics
                           def detection_rate_at_5pct_fpr(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               target_fpr = 0.05
                               idx = np.argmax(fpr >= target_fpr)
                               if idx > 0:
                                   return tpr[idx] * 100
                               return 0

                           def fpr_at_50pct_detection(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               target_tpr = 0.5
                               idx = np.argmax(tpr >= target_tpr)
                               if idx < len(fpr):
                                   return fpr[idx] * 100
                               return 100

                           # Evaluate on UK test set
                           print("Evaluating on UK test set...")
                           y_prob_uk = model.predict_proba(X_test_uk)[:, 1]
                           auc_uk, lower_ci_uk, upper_ci_uk = get_auc_with_ci(y_prob_uk, y_test_uk, ci=0.95)

                           det_rate_uk, det_lower_uk, det_upper_uk = bootstrap_metric_fast(
                               y_test_uk, y_prob_uk, detection_rate_at_5pct_fpr)

                           fpr_uk, fpr_lower_uk, fpr_upper_uk = bootstrap_metric_fast(
                               y_test_uk, y_prob_uk, fpr_at_50pct_detection)

                           # Evaluate on FR test set
                           print("Evaluating on FR test set...")
                           y_prob_fr = model.predict_proba(X_test_fr)[:, 1]
                           auc_fr, lower_ci_fr, upper_ci_fr = get_auc_with_ci(y_prob_fr, y_fr, ci=0.95)

                           # Format results
                           row_data['Test AUC on UK'] = f"{auc_uk:.2f} ({lower_ci_uk:.2f}-{upper_ci_uk:.2f})"
                           row_data['Detection rate for 5% FPR'] = f"{det_rate_uk:.1f}% ({det_lower_uk:.1f}%-{det_upper_uk:.1f}%)"
                           row_data['FPR for 50% detection rate'] = f"{fpr_uk:.1f}% ({fpr_lower_uk:.1f}%-{fpr_upper_uk:.1f}%)"
                           row_data['Test AUC on FR'] = f"{auc_fr:.2f} ({lower_ci_fr:.2f}-{upper_ci_fr:.2f})"

                           print(f"  UK AUC: {auc_uk:.3f} ({lower_ci_uk:.3f}-{upper_ci_uk:.3f})")
                           print(f"  FR AUC: {auc_fr:.3f} ({lower_ci_fr:.3f}-{upper_ci_fr:.3f})")

                       except Exception as e:
                           print(f"Error with {algo_info['name']}: {str(e)}")
                           row_data['Test AUC on UK'] = 'Error'
                           row_data['Detection rate for 5% FPR'] = 'Error'
                           row_data['FPR for 50% detection rate'] = 'Error'
                           row_data['Test AUC on FR'] = 'Error'

                       results.append(row_data)

           # Create DataFrame
           df = pd.DataFrame(results)

           # Reorder columns
           column_order = ['Disease', 'Prediction up to year', 'Algorithm', 'Test AUC on UK',
                          'Detection rate for 5% FPR', 'FPR for 50% detection rate', 'Test AUC on FR']
           df = df[column_order]

           return df

# Generate the multi-algorithm comparison table
print("Generating multi-algorithm comparison table...")
print("This will take several minutes due to multiple model training and bootstrap calculations...")

multi_algo_table = generate_multi_algorithm_table(
    diseases=['all_dementias', 'alzheimer'],
    prediction_years=[2, 5, 10],
    age=65,
    include_charlson_bmi=True,
    exclude_mci_baseline=False,
    include_deceased_as_zeros=True,
    n_bootstrap=500  # Reduced for faster computation
)

print("\n" + "="*120)
print("MULTI-ALGORITHM PERFORMANCE COMPARISON TABLE")
print("="*120)
display(multi_algo_table)

# Save the table
output_file = os.path.join(OUTPUT_DIR, 'multi_algorithm_performance_table.csv')
multi_algo_table.to_csv(output_file, index=False)
print(f"\nTable saved to: {output_file}")

# Create a summary showing best performing algorithm for each task
print("\n" + "="*80)
print("BEST PERFORMING ALGORITHM SUMMARY (by UK AUC)")
print("="*80)

# Extract AUC values for comparison
def extract_auc(auc_string):
    """Extract numeric AUC value from formatted string"""
    try:
        return float(auc_string.split(' ')[0])
    except:
        return 0

multi_algo_table['AUC_numeric'] = multi_algo_table['Test AUC on UK'].apply(extract_auc)

# Group by task and find best algorithm
summary_results = []
for _, group in multi_algo_table.groupby(['Disease', 'Prediction up to year']):
    best_row = group.loc[group['AUC_numeric'].idxmax()]
    summary_results.append({
        'Task': f"{best_row['Disease']} - {best_row['Prediction up to year']}",
        'Best Algorithm': best_row['Algorithm'],
        'UK AUC': best_row['Test AUC on UK'],
        'FR AUC': best_row['Test AUC on FR']
    })

summary_df = pd.DataFrame(summary_results)
display(summary_df)

# Clean up temporary column
multi_algo_table = multi_algo_table.drop('AUC_numeric', axis=1)


In [None]:
def generate_performance_table_screen(diseases=['all_dementias', 'alzheimer'], 
                                      prediction_years=[2, 5, 10], 
                                      age=65, 
                                      include_charlson_bmi=True,
                                      exclude_mci_baseline=False,
                                      include_deceased_as_zeros=True,
                                      n_bootstrap=1000):
           """
           Generate a comprehensive performance table for different diseases and prediction horizons.
           
           Parameters:
           -----------
           diseases : list
               List of diseases to evaluate
           prediction_years : list
               List of prediction horizons in years
           age : int
               Age threshold (65 or 70)
           include_charlson_bmi : bool
               Whether to include BMI and CHARLSON features
           exclude_mci_baseline : bool
               Whether to exclude patients with MCI at baseline
           include_deceased_as_zeros : bool
               Whether to include deceased patients as negatives
           n_bootstrap : int
               Number of bootstrap samples for confidence intervals
               
           Returns:
           --------
           pd.DataFrame
               Performance table with all metrics
           """

           results = []
           screening_results = []  # Store data for screening analysis

           for disease in diseases:
               for pred_year in prediction_years:
                   print(f"\n{'='*60}")
                   print(f"Processing {disease} - {pred_year} years prediction")
                   print(f"{'='*60}")

                   row_data = {
                       'Disease': 'Dementia' if disease == 'all_dementias' else disease.capitalize(),
                       'Prediction up to year': f'{pred_year} years'
                   }

                   # Evaluate on both UK and FR
                   for country in ['UK', 'FR']:
                       print(f"\n--- Evaluating on {country} ---")

                       try:
                           # Load data
                           if country == 'UK':
                               # For UK: train and test on the same data (with train/val split)
                               X, y = get_data(country='UK', age=age, disease=disease,
                                              include_charlson_bmi=include_charlson_bmi,
                                              pred_up_to_year=pred_year,
                                              include_deceased_as_zeros=include_deceased_as_zeros)

                               # Exclude MCI if requested
                               if exclude_mci_baseline and 'mci_at_baseline' in X.columns:
                                   mask = X['mci_at_baseline'] == 0
                                   X = X[mask]
                                   y = y[mask]
                                   X = X.drop('mci_at_baseline', axis=1)
                                   print(f'Removed {(~mask).sum()} MCI patients at baseline')

                               # Split data
                               X_train, X_test, y_train, y_test = train_test_split(
                                   X, y, test_size=0.25, random_state=SEED)
                               X_train = X_train.fillna(0)
                               X_test = X_test.fillna(0)

                           else:  # FR
                               # For FR: train on UK, test on FR
                               print("Training on UK data...")
                               X_train, y_train = get_data(country='UK', age=age, disease=disease,
                                                          include_charlson_bmi=include_charlson_bmi,
                                                          pred_up_to_year=pred_year,
                                                          include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_train.columns:
                                   mask_train = X_train['mci_at_baseline'] == 0
                                   X_train = X_train[mask_train]
                                   y_train = y_train[mask_train]
                                   X_train = X_train.drop('mci_at_baseline', axis=1)

                               X_train = X_train.fillna(0)

                               print("Testing on FR data...")
                               X_test, y_test = get_data(country='FR', age=age, disease=disease,
                                                        include_charlson_bmi=include_charlson_bmi,
                                                        pred_up_to_year=pred_year,
                                                        include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_test.columns:
                                   mask_test = X_test['mci_at_baseline'] == 0
                                   X_test = X_test[mask_test]
                                   y_test = y_test[mask_test]
                                   X_test = X_test.drop('mci_at_baseline', axis=1)

                               X_test = X_test.fillna(0)

                           print(f"Train set: {len(X_train)} patients ({y_train.sum()} cases)")
                           print(f"Test set: {len(X_test)} patients ({y_test.sum()} cases)")

                           # Train model
                           model = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
                           model.fit(X_train, y_train)

                           # Get predictions
                           y_prob = model.predict_proba(X_test)[:, 1]

                           # Store data for screening analysis (for ALL diseases on UK validation set)
                           if country == 'UK':
                               screening_results.append({
                                   'disease': disease,
                                   'pred_year': pred_year,
                                   'y_test': y_test,
                                   'y_prob': y_prob,
                                   'n_patients': len(X_test),
                                   'n_cases': y_test.sum()
                               })

                           # Calculate AUC with CI
                           auc_score, lower_ci, upper_ci = get_auc_with_ci(y_prob, y_test, ci=0.95)

                           # Calculate detection rate for 5% FPR
                           def bootstrap_metric(y_true, y_scores, metric_func, n_bootstrap=n_bootstrap):
                               """Bootstrap confidence intervals for custom metrics"""
                               np.random.seed(SEED)
                               bootstrap_values = []
                               n_samples = len(y_true)

                               for _ in range(n_bootstrap):
                                   # Bootstrap sample
                                   indices = np.random.choice(n_samples, n_samples, replace=True)
                                   y_boot = y_true.iloc[indices] if hasattr(y_true, 'iloc') else y_true[indices]
                                   scores_boot = y_scores[indices]

                                   try:
                                       value = metric_func(y_boot, scores_boot)
                                       if not np.isnan(value):
                                           bootstrap_values.append(value)
                                   except:
                                       continue

                               if len(bootstrap_values) > 0:
                                   return (np.mean(bootstrap_values),
                                          np.percentile(bootstrap_values, 2.5),
                                          np.percentile(bootstrap_values, 97.5))
                               else:
                                   return (np.nan, np.nan, np.nan)

                           # Detection rate at 5% FPR
                           def detection_rate_at_5pct_fpr(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 5% FPR
                               target_fpr = 0.05
                               idx = np.argmax(fpr >= target_fpr)
                               if idx > 0:
                                   return tpr[idx] * 100  # Convert to percentage
                               return 0

                           # FPR at 50% detection rate
                           def fpr_at_50pct_detection(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 50% TPR
                               target_tpr = 0.5
                               idx = np.argmax(tpr >= target_tpr)
                               if idx < len(fpr):
                                   return fpr[idx] * 100  # Convert to percentage
                               return 100

                           # Calculate metrics with bootstrap CIs
                           det_rate_mean, det_rate_lower, det_rate_upper = bootstrap_metric(
                               y_test, y_prob, detection_rate_at_5pct_fpr)

                           fpr_mean, fpr_lower, fpr_upper = bootstrap_metric(
                               y_test, y_prob, fpr_at_50pct_detection)

                           # Format results
                           auc_str = f"{auc_score:.2f} ({lower_ci:.2f}-{upper_ci:.2f})"
                           det_rate_str = f"{det_rate_mean:.1f}% ({det_rate_lower:.1f}%-{det_rate_upper:.1f}%)"
                           fpr_str = f"{fpr_mean:.1f}% ({fpr_lower:.1f}%-{fpr_upper:.1f}%)"

                           if country == 'UK':
                               row_data['Test AUC on UK'] = auc_str
                               row_data['Detection rate for 5% FPR'] = det_rate_str
                               row_data['FPR for 50% detection rate'] = fpr_str
                           else:
                               row_data['Test AUC on FR'] = auc_str

                           print(f"{country} Results:")
                           print(f"  AUC: {auc_str}")
                           if country == 'UK':
                               print(f"  Detection rate at 5% FPR: {det_rate_str}")
                               print(f"  FPR at 50% detection: {fpr_str}")

                       except Exception as e:
                           print(f"Error processing {country}: {str(e)}")
                           if country == 'UK':
                               row_data['Test AUC on UK'] = 'N/A'
                               row_data['Detection rate for 5% FPR'] = 'N/A'
                               row_data['FPR for 50% detection rate'] = 'N/A'
                           else:
                               row_data['Test AUC on FR'] = 'N/A'

                   results.append(row_data)

           # Create DataFrame
           df = pd.DataFrame(results)

           # Reorder columns to match the desired format
           column_order = ['Disease', 'Prediction up to year', 'Test AUC on UK',
                          'Detection rate for 5% FPR', 'FPR for 50% detection rate', 'Test AUC on FR']
           df = df[column_order]

           # Calculate screening requirements for 80% detection rate for ALL diseases
           if screening_results:  # Only if we have screening data
               print("\n" + "="*100)
               print("SCREENING REQUIREMENTS FOR 80% DETECTION RATE")
               print("="*100)

               for screen_data in screening_results:
                   disease = screen_data['disease']
                   pred_year = screen_data['pred_year']
                   y_test = screen_data['y_test']
                   y_prob = screen_data['y_prob']
                   n_patients = screen_data['n_patients']
                   n_cases = screen_data['n_cases']

                   # Calculate threshold for 80% detection rate
                   fpr, tpr, thresholds = roc_curve(y_test, y_prob)
                   target_tpr = 0.80

                   # Find the index where TPR >= 80%
                   idx = np.argmax(tpr >= target_tpr)

                   if idx < len(thresholds) and tpr[idx] >= target_tpr:
                       threshold_80 = thresholds[idx]
                       fpr_80 = fpr[idx]
                       tpr_80 = tpr[idx]

                       # Calculate how many patients need to be screened
                       # FPR tells us the fraction of negative patients that will test positive
                       # TPR tells us the fraction of positive patients that will test positive
                       n_negative = n_patients - n_cases
                       n_positive = n_cases

                       # Expected number of positive tests
                       expected_positive_tests = (fpr_80 * n_negative) + (tpr_80 * n_positive)

                       # Scale to general population
                       prevalence = n_cases / n_patients

                       # If we want to detect 80% of patients with disease in a population,
                       # we need to screen all patients above the threshold
                       patients_to_screen_per_1000 = int(fpr_80 * 1000 + (tpr_80 * prevalence * 1000))
                       disease_detected_per_1000 = int(tpr_80 * prevalence * 1000)

                       disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()

                       print(f"\n{pred_year}-year {disease_name} prediction:")
                       print(f"  Threshold for 80% detection: {threshold_80:.3f}")
                       print(f"  Sensitivity (TPR): {tpr_80*100:.1f}%")
                       print(f"  FPR: {fpr_80*100:.1f}%")
                       print(f"  Prevalence in test set: {prevalence*100:.2f}% ({n_cases}/{n_patients})")
                       print(f"  ")
                       print(f"  Per 1,000 patients screened:")
                       print(f"    - Patients testing positive: {patients_to_screen_per_1000}")
                       print(f"    - {disease_name} patients detected: {disease_detected_per_1000}")
                       print(f"    - False positives: {patients_to_screen_per_1000 - disease_detected_per_1000}")
                       print(f"  ")
                       print(f"  Number needed to screen to detect 80% of {disease_name} cases:")
                       if disease_detected_per_1000 > 0:
                           nns = int(1000 / disease_detected_per_1000)
                           print(f"    - {nns} patients per {disease_name} case detected")
                       else:
                           print(f"    - Cannot calculate (no cases detected)")
                   else:
                       disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()
                       print(f"\n{pred_year}-year {disease_name} prediction:")
                       print(f"  Cannot achieve 80% detection rate with available data")
                       print(f"  Maximum achievable TPR: {max(tpr)*100:.1f}%")

           return df

       # Generate the performance table
print("Generating comprehensive performance table...")
print("This may take several minutes due to bootstrap calculations...")

performance_table = generate_performance_table_screen(
    diseases=['all_dementias', 'alzheimer'],
    prediction_years=[2, 5, 10],
    age=65,
    include_charlson_bmi=True,
    exclude_mci_baseline=False,
    include_deceased_as_zeros=True,
    n_bootstrap=1000
)

print("\n" + "="*100)
print("COMPREHENSIVE PERFORMANCE TABLE")
print("="*100)
display(performance_table)

# Save the table
output_file = os.path.join(OUTPUT_DIR, 'performance_table.csv')
performance_table.to_csv(output_file, index=False)
print(f"\nTable saved to: {output_file}")

In [None]:
def generate_performance_table_screen(diseases=['all_dementias', 'alzheimer'], 
                                      prediction_years=[2, 5, 10], 
                                      age=65, 
                                      include_charlson_bmi=True,
                                      exclude_mci_baseline=False,
                                      include_deceased_as_zeros=True,
                                      n_bootstrap=1000):
           """
           Generate a comprehensive performance table for different diseases and prediction horizons.
           
           Parameters:
           -----------
           diseases : list
               List of diseases to evaluate
           prediction_years : list
               List of prediction horizons in years
           age : int
               Age threshold (65 or 70)
           include_charlson_bmi : bool
               Whether to include BMI and CHARLSON features
           exclude_mci_baseline : bool
               Whether to exclude patients with MCI at baseline
           include_deceased_as_zeros : bool
               Whether to include deceased patients as negatives
           n_bootstrap : int
               Number of bootstrap samples for confidence intervals
               
           Returns:
           --------
           pd.DataFrame
               Performance table with all metrics
           """

           results = []
           screening_results = []  # Store data for screening analysis

           for disease in diseases:
               for pred_year in prediction_years:
                   print(f"\n{'='*60}")
                   print(f"Processing {disease} - {pred_year} years prediction")
                   print(f"{'='*60}")

                   row_data = {
                       'Disease': 'Dementia' if disease == 'all_dementias' else disease.capitalize(),
                       'Prediction up to year': f'{pred_year} years'
                   }

                   # Evaluate on both UK and FR
                   for country in ['UK', 'FR']:
                       print(f"\n--- Evaluating on {country} ---")

                       try:
                           # Load data
                           if country == 'UK':
                               # For UK: train and test on the same data (with train/val split)
                               X, y = get_data(country='UK', age=age, disease=disease,
                                              include_charlson_bmi=include_charlson_bmi,
                                              pred_up_to_year=pred_year,
                                              include_deceased_as_zeros=include_deceased_as_zeros)

                               # Exclude MCI if requested
                               if exclude_mci_baseline and 'mci_at_baseline' in X.columns:
                                   mask = X['mci_at_baseline'] == 0
                                   X = X[mask]
                                   y = y[mask]
                                   X = X.drop('mci_at_baseline', axis=1)
                                   print(f'Removed {(~mask).sum()} MCI patients at baseline')

                               # Split data
                               X_train, X_test, y_train, y_test = train_test_split(
                                   X, y, test_size=0.25, random_state=SEED)
                               X_train = X_train.fillna(0)
                               X_test = X_test.fillna(0)

                           else:  # FR
                               # For FR: train on UK, test on FR
                               print("Training on UK data...")
                               X_train, y_train = get_data(country='UK', age=age, disease=disease,
                                                          include_charlson_bmi=include_charlson_bmi,
                                                          pred_up_to_year=pred_year,
                                                          include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_train.columns:
                                   mask_train = X_train['mci_at_baseline'] == 0
                                   X_train = X_train[mask_train]
                                   y_train = y_train[mask_train]
                                   X_train = X_train.drop('mci_at_baseline', axis=1)

                               X_train = X_train.fillna(0)

                               print("Testing on FR data...")
                               X_test, y_test = get_data(country='FR', age=age, disease=disease,
                                                        include_charlson_bmi=include_charlson_bmi,
                                                        pred_up_to_year=pred_year,
                                                        include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_test.columns:
                                   mask_test = X_test['mci_at_baseline'] == 0
                                   X_test = X_test[mask_test]
                                   y_test = y_test[mask_test]
                                   X_test = X_test.drop('mci_at_baseline', axis=1)

                               X_test = X_test.fillna(0)

                           print(f"Train set: {len(X_train)} patients ({y_train.sum()} cases)")
                           print(f"Test set: {len(X_test)} patients ({y_test.sum()} cases)")

                           # Train model
                           model = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
                           model.fit(X_train, y_train)

                           # Get predictions
                           y_prob = model.predict_proba(X_test)[:, 1]

                           # Store data for screening analysis (for ALL diseases on UK validation set)
                           if country == 'UK':
                               screening_results.append({
                                   'disease': disease,
                                   'pred_year': pred_year,
                                   'y_test': y_test,
                                   'y_prob': y_prob,
                                   'n_patients': len(X_test),
                                   'n_cases': y_test.sum()
                               })

                           # Calculate AUC with CI
                           auc_score, lower_ci, upper_ci = get_auc_with_ci(y_prob, y_test, ci=0.95)

                           # Calculate detection rate for 5% FPR
                           def bootstrap_metric(y_true, y_scores, metric_func, n_bootstrap=n_bootstrap):
                               """Bootstrap confidence intervals for custom metrics"""
                               np.random.seed(SEED)
                               bootstrap_values = []
                               n_samples = len(y_true)

                               for _ in range(n_bootstrap):
                                   # Bootstrap sample
                                   indices = np.random.choice(n_samples, n_samples, replace=True)
                                   y_boot = y_true.iloc[indices] if hasattr(y_true, 'iloc') else y_true[indices]
                                   scores_boot = y_scores[indices]

                                   try:
                                       value = metric_func(y_boot, scores_boot)
                                       if not np.isnan(value):
                                           bootstrap_values.append(value)
                                   except:
                                       continue

                               if len(bootstrap_values) > 0:
                                   return (np.mean(bootstrap_values),
                                          np.percentile(bootstrap_values, 2.5),
                                          np.percentile(bootstrap_values, 97.5))
                               else:
                                   return (np.nan, np.nan, np.nan)

                           # Detection rate at 5% FPR
                           def detection_rate_at_5pct_fpr(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 5% FPR
                               target_fpr = 0.05
                               idx = np.argmax(fpr >= target_fpr)
                               if idx > 0:
                                   return tpr[idx] * 100  # Convert to percentage
                               return 0

                           # FPR at 50% detection rate
                           def fpr_at_50pct_detection(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 50% TPR
                               target_tpr = 0.5
                               idx = np.argmax(tpr >= target_tpr)
                               if idx < len(fpr):
                                   return fpr[idx] * 100  # Convert to percentage
                               return 100

                           # Calculate metrics with bootstrap CIs
                           det_rate_mean, det_rate_lower, det_rate_upper = bootstrap_metric(
                               y_test, y_prob, detection_rate_at_5pct_fpr)

                           fpr_mean, fpr_lower, fpr_upper = bootstrap_metric(
                               y_test, y_prob, fpr_at_50pct_detection)

                           # Format results
                           auc_str = f"{auc_score:.2f} ({lower_ci:.2f}-{upper_ci:.2f})"
                           det_rate_str = f"{det_rate_mean:.1f}% ({det_rate_lower:.1f}%-{det_rate_upper:.1f}%)"
                           fpr_str = f"{fpr_mean:.1f}% ({fpr_lower:.1f}%-{fpr_upper:.1f}%)"

                           if country == 'UK':
                               row_data['Test AUC on UK'] = auc_str
                               row_data['Detection rate for 5% FPR'] = det_rate_str
                               row_data['FPR for 50% detection rate'] = fpr_str
                           else:
                               row_data['Test AUC on FR'] = auc_str

                           print(f"{country} Results:")
                           print(f"  AUC: {auc_str}")
                           if country == 'UK':
                               print(f"  Detection rate at 5% FPR: {det_rate_str}")
                               print(f"  FPR at 50% detection: {fpr_str}")

                       except Exception as e:
                           print(f"Error processing {country}: {str(e)}")
                           if country == 'UK':
                               row_data['Test AUC on UK'] = 'N/A'
                               row_data['Detection rate for 5% FPR'] = 'N/A'
                               row_data['FPR for 50% detection rate'] = 'N/A'
                           else:
                               row_data['Test AUC on FR'] = 'N/A'

                   results.append(row_data)

           # Create DataFrame
           df = pd.DataFrame(results)

           # Reorder columns to match the desired format
           column_order = ['Disease', 'Prediction up to year', 'Test AUC on UK',
                          'Detection rate for 5% FPR', 'FPR for 50% detection rate', 'Test AUC on FR']
           df = df[column_order]

           # Calculate screening requirements for 80% detection rate for ALL diseases
           if screening_results:  # Only if we have screening data
               print("\n" + "="*100)
               print("SCREENING REQUIREMENTS FOR 80% DETECTION RATE")
               print("="*100)

               for screen_data in screening_results:
                   disease = screen_data['disease']
                   pred_year = screen_data['pred_year']
                   y_test = screen_data['y_test']
                   y_prob = screen_data['y_prob']
                   n_patients = screen_data['n_patients']
                   n_cases = screen_data['n_cases']

                   # Calculate threshold for 80% detection rate
                   fpr, tpr, thresholds = roc_curve(y_test, y_prob)
                   target_tpr = 0.80

                   # Find the index where TPR >= 80%
                   idx = np.argmax(tpr >= target_tpr)

                   if idx < len(thresholds) and tpr[idx] >= target_tpr:
                       threshold_80 = thresholds[idx]
                       fpr_80 = fpr[idx]
                       tpr_80 = tpr[idx]

                       # Calculate how many patients need to be screened
                       # FPR tells us the fraction of negative patients that will test positive
                       # TPR tells us the fraction of positive patients that will test positive
                       n_negative = n_patients - n_cases
                       n_positive = n_cases

                       # Expected number of positive tests
                       expected_positive_tests = (fpr_80 * n_negative) + (tpr_80 * n_positive)

                       # Scale to general population
                       prevalence = n_cases / n_patients

                       # If we want to detect 80% of patients with disease in a population,
                       # we need to screen all patients above the threshold
                       patients_to_screen_per_1000 = int(fpr_80 * 1000 + (tpr_80 * prevalence * 1000))
                       disease_detected_per_1000 = int(tpr_80 * prevalence * 1000)

                       disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()

                       print(f"\n{pred_year}-year {disease_name} prediction:")
                       print(f"  Threshold for 80% detection: {threshold_80:.3f}")
                       print(f"  Sensitivity (TPR): {tpr_80*100:.1f}%")
                       print(f"  FPR: {fpr_80*100:.1f}%")
                       print(f"  Prevalence in test set: {prevalence*100:.2f}% ({n_cases}/{n_patients})")
                       print(f"  ")
                       print(f"  Per 1,000 patients screened:")
                       print(f"    - Patients testing positive: {patients_to_screen_per_1000}")
                       print(f"    - {disease_name} patients detected: {disease_detected_per_1000}")
                       print(f"    - False positives: {patients_to_screen_per_1000 - disease_detected_per_1000}")
                       print(f"  ")
                       print(f"  Number needed to screen to detect 80% of {disease_name} cases:")
                       if disease_detected_per_1000 > 0:
                           nns = int(1000 / disease_detected_per_1000)
                           print(f"    - {nns} patients per {disease_name} case detected")
                       else:
                           print(f"    - Cannot calculate (no cases detected)")
                   else:
                       disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()
                       print(f"\n{pred_year}-year {disease_name} prediction:")
                       print(f"  Cannot achieve 80% detection rate with available data")
                       print(f"  Maximum achievable TPR: {max(tpr)*100:.1f}%")

               # NEW ANALYSIS: Top 1% highest scores
               print("\n" + "="*100)
               print("ANALYSIS: TOP 1% PATIENTS WITH HIGHEST PREDICTION SCORES")
               print("="*100)

               for screen_data in screening_results:
                   disease = screen_data['disease']
                   pred_year = screen_data['pred_year']
                   y_test = screen_data['y_test']
                   y_prob = screen_data['y_prob']
                   n_patients = screen_data['n_patients']
                   n_cases = screen_data['n_cases']

                   disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()

                   # Sort patients by prediction score (highest first)
                   sorted_indices = np.argsort(y_prob)[::-1]
                   sorted_y_test = y_test.iloc[sorted_indices] if hasattr(y_test, 'iloc') else y_test[sorted_indices]
                   sorted_y_prob = y_prob[sorted_indices]

                   # Calculate top 1% threshold
                   top_1_percent = max(1, int(0.01 * n_patients))  # At least 1 patient

                   # Get the top 1% patients
                   top_1_percent_labels = sorted_y_test[:top_1_percent]
                   top_1_percent_scores = sorted_y_prob[:top_1_percent]

                   # Calculate statistics
                   n_disease_in_top1 = top_1_percent_labels.sum()
                   precision_top1 = (n_disease_in_top1 / top_1_percent) * 100 if top_1_percent > 0 else 0
                   recall_top1 = (n_disease_in_top1 / n_cases) * 100 if n_cases > 0 else 0

                   # Calculate enrichment factor (vs random selection)
                   overall_prevalence = n_cases / n_patients
                   enrichment_factor = (precision_top1 / 100) / overall_prevalence if overall_prevalence > 0 else 0

                   print(f"\n{pred_year}-year {disease_name} prediction - Top 1% analysis:")
                   print(f"  Total patients in test set: {n_patients:,}")
                   print(f"  Total {disease_name.lower()} cases: {n_cases:,} ({overall_prevalence*100:.2f}%)")
                   print(f"  ")
                   print(f"  Top 1% highest scores ({top_1_percent:,} patients):")
                   print(f"    - {disease_name} cases found: {n_disease_in_top1}")
                   print(f"    - Precision (PPV): {precision_top1:.1f}%")
                   print(f"    - Sensitivity captured: {recall_top1:.1f}%")
                   print(f"    - Enrichment factor: {enrichment_factor:.1f}x")
                   print(f"    - Score range: {top_1_percent_scores.min():.3f} - {top_1_percent_scores.max():.3f}")
                   print(f"  ")
                   print(f"  Clinical interpretation:")
                   if n_disease_in_top1 > 0:
                       patients_per_case = top_1_percent / n_disease_in_top1
                       print(f"    - Need to screen {patients_per_case:.1f} high-risk patients to find 1 {disease_name.lower()} case")
                       print(f"    - In 1000 patients, top 10 would contain ~{(n_disease_in_top1/top_1_percent)*10:.1f} {disease_name.lower()} cases")
                   else:
                       print(f"    - No {disease_name.lower()} cases found in top 1% - model may need improvement")

           return df

# Generate the performance table
print("Generating comprehensive performance table...")
print("This may take several minutes due to bootstrap calculations...")

performance_table = generate_performance_table_screen(
    diseases=['all_dementias', 'alzheimer'],
    prediction_years=[2, 5, 10],
    age=65,
    include_charlson_bmi=True,
    exclude_mci_baseline=False,
    include_deceased_as_zeros=True,
    n_bootstrap=1000
)

print("\n" + "="*100)
print("COMPREHENSIVE PERFORMANCE TABLE")
print("="*100)
display(performance_table)

# Save the table
output_file = os.path.join(OUTPUT_DIR, 'performance_table.csv')
performance_table.to_csv(output_file, index=False)
print(f"\nTable saved to: {output_file}")

In [None]:
def generate_performance_table_screen(diseases=['all_dementias', 'alzheimer'], 
                                      prediction_years=[2, 5, 10], 
                                      age=65, 
                                      include_charlson_bmi=True,
                                      exclude_mci_baseline=False,
                                      include_deceased_as_zeros=True,
                                      n_bootstrap=1000):
           """
           Generate a comprehensive performance table for different diseases and prediction horizons.
           
           Parameters:
           -----------
           diseases : list
               List of diseases to evaluate
           prediction_years : list
               List of prediction horizons in years
           age : int
               Age threshold (65 or 70)
           include_charlson_bmi : bool
               Whether to include BMI and CHARLSON features
           exclude_mci_baseline : bool
               Whether to exclude patients with MCI at baseline
           include_deceased_as_zeros : bool
               Whether to include deceased patients as negatives
           n_bootstrap : int
               Number of bootstrap samples for confidence intervals
               
           Returns:
           --------
           pd.DataFrame
               Performance table with all metrics
           """

           results = []
           screening_results = []  # Store data for screening analysis

           for disease in diseases:
               for pred_year in prediction_years:
                   print(f"\n{'='*60}")
                   print(f"Processing {disease} - {pred_year} years prediction")
                   print(f"{'='*60}")

                   row_data = {
                       'Disease': 'Dementia' if disease == 'all_dementias' else disease.capitalize(),
                       'Prediction up to year': f'{pred_year} years'
                   }

                   # Evaluate on both UK and FR
                   for country in ['UK', 'FR']:
                       print(f"\n--- Evaluating on {country} ---")

                       try:
                           # Load data
                           if country == 'UK':
                               # For UK: train and test on the same data (with train/val split)
                               X, y = get_data(country='UK', age=age, disease=disease,
                                              include_charlson_bmi=include_charlson_bmi,
                                              pred_up_to_year=pred_year,
                                              include_deceased_as_zeros=include_deceased_as_zeros)

                               # Exclude MCI if requested
                               if exclude_mci_baseline and 'mci_at_baseline' in X.columns:
                                   mask = X['mci_at_baseline'] == 0
                                   X = X[mask]
                                   y = y[mask]
                                   X = X.drop('mci_at_baseline', axis=1)
                                   print(f'Removed {(~mask).sum()} MCI patients at baseline')

                               # Split data
                               X_train, X_test, y_train, y_test = train_test_split(
                                   X, y, test_size=0.25, random_state=SEED)
                               X_train = X_train.fillna(0)
                               X_test = X_test.fillna(0)

                           else:  # FR
                               # For FR: train on UK, test on FR
                               print("Training on UK data...")
                               X_train, y_train = get_data(country='UK', age=age, disease=disease,
                                                          include_charlson_bmi=include_charlson_bmi,
                                                          pred_up_to_year=pred_year,
                                                          include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_train.columns:
                                   mask_train = X_train['mci_at_baseline'] == 0
                                   X_train = X_train[mask_train]
                                   y_train = y_train[mask_train]
                                   X_train = X_train.drop('mci_at_baseline', axis=1)

                               X_train = X_train.fillna(0)

                               print("Testing on FR data...")
                               X_test, y_test = get_data(country='FR', age=age, disease=disease,
                                                        include_charlson_bmi=include_charlson_bmi,
                                                        pred_up_to_year=pred_year,
                                                        include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_test.columns:
                                   mask_test = X_test['mci_at_baseline'] == 0
                                   X_test = X_test[mask_test]
                                   y_test = y_test[mask_test]
                                   X_test = X_test.drop('mci_at_baseline', axis=1)

                               X_test = X_test.fillna(0)

                           print(f"Train set: {len(X_train)} patients ({y_train.sum()} cases)")
                           print(f"Test set: {len(X_test)} patients ({y_test.sum()} cases)")

                           # Train model
                           model = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
                           model.fit(X_train, y_train)

                           # Get predictions
                           y_prob = model.predict_proba(X_test)[:, 1]

                           # Store data for screening analysis (for ALL diseases on UK validation set)
                           if country == 'UK':
                               screening_results.append({
                                   'disease': disease,
                                   'pred_year': pred_year,
                                   'y_test': y_test,
                                   'y_prob': y_prob,
                                   'n_patients': len(X_test),
                                   'n_cases': y_test.sum()
                               })

                           # Calculate AUC with CI
                           auc_score, lower_ci, upper_ci = get_auc_with_ci(y_prob, y_test, ci=0.95)

                           # Calculate detection rate for 5% FPR
                           def bootstrap_metric(y_true, y_scores, metric_func, n_bootstrap=n_bootstrap):
                               """Bootstrap confidence intervals for custom metrics"""
                               np.random.seed(SEED)
                               bootstrap_values = []
                               n_samples = len(y_true)

                               for _ in range(n_bootstrap):
                                   # Bootstrap sample
                                   indices = np.random.choice(n_samples, n_samples, replace=True)
                                   y_boot = y_true.iloc[indices] if hasattr(y_true, 'iloc') else y_true[indices]
                                   scores_boot = y_scores[indices]

                                   try:
                                       value = metric_func(y_boot, scores_boot)
                                       if not np.isnan(value):
                                           bootstrap_values.append(value)
                                   except:
                                       continue

                               if len(bootstrap_values) > 0:
                                   return (np.mean(bootstrap_values),
                                          np.percentile(bootstrap_values, 2.5),
                                          np.percentile(bootstrap_values, 97.5))
                               else:
                                   return (np.nan, np.nan, np.nan)

                           # Detection rate at 5% FPR
                           def detection_rate_at_5pct_fpr(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 5% FPR
                               target_fpr = 0.05
                               idx = np.argmax(fpr >= target_fpr)
                               if idx > 0:
                                   return tpr[idx] * 100  # Convert to percentage
                               return 0

                           # FPR at 50% detection rate
                           def fpr_at_50pct_detection(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 50% TPR
                               target_tpr = 0.5
                               idx = np.argmax(tpr >= target_tpr)
                               if idx < len(fpr):
                                   return fpr[idx] * 100  # Convert to percentage
                               return 100

                           # Calculate metrics with bootstrap CIs
                           det_rate_mean, det_rate_lower, det_rate_upper = bootstrap_metric(
                               y_test, y_prob, detection_rate_at_5pct_fpr)

                           fpr_mean, fpr_lower, fpr_upper = bootstrap_metric(
                               y_test, y_prob, fpr_at_50pct_detection)

                           # Format results
                           auc_str = f"{auc_score:.2f} ({lower_ci:.2f}-{upper_ci:.2f})"
                           det_rate_str = f"{det_rate_mean:.1f}% ({det_rate_lower:.1f}%-{det_rate_upper:.1f}%)"
                           fpr_str = f"{fpr_mean:.1f}% ({fpr_lower:.1f}%-{fpr_upper:.1f}%)"

                           if country == 'UK':
                               row_data['Test AUC on UK'] = auc_str
                               row_data['Detection rate for 5% FPR'] = det_rate_str
                               row_data['FPR for 50% detection rate'] = fpr_str
                           else:
                               row_data['Test AUC on FR'] = auc_str

                           print(f"{country} Results:")
                           print(f"  AUC: {auc_str}")
                           if country == 'UK':
                               print(f"  Detection rate at 5% FPR: {det_rate_str}")
                               print(f"  FPR at 50% detection: {fpr_str}")

                       except Exception as e:
                           print(f"Error processing {country}: {str(e)}")
                           if country == 'UK':
                               row_data['Test AUC on UK'] = 'N/A'
                               row_data['Detection rate for 5% FPR'] = 'N/A'
                               row_data['FPR for 50% detection rate'] = 'N/A'
                           else:
                               row_data['Test AUC on FR'] = 'N/A'

                   results.append(row_data)

           # Create DataFrame
           df = pd.DataFrame(results)

           # Reorder columns to match the desired format
           column_order = ['Disease', 'Prediction up to year', 'Test AUC on UK',
                          'Detection rate for 5% FPR', 'FPR for 50% detection rate', 'Test AUC on FR']
           df = df[column_order]

           # Calculate screening requirements for 80% detection rate for ALL diseases
           if screening_results:  # Only if we have screening data
               print("\n" + "="*100)
               print("SCREENING REQUIREMENTS FOR 80% DETECTION RATE")
               print("="*100)

               for screen_data in screening_results:
                   disease = screen_data['disease']
                   pred_year = screen_data['pred_year']
                   y_test = screen_data['y_test']
                   y_prob = screen_data['y_prob']
                   n_patients = screen_data['n_patients']
                   n_cases = screen_data['n_cases']

                   # Calculate threshold for 80% detection rate
                   fpr, tpr, thresholds = roc_curve(y_test, y_prob)
                   target_tpr = 0.80

                   # Find the index where TPR >= 80%
                   idx = np.argmax(tpr >= target_tpr)

                   if idx < len(thresholds) and tpr[idx] >= target_tpr:
                       threshold_80 = thresholds[idx]
                       fpr_80 = fpr[idx]
                       tpr_80 = tpr[idx]

                       # Calculate how many patients need to be screened
                       # FPR tells us the fraction of negative patients that will test positive
                       # TPR tells us the fraction of positive patients that will test positive
                       n_negative = n_patients - n_cases
                       n_positive = n_cases

                       # Expected number of positive tests
                       expected_positive_tests = (fpr_80 * n_negative) + (tpr_80 * n_positive)

                       # Scale to general population
                       prevalence = n_cases / n_patients

                       # If we want to detect 80% of patients with disease in a population,
                       # we need to screen all patients above the threshold
                       patients_to_screen_per_1000 = int(fpr_80 * 1000 + (tpr_80 * prevalence * 1000))
                       disease_detected_per_1000 = int(tpr_80 * prevalence * 1000)

                       disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()

                       print(f"\n{pred_year}-year {disease_name} prediction:")
                       print(f"  Threshold for 80% detection: {threshold_80:.3f}")
                       print(f"  Sensitivity (TPR): {tpr_80*100:.1f}%")
                       print(f"  FPR: {fpr_80*100:.1f}%")
                       print(f"  Prevalence in test set: {prevalence*100:.2f}% ({n_cases}/{n_patients})")
                       print(f"  ")
                       print(f"  Per 1,000 patients screened:")
                       print(f"    - Patients testing positive: {patients_to_screen_per_1000}")
                       print(f"    - {disease_name} patients detected: {disease_detected_per_1000}")
                       print(f"    - False positives: {patients_to_screen_per_1000 - disease_detected_per_1000}")
                       print(f"  ")
                       print(f"  Number needed to screen to detect 80% of {disease_name} cases:")
                       if disease_detected_per_1000 > 0:
                           nns = int(1000 / disease_detected_per_1000)
                           print(f"    - {nns} patients per {disease_name} case detected")
                       else:
                           print(f"    - Cannot calculate (no cases detected)")
                   else:
                       disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()
                       print(f"\n{pred_year}-year {disease_name} prediction:")
                       print(f"  Cannot achieve 80% detection rate with available data")
                       print(f"  Maximum achievable TPR: {max(tpr)*100:.1f}%")

               # ENRICHMENT ANALYSIS: Algorithm vs Random Selection for Blood Test Decision
               print("\n" + "="*120)
               print("BLOOD TEST ENRICHMENT ANALYSIS: Algorithm vs Random Selection")
               print("Scenario: GP selects 1% of patients for expensive blood test")
               print("="*120)

               enrichment_table_data = []

               for screen_data in screening_results:
                   disease = screen_data['disease']
                   pred_year = screen_data['pred_year']
                   y_test = screen_data['y_test']
                   y_prob = screen_data['y_prob']
                   n_patients = screen_data['n_patients']
                   n_cases = screen_data['n_cases']

                   disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()

                   # Overall prevalence in the population
                   overall_prevalence = n_cases / n_patients

                   # Calculate top 1% threshold
                   top_1_percent = max(1, int(0.01 * n_patients))  # At least 1 patient

                   # SCENARIO 1: Random selection (baseline)
                   # If GP randomly selects 1% for blood test
                   random_cases_expected = overall_prevalence * top_1_percent

                   # SCENARIO 2: Algorithm-guided selection
                   # Sort patients by prediction score (highest first)
                   sorted_indices = np.argsort(y_prob)[::-1]
                   sorted_y_test = y_test.iloc[sorted_indices] if hasattr(y_test, 'iloc') else y_test[sorted_indices]
                   sorted_y_prob = y_prob[sorted_indices]

                   # Get the top 1% patients according to algorithm
                   top_1_percent_labels = sorted_y_test[:top_1_percent]
                   top_1_percent_scores = sorted_y_prob[:top_1_percent]

                   # Calculate actual cases found
                   algorithm_cases_found = top_1_percent_labels.sum()

                   # Calculate enrichment factor
                   enrichment_factor = algorithm_cases_found / random_cases_expected if random_cases_expected > 0 else 0

                   # Calculate precision for both scenarios
                   random_precision = (random_cases_expected / top_1_percent) * 100
                   algorithm_precision = (algorithm_cases_found / top_1_percent) * 100

                   # Calculate cost-effectiveness metrics
                   cost_per_case_random = top_1_percent / random_cases_expected if random_cases_expected > 0 else float('inf')
                   cost_per_case_algorithm = top_1_percent / algorithm_cases_found if algorithm_cases_found > 0 else float('inf')
                   cost_reduction = ((cost_per_case_random - cost_per_case_algorithm) / cost_per_case_random) * 100 if cost_per_case_random > 0 else 0

                   # Store results for table
                   enrichment_table_data.append({
                       'Disease': disease_name,
                       'Prediction Horizon': f'{pred_year} years',
                       'Population Prevalence': f'{overall_prevalence*100:.2f}%',
                       'Random Selection (1%)': f'{random_cases_expected:.1f} cases',
                       'Algorithm Selection (1%)': f'{algorithm_cases_found} cases',
                       'Enrichment Factor': f'{enrichment_factor:.1f}x',
                       'Random Precision': f'{random_precision:.1f}%',
                       'Algorithm Precision': f'{algorithm_precision:.1f}%',
                       'Blood Tests per Case (Random)': f'{cost_per_case_random:.1f}',
                       'Blood Tests per Case (Algorithm)': f'{cost_per_case_algorithm:.1f}',
                       'Cost Reduction': f'{cost_reduction:.1f}%'
                   })

                   print(f"\n{pred_year}-year {disease_name} prediction:")
                   print(f"  Population: {n_patients:,} patients, {n_cases} cases ({overall_prevalence*100:.2f}% prevalence)")
                   print(f"  Blood test budget: 1% = {top_1_percent} tests")
                   print(f"  ")
                   print(f"  RANDOM SELECTION (baseline):")
                   print(f"    - Expected cases found: {random_cases_expected:.1f}")
                   print(f"    - Precision: {random_precision:.1f}%")
                   print(f"    - Blood tests per case: {cost_per_case_random:.1f}")
                   print(f"  ")
                   print(f"  ALGORITHM SELECTION:")
                   print(f"    - Actual cases found: {algorithm_cases_found}")
                   print(f"    - Precision: {algorithm_precision:.1f}%")
                   print(f"    - Blood tests per case: {cost_per_case_algorithm:.1f}")
                   print(f"  ")
                   print(f"  IMPROVEMENT:")
                   print(f"    - Enrichment factor: {enrichment_factor:.1f}x")
                   print(f"    - Cost reduction: {cost_reduction:.1f}%")
                   print(f"    - Additional cases found: {algorithm_cases_found - random_cases_expected:.1f}")

               # Create enrichment comparison table
               enrichment_df = pd.DataFrame(enrichment_table_data)

               print(f"\n{'='*150}")
               print("BLOOD TEST ENRICHMENT COMPARISON TABLE")
               print(f"{'='*150}")
               display(enrichment_df)

               # Save enrichment table
               enrichment_output_file = os.path.join(OUTPUT_DIR, 'blood_test_enrichment_table.csv')
               enrichment_df_transposed = enrichment_df.T
               enrichment_df_transposed.reset_index(inplace=True)
               enrichment_df_transposed.rename(columns={'index': 'Metric'}, inplace=True)
               enrichment_df_transposed.to_csv(enrichment_output_file, index=False)
               print(f"\nEnrichment table saved to: {enrichment_output_file}")

               # Summary statistics
               print(f"\n{'='*100}")
               print("SUMMARY: Blood Test Cost-Effectiveness")
               print(f"{'='*100}")

               avg_enrichment = enrichment_df['Enrichment Factor'].str.replace('x', '').astype(float).mean()
               avg_cost_reduction = enrichment_df['Cost Reduction'].str.replace('%', '').astype(float).mean()

               print(f"Average enrichment factor across all tasks: {avg_enrichment:.1f}x")
               print(f"Average cost reduction: {avg_cost_reduction:.1f}%")
               print(f"")
               print(f"Clinical Impact:")
               print(f"  - Algorithm identifies {avg_enrichment:.1f}x more cases than random selection")
               print(f"  - Reduces blood test costs by {avg_cost_reduction:.1f}% for same detection rate")
               print(f"  - Enables precision medicine approach for expensive diagnostic tests")

           return df

# Generate the performance table
print("Generating comprehensive performance table...")
print("This will take several minutes due to bootstrap calculations...")

performance_table = generate_performance_table_screen(
    diseases=['all_dementias', 'alzheimer'],
    prediction_years=[2, 5, 10],
    age=65,
    include_charlson_bmi=True,
    exclude_mci_baseline=False,
    include_deceased_as_zeros=True,
    n_bootstrap=1000
)

print("\n" + "="*100)
print("COMPREHENSIVE PERFORMANCE TABLE")
print("="*100)
display(performance_table)

# Save the table
output_file = os.path.join(OUTPUT_DIR, 'performance_table.csv')
performance_table.to_csv(output_file, index=False)
print(f"\nTable saved to: {output_file}")

In [None]:
def generate_performance_table(diseases=['all_dementias', 'alzheimer'], 
                                             prediction_years=[2, 5, 10], 
                                             age=65, 
                                             include_charlson_bmi=True,
                                             exclude_mci_baseline=False,
                                             include_deceased_as_zeros=True,
                                             n_bootstrap=1000):
           """
           Generate a comprehensive performance table for different diseases and prediction horizons.
           
           Parameters:
           -----------
           diseases : list
               List of diseases to evaluate
           prediction_years : list
               List of prediction horizons in years
           age : int
               Age threshold (65 or 70)
           include_charlson_bmi : bool
               Whether to include BMI and CHARLSON features
           exclude_mci_baseline : bool
               Whether to exclude patients with MCI at baseline
           include_deceased_as_zeros : bool
               Whether to include deceased patients as negatives
           n_bootstrap : int
               Number of bootstrap samples for confidence intervals
               
           Returns:
           --------
           pd.DataFrame
               Performance table with all metrics
           """

           results = []
           screening_results = []  # Store data for screening analysis

           for disease in diseases:
               for pred_year in prediction_years:
                   print(f"\n{'='*60}")
                   print(f"Processing {disease} - {pred_year} years prediction")
                   print(f"{'='*60}")

                   row_data = {
                       'Disease': 'Dementia' if disease == 'all_dementias' else disease.capitalize(),
                       'Prediction up to year': f'{pred_year} years'
                   }

                   # Evaluate on both UK and FR
                   for country in ['UK', 'FR']:
                       print(f"\n--- Evaluating on {country} ---")

                       try:
                           # Load data
                           if country == 'UK':
                               # For UK: train and test on the same data (with train/val split)
                               X, y = get_data(country='UK', age=age, disease=disease,
                                               include_charlson_bmi=include_charlson_bmi,
                                               pred_up_to_year=pred_year,
                                               include_deceased_as_zeros=include_deceased_as_zeros)

                               # Exclude MCI if requested
                               if exclude_mci_baseline and 'mci_at_baseline' in X.columns:
                                   mask = X['mci_at_baseline'] == 0
                                   X = X[mask]
                                   y = y[mask]
                                   X = X.drop('mci_at_baseline', axis=1)
                                   print(f'Removed {(~mask).sum()} MCI patients at baseline')

                               # Split data
                               X_train, X_test, y_train, y_test = split(X, y)
                               X_train = X_train.fillna(0)
                               X_test = X_test.fillna(0)

                           else:  # FR
                               # For FR: train on UK, test on FR
                               print("Training on UK data...")
                               X_train, y_train = get_data(country='UK', age=age, disease=disease,
                                                           include_charlson_bmi=include_charlson_bmi,
                                                           pred_up_to_year=pred_year,
                                                           include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_train.columns:
                                   mask_train = X_train['mci_at_baseline'] == 0
                                   X_train = X_train[mask_train]
                                   y_train = y_train[mask_train]
                                   X_train = X_train.drop('mci_at_baseline', axis=1)

                               X_train = X_train.fillna(0)

                               print("Testing on FR data...")
                               X_test, y_test = get_data(country='FR', age=age, disease=disease,
                                                       include_charlson_bmi=include_charlson_bmi,
                                                       pred_up_to_year=pred_year,
                                                       include_deceased_as_zeros=include_deceased_as_zeros)

                               if exclude_mci_baseline and 'mci_at_baseline' in X_test.columns:
                                   mask_test = X_test['mci_at_baseline'] == 0
                                   X_test = X_test[mask_test]
                                   y_test = y_test[mask_test]
                                   X_test = X_test.drop('mci_at_baseline', axis=1)

                               X_test = X_test.fillna(0)

                           print(f"Train set: {len(X_train)} patients ({y_train.sum()} cases)")
                           print(f"Test set: {len(X_test)} patients ({y_test.sum()} cases)")

                           # Train model
                           model = LogisticRegression(class_weight='balanced', random_state=SEED, max_iter=1_000)
                           model.fit(X_train, y_train)

                           # Get predictions
                           y_prob = model.predict_proba(X_test)[:, 1]

                           # Store data for screening analysis (for ALL diseases on UK validation set)
                           if country == 'UK':
                               screening_results.append({
                                   'disease': disease,
                                   'pred_year': pred_year,
                                   'y_test': y_test,
                                   'y_prob': y_prob,
                                   'n_patients': len(y_test),
                                   'n_cases': y_test.sum()
                               })

                           # Calculate ROC AUC with CI
                           roc_auc_score, roc_lower_ci, roc_upper_ci = get_auc_with_ci(y_prob, y_test, ci=0.95)

                           # Calculate Brier score for calibration
                           from sklearn.metrics import brier_score_loss
                           brier_score = brier_score_loss(y_test, y_prob)

                           # Calculate detection rate for 5% FPR
                           def bootstrap_metric(y_true, y_scores, metric_func, n_bootstrap=n_bootstrap):
                               """Bootstrap confidence intervals for custom metrics"""
                               np.random.seed(SEED)
                               bootstrap_values = []
                               n_samples = len(y_true)

                               for _ in range(n_bootstrap):
                                   # Bootstrap sample
                                   indices = np.random.choice(n_samples, n_samples, replace=True)
                                   y_boot = y_true.iloc[indices] if hasattr(y_true, 'iloc') else y_true[indices]
                                   scores_boot = y_scores[indices]

                                   try:
                                       value = metric_func(y_boot, scores_boot)
                                       if not np.isnan(value):
                                           bootstrap_values.append(value)
                                   except:
                                       continue

                               if len(bootstrap_values) > 0:
                                   return (np.mean(bootstrap_values),
                                           np.percentile(bootstrap_values, 2.5),
                                           np.percentile(bootstrap_values, 97.5))
                               else:
                                   return (np.nan, np.nan, np.nan)

                           # Detection rate at 5% FPR
                           def detection_rate_at_5pct_fpr(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 5% FPR
                               target_fpr = 0.05
                               idx = np.argmax(fpr >= target_fpr)
                               if idx > 0:
                                   return tpr[idx] * 100  # Convert to percentage
                               return 0

                           # FPR at 50% detection rate
                           def fpr_at_50pct_detection(y_true, y_scores):
                               fpr, tpr, thresholds = roc_curve(y_true, y_scores)
                               # Find threshold for 50% TPR
                               target_tpr = 0.5
                               idx = np.argmax(tpr >= target_tpr)
                               if idx < len(fpr):
                                   return fpr[idx] * 100  # Convert to percentage
                               return 100

                           # Calculate metrics with bootstrap CIs
                           det_rate_mean, det_rate_lower, det_rate_upper = bootstrap_metric(
                               y_test, y_prob, detection_rate_at_5pct_fpr)

                           fpr_mean, fpr_lower, fpr_upper = bootstrap_metric(
                               y_test, y_prob, fpr_at_50pct_detection)

                           # Format results
                           roc_auc_str = f"{roc_auc_score:.2f} ({roc_lower_ci:.2f}-{roc_upper_ci:.2f})"
                           brier_str = f"{brier_score:.3f}"
                           det_rate_str = f"{det_rate_mean:.1f}% ({det_rate_lower:.1f}%-{det_rate_upper:.1f}%)"
                           fpr_str = f"{fpr_mean:.1f}% ({fpr_lower:.1f}%-{fpr_upper:.1f}%)"

                           if country == 'UK':
                               row_data['ROC AUC on UK'] = roc_auc_str
                               row_data['Brier Score on UK'] = brier_str
                               row_data['Detection rate for 5% FPR'] = det_rate_str
                               row_data['FPR for 50% detection rate'] = fpr_str

                               # Calculate precision of top 1% and lift based on actual dataset prevalence
                               dataset_prevalence = y_test.sum() / len(y_test)  # Actual prevalence in test set

                               n_patients = len(y_test)
                               top_1_percent_size = max(1, int(0.01 * n_patients))  # At least 1 patient

                               # Sort patients by prediction score (highest first)
                               sorted_indices = np.argsort(y_prob)[::-1]
                               y_sorted = y_test.iloc[sorted_indices] if hasattr(y_test, 'iloc') else y_test[sorted_indices]

                               # Get top 1% patients
                               top_1_percent_patients = y_sorted[:top_1_percent_size]
                               top_1_percent_cases = top_1_percent_patients.sum()

                               # Calculate precision of top 1%
                               precision_top_1_percent = top_1_percent_cases / top_1_percent_size if top_1_percent_size > 0 else 0

                               # Calculate lift compared to actual dataset prevalence
                               lift = precision_top_1_percent / dataset_prevalence if dataset_prevalence > 0 else 0

                               row_data['Precision top 1%'] = f"{precision_top_1_percent*100:.1f}%"
                               row_data['Lift (vs dataset prevalence)'] = f"{lift:.1f}x"

                           else:
                               row_data['ROC AUC on FR'] = roc_auc_str
                               row_data['Brier Score on FR'] = brier_str

                           print(f"{country} Results:")
                           print(f"  ROC AUC: {roc_auc_str}")
                           print(f"  Brier Score: {brier_str}")
                           if country == 'UK':
                               print(f"  Detection rate at 5% FPR: {det_rate_str}")
                               print(f"  FPR at 50% detection: {fpr_str}")

                       except Exception as e:
                           print(f"Error processing {country}: {str(e)}")
                           if country == 'UK':
                               row_data['ROC AUC on UK'] = 'N/A'
                               row_data['Brier Score on UK'] = 'N/A'
                               row_data['Detection rate for 5% FPR'] = 'N/A'
                               row_data['FPR for 50% detection rate'] = 'N/A'
                               row_data['Precision top 1%'] = 'N/A'
                               row_data['Lift (vs dataset prevalence)'] = 'N/A'
                           else:
                               row_data['ROC AUC on FR'] = 'N/A'
                               row_data['Brier Score on FR'] = 'N/A'

                   results.append(row_data)

           # Create DataFrame
           df = pd.DataFrame(results)

           # Reorder columns to match the desired format
           column_order = ['Disease', 'Prediction up to year', 'ROC AUC on UK', 'Brier Score on UK',
                           'Detection rate for 5% FPR', 'FPR for 50% detection rate',
                           'ROC AUC on FR', 'Brier Score on FR',
                           'Precision top 1%', 'Lift (vs dataset prevalence)']
           df = df[column_order]

           # Calculate screening requirements for 80% detection rate for ALL diseases
           if screening_results:  # Only if we have screening data
               print("\n" + "="*100)
               print("SCREENING REQUIREMENTS FOR 80% DETECTION RATE")
               print("="*100)

               for screen_data in screening_results:
                   disease = screen_data['disease']
                   pred_year = screen_data['pred_year']
                   y_test = screen_data['y_test']
                   y_prob = screen_data['y_prob']
                   n_patients = screen_data['n_patients']
                   n_cases = screen_data['n_cases']

                   # Calculate threshold for 80% detection rate
                   fpr, tpr, thresholds = roc_curve(y_test, y_prob)
                   target_tpr = 0.80

                   # Find the index where TPR >= 80%
                   idx = np.argmax(tpr >= target_tpr)

                   disease_name = 'Dementia' if disease == 'all_dementias' else disease.capitalize()

                   if idx < len(thresholds) and tpr[idx] >= target_tpr:
                       threshold_80 = thresholds[idx]
                       fpr_80 = fpr[idx]
                       tpr_80 = tpr[idx]

                       # Calculate how many patients need to be screened
                       prevalence = n_cases / n_patients

                       # Number of patients needed to screen per case detected at 80% sensitivity
                       patients_to_screen_per_case = 1 / (tpr_80 * prevalence) if (tpr_80 * prevalence) > 0 else float('inf')

                       print(f"\n{disease_name} - {pred_year}-year prediction:")
                       print(f"  Threshold for 80% detection: {threshold_80:.3f}")
                       print(f"  Sensitivity (TPR): {tpr_80*100:.1f}%")
                       print(f"  FPR: {fpr_80*100:.1f}%")
                       print(f"  Prevalence in test set: {prevalence*100:.2f}% ({n_cases}/{n_patients})")
                       print(f"  Number needed to screen: {patients_to_screen_per_case:.0f} patients per case detected")
                   else:
                       print(f"\n{disease_name} - {pred_year}-year prediction:")
                       print(f"  Cannot achieve 80% detection rate with available data")
                       print(f"  Maximum achievable TPR: {max(tpr)*100:.1f}%")

           return df



# Generate the performance table
print("Generating comprehensive performance table...")
print("This will take several minutes due to bootstrap calculations...")

performance_table = generate_performance_table(
    diseases=['all_dementias', 'alzheimer'],
    prediction_years=[2, 5, 10],
    age=65,
    include_charlson_bmi=True,
    exclude_mci_baseline=False,
    include_deceased_as_zeros=True,
    n_bootstrap=1000
)

print("\n" + "="*100)
print("COMPREHENSIVE PERFORMANCE TABLE")
print("="*100)
display(performance_table)

# Save the table
output_file = os.path.join(OUTPUT_DIR, 'performance_table.csv')
performance_table.to_csv(output_file, index=False)
print(f"\nTable saved to: {output_file}")