In [1]:
import multiprocessing as mp
from pathlib import Path
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, auc, roc_curve, mean_absolute_error, r2_score, roc_auc_score
from sklearn.preprocessing import label_binarize
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import re
from lifelines.utils import concordance_index
import os
import yaml

In [2]:
base_path = os.path.join(os.getcwd().split('pretrain_comparison/evaluation')[0], 'pretrain_comparison/output/results')
files_list = list(Path(base_path).rglob('*.npy'))

In [3]:
def calculate_metrics_per_condition(times, presence, hazard_ratios):
    """
    Calculate metrics for each medical condition with optional bootstrapping.
    
    Parameters
    ----------
    times : np.ndarray
        Time-to-event data, shape (n_samples,) or (n_samples, n_conditions)
    presence : np.ndarray
        Binary indicator of event occurrence, same shape as times
    hazard_ratios : np.ndarray
        Predicted hazard ratios, same shape as times
    
    Returns
    -------
    metrics : list of dicts
        Each dict contains c_index and n_events
    """
    # Ensure 2D shape
    if times.ndim == 1:
        times = times[:, np.newaxis]
        presence = presence[:, np.newaxis]
        hazard_ratios = hazard_ratios[:, np.newaxis]


    n_conditions = times.shape[1]
    metrics = []

    for i in range(n_conditions):
        condition_times = times[:, i]
        condition_presence = presence[:, i]
        condition_hazards = hazard_ratios[:, i]

        # Set presence = 0 if time <= 0.001
        condition_presence = np.where(condition_times > 0.001, condition_presence, 0)

        n_events = np.sum(condition_presence)

        try:
            valid_mask = ~(np.isnan(condition_times) | np.isnan(condition_hazards) | np.isnan(condition_presence))

            c_index = concordance_index(
                condition_times[valid_mask],
                -condition_hazards[valid_mask],
                condition_presence[valid_mask]
            )

            c_index_ci = (np.nan, np.nan)  # Placeholder


            metrics.append({
                'c_index': c_index,
                'n_events': n_events
            })
        except Exception as e:
            print(f"Error calculating metrics for condition {i}: {str(e)}")
            metrics.append({
                'c_index': np.nan,
                'n_events': n_events
            })

    return metrics


In [4]:
import numpy as np
import pandas as pd
from lifelines.utils import concordance_index
from sklearn.metrics import roc_auc_score

# Assuming `files_list` is defined and the two functions you've shared are loaded

rows = []
print(f"Found {files_list} files")

num_accepted_files = 0
num_keys = 0
num_diagnosis = 0

for file in files_list:
    if 'death_and_diagnosis_results.npy' not in str(file):
        continue
    num_accepted_files += 1
    
    name = str(file).split('/')[-2]
    print(f"Calculating death and diagnosis C-index for {name}")
    
    results = np.load(file, allow_pickle=True).item()
    if results is None:
        continue

    config_path = os.path.join(os.getcwd().split('pretrain_comparison/evaluation')[0], 'pretrain_comparison/fine_tune/config_fine_tune.yaml')
    config = yaml.safe_load(open(config_path))
    diagnosis_labels = pd.read_csv(os.path.join(config['diagnosis_labels_path'], 'is_event.csv'))
    diagnosis_list = list(diagnosis_labels.columns.values[1:])

    for key in ['death', 'diagnosis']:
        num_keys += 1
        times = results[key]['times']
        presence = results[key]['presence']
        hazards = results[key]['hazards']

        metrics = calculate_metrics_per_condition(times, presence, hazards)
        
        for i, m in enumerate(metrics):
            if len(metrics) != 1 and len(metrics) != 12:
                print(f"Warning: Number of metrics {len(metrics)} does not match number of conditions, that is either 1 or 12")
            num_diagnosis += 1
            diagnosis_name = diagnosis_list[i] if len(metrics) == 12 else 'death'
            row = {
                'name': name,
                'condition_idx': diagnosis_name,
                'c_index': m['c_index'],
                'n_events': m['n_events'],
            }
            rows.append(row)

df = pd.DataFrame(rows, columns=[
    'name', 'condition_idx', 'c_index', 'n_events'
])
# Pivot so that each condition becomes a column with its c_index
df_pivoted = df.pivot_table(index='name', columns='condition_idx', values='c_index')

# Optionally, reset the index if you want 'name' as a column
df_pivoted = df_pivoted.reset_index()
df_pivoted['overall'] = df_pivoted.iloc[:, 1:].mean(axis=1)
df_pivoted['overall diagnosis'] = df_pivoted.iloc[:, 1:-2].mean(axis=1)
df_pivoted.to_csv(os.path.join(base_path, 'death_and_diagnosis_results.csv'), index=False)


Found [PosixPath('/oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/sleep_stage_and_age_results.npy'), PosixPath('/oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/ahi_diagnosis_results.npy'), PosixPath('/oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/death_and_diagnosis_results.npy')] files
Calculating death and diagnosis C-index for CL_pairwise_epochs_36
