In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
import warnings

warnings.filterwarnings('ignore')

# ==================== Data Loading ====================
df = pd.read_csv("./data/roc_data.csv")

# Define features
proteins = ['APOM', 'GDF15', 'GHRL', 'LGALS4']
covariates = ['age', 'sex', 'race', 'edu', 'Townsend.index', 'BMI']
all_features = proteins + covariates

# Data preprocessing
data = df[all_features + ['Dementia_type', 'Dementia_interval', 'APOEe4_carrier']].dropna().copy()
data['years_to_dx'] = data['Dementia_interval'] / 365.25
data.loc[data['Dementia_type'] == 0, 'years_to_dx'] = 20

print(f"Total sample size: {len(data)}")
print(f"APOEe4_carrier = 0: {(data['APOEe4_carrier'] == 0).sum()}")
print(f"APOEe4_carrier = 1: {(data['APOEe4_carrier'] == 1).sum()}")
print(f"Number of dementia cases: {(data['Dementia_type'] == 1).sum()}")

# ==================== Function to compute time-dependent AUC ====================
def compute_dynamic_auc(subset_data, feature_cols, years_before=np.arange(1, 16),
                        n_boot=2000, seed=42):
    """
    Compute time-dependent AUC and its confidence intervals
    """
    X = subset_data[feature_cols]
    y = subset_data['Dementia_type']
    
    # Train model
    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(X, y)
    risk_score = model.predict_proba(X)[:, 1]
    
    years_to_dx = subset_data['years_to_dx'].values
    dementia_type = subset_data['Dementia_type'].values
    
    aucs = []
    cis = []
    n_cases_list = []
    n_controls_list = []
    rng = np.random.default_rng(seed)
    
    for t in years_before:
        # Define cases and controls
        case = (dementia_type == 1) & (years_to_dx <= t)
        control = years_to_dx > t
        
        # Create binary label
        y_bin = case.astype(int)
        y_score = risk_score
        
        # Check sample size
        n_case = y_bin.sum()
        n_control = (1 - y_bin).sum()
        n_cases_list.append(n_case)
        n_controls_list.append(n_control)
        
        if n_case < 10 or n_control < 10:
            aucs.append(np.nan)
            cis.append((np.nan, np.nan))
            continue
        
        # Compute AUC
        auc = roc_auc_score(y_bin, y_score)
        
        # Bootstrap confidence interval
        boot_aucs = []
        for _ in range(n_boot):
            idx = rng.choice(len(y_bin), len(y_bin), replace=True)
            if len(np.unique(y_bin[idx])) < 2:
                continue
            boot_aucs.append(roc_auc_score(y_bin[idx], y_score[idx]))
        
        aucs.append(auc)
        if len(boot_aucs) > 100:
            cis.append((np.percentile(boot_aucs, 2.5), np.percentile(boot_aucs, 97.5)))
        else:
            cis.append((np.nan, np.nan))
    
    aucs = np.array(aucs)
    low = np.array([ci[0] for ci in cis])
    high = np.array([ci[1] for ci in cis])
    
    return aucs, low, high, n_cases_list, n_controls_list

# ==================== Function to create AUC result table ====================
def create_auc_table(years, aucs_protein, low_p, high_p, aucs_full, low_f, high_f,
                     n_cases_p, n_controls_p, population_name):
    """
    Create table containing AUC results for all time points
    """
    results = []
    
    for i, year in enumerate(years):
        row = {
            'Population': population_name,
            'Year': year,
            'N_Cases': n_cases_p[i],
            'N_Controls': n_controls_p[i],
            'AUC_Protein': aucs_protein[i],
            'CI_Low_Protein': low_p[i],
            'CI_High_Protein': high_p[i],
            'AUC_Full': aucs_full[i],
            'CI_Low_Full': low_f[i],
            'CI_High_Full': high_f[i]
        }
        results.append(row)
    
    df_results = pd.DataFrame(results)
    
    # Mark maximum AUC
    valid_protein = df_results['AUC_Protein'].notna()
    valid_full = df_results['AUC_Full'].notna()
    
    df_results['Max_Protein'] = ''
    df_results['Max_Full'] = ''
    
    if valid_protein.any():
        max_idx_p = df_results.loc[valid_protein, 'AUC_Protein'].idxmax()
        df_results.loc[max_idx_p, 'Max_Protein'] = '★'
    
    if valid_full.any():
        max_idx_f = df_results.loc[valid_full, 'AUC_Full'].idxmax()
        df_results.loc[max_idx_f, 'Max_Full'] = '★'
    
    return df_results

def format_auc_ci(auc, low, high):
    """Format AUC and confidence interval"""
    if pd.isna(auc):
        return "N/A"
    return f"{auc:.3f} ({low:.3f}-{high:.3f})"

def create_summary_table(df_table):
    """Create formatted summary table"""
    summary = df_table.copy()
    summary['Protein_Model'] = summary.apply(
        lambda x: format_auc_ci(x['AUC_Protein'], x['CI_Low_Protein'], x['CI_High_Protein']) + x['Max_Protein'], 
        axis=1
    )
    summary['Full_Model'] = summary.apply(
        lambda x: format_auc_ci(x['AUC_Full'], x['CI_Low_Full'], x['CI_High_Full']) + x['Max_Full'], 
        axis=1
    )
    
    return summary[['Population', 'Year', 'N_Cases', 'N_Controls', 'Protein_Model', 'Full_Model']]

# ==================== Plotting Function ====================
def plot_dual_auc_curves(years, aucs_protein, low_p, high_p, 
                         aucs_full, low_f, high_f,
                         title, filename_prefix, show_legend=True):
    """
    Plot dual time-dependent AUC curves
    """
    fig, ax = plt.subplots(figsize=(9, 6))
    
    # Protein-only model - orange
    ax.plot(years, aucs_protein, color='#D55E00', linewidth=3, label='4-protein panel')
    ax.fill_between(years, low_p, high_p, color='#D55E00', alpha=0.2)
    
    # Protein + covariates model - blue
    ax.plot(years, aucs_full, color='#0072B2', linewidth=3, label='4-protein + covariates')
    ax.fill_between(years, low_f, high_f, color='#0072B2', alpha=0.2)
    
    # Reference lines
    ax.axhline(0.7, color='gray', linewidth=1, linestyle='-', alpha=0.6)
    ax.axhline(0.8, color='gray', linewidth=0.8, linestyle='--', alpha=0.5)
    
    # Axis settings
    ax.set_ylim(0.50, 0.94)
    ax.set_xlim(1, 15)
    ax.set_xlabel('Years before dementia diagnosis', fontsize=13, labelpad=10)
    ax.set_ylabel('Time-dependent AUC', fontsize=13, labelpad=10)
    ax.set_title(title, fontsize=14, pad=20)
    
    # Minimalist style
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.grid(True, axis='y', alpha=0.3, linewidth=0.8)
    
    if show_legend:
        ax.legend(loc='lower right', fontsize=11, frameon=False)
    
    plt.tight_layout()
    plt.savefig(f"./plot/{filename_prefix}.png", dpi=400, bbox_inches='tight')
    plt.savefig(f"./plot/{filename_prefix}.pdf", format='pdf', bbox_inches='tight', dpi=600,
                transparent=False, metadata={'Creator': 'matplotlib', 'Producer': None})
    plt.show()
    
    return fig

# ==================== Main Analysis ====================
years_before = np.arange(1, 16)
all_tables = []  # Store all table results

# 1. Overall population analysis
print("\n" + "="*60)
print("1. Overall Population Analysis")
print("="*60)

aucs_protein_all, low_p_all, high_p_all, n_cases_all, n_controls_all = compute_dynamic_auc(data, proteins, years_before)
aucs_full_all, low_f_all, high_f_all, _, _ = compute_dynamic_auc(data, all_features, years_before)

# Create table
table_all = create_auc_table(
    years_before, aucs_protein_all, low_p_all, high_p_all,
    aucs_full_all, low_f_all, high_f_all, n_cases_all, n_controls_all,
    'Overall'
)
all_tables.append(table_all)

plot_dual_auc_curves(
    years_before, aucs_protein_all, low_p_all, high_p_all,
    aucs_full_all, low_f_all, high_f_all,
    'Predictive performance: Overall population',
    'dynamic_AUC_overall'
)

# 2. APOEe4 non-carriers (APOEe4_carrier = 0) subgroup
print("\n" + "="*60)
print("2. APOEε4 Non-carriers (APOEe4_carrier = 0) Analysis")
print("="*60)

data_apoe0 = data[data['APOEe4_carrier'] == 0].copy()
print(f"Sample size: {len(data_apoe0)}, Number of cases: {(data_apoe0['Dementia_type']==1).sum()}")

aucs_protein_apoe0, low_p_apoe0, high_p_apoe0, n_cases_apoe0, n_controls_apoe0 = compute_dynamic_auc(data_apoe0, proteins, years_before)
aucs_full_apoe0, low_f_apoe0, high_f_apoe0, _, _ = compute_dynamic_auc(data_apoe0, all_features, years_before)

# Create table
table_apoe0 = create_auc_table(
    years_before, aucs_protein_apoe0, low_p_apoe0, high_p_apoe0,
    aucs_full_apoe0, low_f_apoe0, high_f_apoe0, n_cases_apoe0, n_controls_apoe0,
    'APOEe4_Non-carriers'
)
all_tables.append(table_apoe0)

plot_dual_auc_curves(
    years_before, aucs_protein_apoe0, low_p_apoe0, high_p_apoe0,
    aucs_full_apoe0, low_f_apoe0, high_f_apoe0,
    'Predictive performance: APOEε4 non-carriers',
    'dynamic_AUC_APOEe4_noncarrier'
)

# 3. APOEe4 carriers (APOEe4_carrier = 1) subgroup
print("\n" + "="*60)
print("3. APOEε4 Carriers (APOEe4_carrier = 1) Analysis")
print("="*60)

data_apoe1 = data[data['APOEe4_carrier'] == 1].copy()
print(f"Sample size: {len(data_apoe1)}, Number of cases: {(data_apoe1['Dementia_type']==1).sum()}")

aucs_protein_apoe1, low_p_apoe1, high_p_apoe1, n_cases_apoe1, n_controls_apoe1 = compute_dynamic_auc(data_apoe1, proteins, years_before)
aucs_full_apoe1, low_f_apoe1, high_f_apoe1, _, _ = compute_dynamic_auc(data_apoe1, all_features, years_before)

# Create table
table_apoe1 = create_auc_table(
    years_before, aucs_protein_apoe1, low_p_apoe1, high_p_apoe1,
    aucs_full_apoe1, low_f_apoe1, high_f_apoe1, n_cases_apoe1, n_controls_apoe1,
    'APOEe4_Carriers'
)
all_tables.append(table_apoe1)

plot_dual_auc_curves(
    years_before, aucs_protein_apoe1, low_p_apoe1, high_p_apoe1,
    aucs_full_apoe1, low_f_apoe1, high_f_apoe1,
    'Predictive performance: APOEε4 carriers',
    'dynamic_AUC_APOEe4_carrier'
)

# ==================== Combine and Export Tables ====================
print("\n" + "="*60)
print("4. Exporting AUC Result Tables")
print("="*60)

# Combine all tables
combined_table = pd.concat(all_tables, ignore_index=True)

# Save raw detailed table
combined_table.to_csv('./data/dynamic_AUC_results_raw.csv', index=False)
print("\nRaw detailed table saved: ./data/dynamic_AUC_results_raw.csv")

# Create formatted summary table
summary_table = create_summary_table(combined_table)
summary_table.to_csv('./data/dynamic_AUC_results_summary.csv', index=False)
print("Formatted summary table saved: ./data/dynamic_AUC_results_summary.csv")

# Display summary table in console
print("\n" + "="*80)
print("                    Time-dependent AUC Summary Table (★ = maximum AUC)")
print("="*80)

for pop in ['Overall', 'APOEe4_Non-carriers', 'APOEe4_Carriers']:
    pop_data = summary_table[summary_table['Population'] == pop]
    print(f"\n【{pop}】")
    print("-"*80)
    print(f"{'Year':<6}{'N_Cases':<10}{'N_Controls':<12}{'Protein Model':<28}{'Full Model':<28}")
    print("-"*80)
    for _, row in pop_data.iterrows():
        print(f"{row['Year']:<6}{row['N_Cases']:<10}{row['N_Controls']:<12}{row['Protein_Model']:<28}{row['Full_Model']:<28}")

# ==================== Maximum AUC Summary Table ====================
print("\n" + "="*60)
print("5. Maximum AUC Summary")
print("="*60)

max_auc_summary = []

for pop in ['Overall', 'APOEe4_Non-carriers', 'APOEe4_Carriers']:
    pop_data = combined_table[combined_table['Population'] == pop]
    
    # Protein model max AUC
    valid_p = pop_data['AUC_Protein'].notna()
    if valid_p.any():
        max_row_p = pop_data.loc[pop_data.loc[valid_p, 'AUC_Protein'].idxmax()]
        max_auc_summary.append({
            'Population': pop,
            'Model': '4-Protein Panel',
            'Best_Year': int(max_row_p['Year']),
            'Max_AUC': max_row_p['AUC_Protein'],
            'CI_Low': max_row_p['CI_Low_Protein'],
            'CI_High': max_row_p['CI_High_Protein'],
            'N_Cases': int(max_row_p['N_Cases']),
            'N_Controls': int(max_row_p['N_Controls'])
        })
    
    # Full model max AUC
    valid_f = pop_data['AUC_Full'].notna()
    if valid_f.any():
        max_row_f = pop_data.loc[pop_data.loc[valid_f, 'AUC_Full'].idxmax()]
        max_auc_summary.append({
            'Population': pop,
            'Model': '4-Protein + Covariates',
            'Best_Year': int(max_row_f['Year']),
            'Max_AUC': max_row_f['AUC_Full'],
            'CI_Low': max_row_f['CI_Low_Full'],
            'CI_High': max_row_f['CI_High_Full'],
            'N_Cases': int(max_row_f['N_Cases']),
            'N_Controls': int(max_row_f['N_Controls'])
        })

max_auc_df = pd.DataFrame(max_auc_summary)
max_auc_df['AUC_with_CI'] = max_auc_df.apply(
    lambda x: f"{x['Max_AUC']:.3f} ({x['CI_Low']:.3f}-{x['CI_High']:.3f})", axis=1
)

# Save max AUC summary
max_auc_df.to_csv('./data/dynamic_AUC_max_summary.csv', index=False)
print("\nMaximum AUC summary table saved: ./data/dynamic_AUC_max_summary.csv")

# Print max AUC summary
print("\n" + "-"*90)
print(f"{'Population':<25}{'Model':<25}{'Best Year':<12}{'Max AUC (95% CI)':<25}")
print("-"*90)
for _, row in max_auc_df.iterrows():
    print(f"{row['Population']:<25}{row['Model']:<25}{row['Best_Year']:<12}{row['AUC_with_CI']:<25}")
print("-"*90)

# ==================== Combined Plot ====================
print("\n" + "="*60)
print("6. Generating Combined Plot")
print("="*60)

fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))

plot_params = [
    (aucs_protein_all, low_p_all, high_p_all, aucs_full_all, low_f_all, high_f_all, 
     'Overall population', 'A'),
    (aucs_protein_apoe0, low_p_apoe0, high_p_apoe0, aucs_full_apoe0, low_f_apoe0, high_f_apoe0, 
     'APOEε4 non-carriers', 'B'),
    (aucs_protein_apoe1, low_p_apoe1, high_p_apoe1, aucs_full_apoe1, low_f_apoe1, high_f_apoe1, 
     'APOEε4 carriers', 'C')
]

for ax, (aucs_p, low_p, high_p, aucs_f, low_f, high_f, title, panel) in zip(axes, plot_params):
    ax.plot(years_before, aucs_p, color='#D55E00', linewidth=2.5, label='4-protein panel')
    ax.fill_between(years_before, low_p, high_p, color='#D55E00', alpha=0.2)
    
    ax.plot(years_before, aucs_f, color='#0072B2', linewidth=2.5, label='4-protein + covariates')
    ax.fill_between(years_before, low_f, high_f, color='#0072B2', alpha=0.2)
    
    ax.axhline(0.7, color='gray', linewidth=1, linestyle='-', alpha=0.6)
    ax.axhline(0.8, color='gray', linewidth=0.8, linestyle='--', alpha=0.5)
    
    ax.set_ylim(0.50, 0.94)
    ax.set_xlim(1, 15)
    ax.set_xlabel('Years before diagnosis', fontsize=12)
    ax.set_ylabel('Time-dependent AUC', fontsize=12)
    ax.set_title(f'{panel}. {title}', fontsize=13, pad=10, loc='left', fontweight='bold')
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='both', which='major', labelsize=11)
    ax.grid(True, axis='y', alpha=0.3, linewidth=0.8)
    
    if panel == 'A':
        ax.legend(loc='lower right', fontsize=10, frameon=False)

plt.tight_layout()
plt.savefig("./plot/dynamic_AUC_combined_panel.png", dpi=400, bbox_inches='tight')
plt.savefig("./plot/dynamic_AUC_combined_panel.pdf", format='pdf', bbox_inches='tight', dpi=600,
            transparent=False, metadata={'Creator': 'matplotlib', 'Producer': None})
plt.show()

print("\n" + "="*60)
print("All analyses completed!")
print("="*60)
print("\nGenerated files:")
print("  - dynamic_AUC_results_raw.csv          (raw detailed results)")
print("  - dynamic_AUC_results_summary.csv      (formatted summary table)")
print("  - dynamic_AUC_max_summary.csv          (maximum AUC summary)")
print("  - dynamic_AUC_overall.png/pdf")
print("  - dynamic_AUC_APOEe4_noncarrier.png/pdf")
print("  - dynamic_AUC_APOEe4_carrier.png/pdf")
print("  - dynamic_AUC_combined_panel.png/pdf")