In [1]:
import numpy as np
import pandas as pd
import re
import os.path

from sklearn.tree import plot_tree
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored , concordance_index_ipcw
from sklearn.impute import SimpleImputer
from sksurv.util import Surv
from lifelines.utils import concordance_index
from sklearn.preprocessing import StandardScaler

# Plotly imports (instead of matplotlib)
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
df_train = pd.read_csv("../../data/train_enhanced.csv")
df_val   = pd.read_csv("../../data/val_enhanced.csv")

In [3]:
# Data exploration
gene_cols = [col for col in df_train.columns if col.startswith("Gene_")]
print(f"Number of genes: {len(gene_cols)}")
print(f"Number of patients: {len(df_train)}")
print(f"\nOS_STATUS distribution (1=dead, 0=censored):")
print(df_train['OS_STATUS'].value_counts())
print(f"\nOS_YEARS statistics:")
print(df_train['OS_YEARS'].describe())

Number of genes: 70
Number of patients: 3173

OS_STATUS distribution (1=dead, 0=censored):
OS_STATUS
True     1600
False    1573
Name: count, dtype: int64

OS_YEARS statistics:
count    3173.000000
mean        2.480713
std         2.588259
min         0.000000
25%         0.652055
50%         1.652055
75%         3.572603
max        22.043836
Name: OS_YEARS, dtype: float64


# Gene Ranking for Survival Prediction

## Statistical methods used:

1. **Log-Rank Test**: Compares survival curves between patients with vs without mutation. This is the standard non-parametric test in survival analysis.

2. **Univariate Cox Model**: Estimates the Hazard Ratio (HR) for each gene. HR > 1 = increased risk of death.

3. **Bonferroni** and **FDR (Benjamini-Hochberg)** corrections: Control for multiple testing (70 tests).

4. **Concordance Index (C-index)**: Measures the discriminative ability of each gene.

5. **Bootstrap**: Estimates the stability of the results.

In [4]:
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.statistics import logrank_test, multivariate_logrank_test
from scipy import stats
from statsmodels.stats.multitest import multipletests
import warnings
warnings.filterwarnings('ignore')

# Prepare survival data
y_train = Surv.from_dataframe('OS_STATUS', 'OS_YEARS', df_train)

# Lists to store results
results = []

print("Running analysis for each gene...")
for i, gene in enumerate(gene_cols):
    # Groups: mutated (1) vs non-mutated (0)
    group_0 = df_train[df_train[gene] == 0]
    group_1 = df_train[df_train[gene] == 1]
    
    n_mutated = len(group_1)
    n_not_mutated = len(group_0)
    
    # Skip if too few patients in a group (< 10)
    if n_mutated < 10 or n_not_mutated < 10:
        continue
    
    # 1. Log-Rank Test
    try:
        lr_result = logrank_test(
            group_1['OS_YEARS'], group_0['OS_YEARS'],
            group_1['OS_STATUS'], group_0['OS_STATUS']
        )
        logrank_pvalue = lr_result.p_value
        logrank_stat = lr_result.test_statistic
    except:
        logrank_pvalue = 1.0
        logrank_stat = 0.0
    
    # 2. Univariate Cox model
    try:
        cox_df = df_train[['OS_YEARS', 'OS_STATUS', gene]].copy()
        cox_df['OS_STATUS'] = cox_df['OS_STATUS'].astype(int)
        cph = CoxPHFitter()
        cph.fit(cox_df, duration_col='OS_YEARS', event_col='OS_STATUS')
        
        hazard_ratio = np.exp(cph.params_[gene])
        cox_pvalue = cph.summary.loc[gene, 'p']
        cox_ci_lower = np.exp(cph.confidence_intervals_.loc[gene, '95% lower-bound'])
        cox_ci_upper = np.exp(cph.confidence_intervals_.loc[gene, '95% upper-bound'])
        cox_coef = cph.params_[gene]
    except:
        hazard_ratio = 1.0
        cox_pvalue = 1.0
        cox_ci_lower = 1.0
        cox_ci_upper = 1.0
        cox_coef = 0.0
    
    # 3. C-index for this gene alone
    try:
        c_index = concordance_index(
            df_train['OS_YEARS'], 
            -df_train[gene],  # Negative because CI assumes larger value = better prognosis
            df_train['OS_STATUS']
        )
    except:
        c_index = 0.5
    
    # Median survival by group
    median_surv_0 = group_0[group_0['OS_STATUS'] == True]['OS_YEARS'].median() if (group_0['OS_STATUS'] == True).sum() > 0 else np.nan
    median_surv_1 = group_1[group_1['OS_STATUS'] == True]['OS_YEARS'].median() if (group_1['OS_STATUS'] == True).sum() > 0 else np.nan
    
    results.append({
        'Gene': gene.replace('Gene_', ''),
        'N_Mutated': n_mutated,
        'N_Not_Mutated': n_not_mutated,
        'Prevalence_%': 100 * n_mutated / len(df_train),
        'LogRank_Stat': logrank_stat,
        'LogRank_pvalue': logrank_pvalue,
        'Cox_Coef': cox_coef,
        'Hazard_Ratio': hazard_ratio,
        'HR_CI_Lower': cox_ci_lower,
        'HR_CI_Upper': cox_ci_upper,
        'Cox_pvalue': cox_pvalue,
        'C_index': c_index,
        'Median_Surv_NoMut': median_surv_0,
        'Median_Surv_Mut': median_surv_1
    })

print(f"Analysis completed for {len(results)} genes")

Running analysis for each gene...
Analysis completed for 70 genes


In [5]:
# Create results DataFrame
df_results = pd.DataFrame(results)

# Multiple testing corrections
# 1. Bonferroni (very conservative)
df_results['LogRank_pvalue_Bonferroni'] = np.minimum(df_results['LogRank_pvalue'] * len(df_results), 1.0)
df_results['Cox_pvalue_Bonferroni'] = np.minimum(df_results['Cox_pvalue'] * len(df_results), 1.0)

# 2. Benjamini-Hochberg FDR (less conservative, recommended)
_, logrank_fdr, _, _ = multipletests(df_results['LogRank_pvalue'], method='fdr_bh')
_, cox_fdr, _, _ = multipletests(df_results['Cox_pvalue'], method='fdr_bh')
df_results['LogRank_FDR'] = logrank_fdr
df_results['Cox_FDR'] = cox_fdr

# Composite score (combination of several metrics)
# Use -log10(p-value) so that small p-values give high scores
df_results['Score_LogRank'] = -np.log10(df_results['LogRank_pvalue'].clip(lower=1e-300))
df_results['Score_Cox'] = -np.log10(df_results['Cox_pvalue'].clip(lower=1e-300))
df_results['Score_Cindex'] = np.abs(df_results['C_index'] - 0.5) * 2  # Normalize between 0 and 1

# Weighted composite score
df_results['Score_Composite'] = (
    0.4 * df_results['Score_LogRank'] / df_results['Score_LogRank'].max() +
    0.4 * df_results['Score_Cox'] / df_results['Score_Cox'].max() +
    0.2 * df_results['Score_Cindex'] / df_results['Score_Cindex'].max()
)

print("Results with multiple testing corrections computed")

Results with multiple testing corrections computed


## Final Gene Ranking

Sorted by **composite score** combining:
- Log-Rank p-value (40%): non-parametric test comparing survival curves
- Cox p-value (40%): parametric test based on the proportional hazards model
- C-index (20%): discriminative ability of the gene alone

**Interpretation of the Hazard Ratio (HR):**
- HR > 1: mutation associated with increased risk of death
- HR < 1: mutation associated with better survival
- HR ≈ 1: no significant effect

In [6]:
# Main ranking by composite score
df_ranking = df_results.sort_values('Score_Composite', ascending=False).reset_index(drop=True)
df_ranking.index = df_ranking.index + 1  # Rank starts at 1
df_ranking.index.name = 'Rank'

# Display top 20
cols_display = ['Gene', 'N_Mutated', 'Prevalence_%', 'Hazard_Ratio', 'HR_CI_Lower', 'HR_CI_Upper', 
                'LogRank_pvalue', 'LogRank_FDR', 'Cox_pvalue', 'Cox_FDR', 'C_index', 'Score_Composite']

print("=" * 100)
print("TOP 20 MOST PREDICTIVE GENES FOR SURVIVAL")
print("=" * 100)

display_df = df_ranking[cols_display].head(20).copy()
display_df['Prevalence_%'] = display_df['Prevalence_%'].round(1)
display_df['Hazard_Ratio'] = display_df['Hazard_Ratio'].round(3)
display_df['HR_CI_Lower'] = display_df['HR_CI_Lower'].round(3)
display_df['HR_CI_Upper'] = display_df['HR_CI_Upper'].round(3)
display_df['LogRank_pvalue'] = display_df['LogRank_pvalue'].apply(lambda x: f"{x:.2e}")
display_df['LogRank_FDR'] = display_df['LogRank_FDR'].apply(lambda x: f"{x:.2e}")
display_df['Cox_pvalue'] = display_df['Cox_pvalue'].apply(lambda x: f"{x:.2e}")
display_df['Cox_FDR'] = display_df['Cox_FDR'].apply(lambda x: f"{x:.2e}")
display_df['C_index'] = display_df['C_index'].round(4)
display_df['Score_Composite'] = display_df['Score_Composite'].round(4)

print(display_df.to_string())

TOP 20 MOST PREDICTIVE GENES FOR SURVIVAL
        Gene  N_Mutated  Prevalence_%  Hazard_Ratio  HR_CI_Lower  HR_CI_Upper LogRank_pvalue LogRank_FDR Cox_pvalue   Cox_FDR  C_index  Score_Composite
Rank                                                                                                                                                   
1       TP53        368          11.6         2.744        2.405        3.130       3.83e-55    2.68e-53   6.06e-51  4.24e-49   0.5662           1.0000
2      RUNX1        448          14.1         2.142        1.889        2.429       4.22e-34    1.48e-32   1.28e-32  4.48e-31   0.5479           0.6440
3      ASXL1        859          27.1         1.670        1.504        1.855       2.96e-22    6.90e-21   8.06e-22  1.88e-20   0.5456           0.4640
4      STAG2        288           9.1         1.960        1.688        2.276       2.60e-19    4.54e-18   1.15e-18  2.01e-17   0.5320           0.3762
5      SF3B1        727          22.9         

In [7]:
# Genes statistically significant after FDR correction (threshold 0.05)
significant_genes = df_ranking[df_ranking['Cox_FDR'] < 0.05]
print(f"\n{'='*80}")
print(f"STATISTICALLY SIGNIFICANT GENES (FDR < 0.05): {len(significant_genes)} genes")
print(f"{'='*80}")

if len(significant_genes) > 0:
    for idx, row in significant_genes.iterrows():
        effect = "HIGHER RISK" if row['Hazard_Ratio'] > 1 else "PROTECTIVE EFFECT"
        print(f"\n{idx}. {row['Gene']}")
        print(f"   HR = {row['Hazard_Ratio']:.3f} [{row['HR_CI_Lower']:.3f} - {row['HR_CI_Upper']:.3f}]")
        print(f"   FDR = {row['Cox_FDR']:.2e} | C-index = {row['C_index']:.4f}")
        print(f"   Prevalence: {row['Prevalence_%']:.1f}% ({row['N_Mutated']} patients)")
        print(f"   Effect: {effect}")
else:
    print("No gene reaches the significance threshold after FDR correction.")


STATISTICALLY SIGNIFICANT GENES (FDR < 0.05): 25 genes

1. TP53
   HR = 2.744 [2.405 - 3.130]
   FDR = 4.24e-49 | C-index = 0.5662
   Prevalence: 11.6% (368 patients)
   Effect: HIGHER RISK

2. RUNX1
   HR = 2.142 [1.889 - 2.429]
   FDR = 4.48e-31 | C-index = 0.5479
   Prevalence: 14.1% (448 patients)
   Effect: HIGHER RISK

3. ASXL1
   HR = 1.670 [1.504 - 1.855]
   FDR = 1.88e-20 | C-index = 0.5456
   Prevalence: 27.1% (859 patients)
   Effect: HIGHER RISK

4. STAG2
   HR = 1.960 [1.688 - 2.276]
   FDR = 2.01e-17 | C-index = 0.5320
   Prevalence: 9.1% (288 patients)
   Effect: HIGHER RISK

5. SF3B1
   HR = 0.636 [0.562 - 0.720]
   FDR = 9.68e-12 | C-index = 0.4461
   Prevalence: 22.9% (727 patients)
   Effect: PROTECTIVE EFFECT

6. NRAS
   HR = 2.109 [1.749 - 2.544]
   FDR = 7.78e-14 | C-index = 0.5227
   Prevalence: 5.3% (167 patients)
   Effect: HIGHER RISK

7. EZH2
   HR = 1.922 [1.615 - 2.286]
   FDR = 1.95e-12 | C-index = 0.5200
   Prevalence: 6.6% (209 patients)
   Effect: HIGH

## Visualisation: Volcano Plot and Forest Plot (Plotly)

In [8]:
# Volcano Plot with Plotly
log_hr = np.log2(df_ranking['Hazard_Ratio'])
neg_log_pvalue = -np.log10(df_ranking['Cox_pvalue'])

# Colors according to significance
colors = []
for idx, row in df_ranking.iterrows():
    if row['Cox_FDR'] < 0.05:
        if row['Hazard_Ratio'] > 1:
            colors.append('red')  # Significant higher risk
        else:
            colors.append('blue')  # Significant protective
    else:
        colors.append('gray')  # Not significant

# Threshold line for FDR=0.05 (approximate)
fdr_threshold = df_ranking[df_ranking['Cox_FDR'] < 0.05]['Cox_pvalue'].max() if len(significant_genes) > 0 else 0.05/70

fig_volcano = go.Figure()

fig_volcano.add_trace(
    go.Scatter(
        x=log_hr,
        y=neg_log_pvalue,
        mode='markers',
        marker=dict(
            color=colors,
            size=8,
            line=dict(width=0.5, color='black')
        ),
        text=df_ranking['Gene'],
        hovertemplate='Gene: %{text}<br>log2(HR): %{x:.2f}<br>-log10(p): %{y:.2f}<extra></extra>'
    )
)

# Add horizontal line at FDR threshold
fig_volcano.add_hline(
    y=-np.log10(fdr_threshold),
    line_dash='dash',
    line_color='red',
    annotation_text='FDR=0.05 threshold',
    annotation_position='top left'
)

# Add vertical line at log2(HR)=0
fig_volcano.add_vline(
    x=0,
    line_dash='solid',
    line_color='black'
)

# Annotate most significant genes
for idx, row in df_ranking.iterrows():
    if row['Cox_FDR'] < 0.01:
        fig_volcano.add_annotation(
            x=np.log2(row['Hazard_Ratio']),
            y=-np.log10(row['Cox_pvalue']),
            text=row['Gene'],
            showarrow=False,
            yshift=8,
            font=dict(size=8)
        )

fig_volcano.update_layout(
    title='Volcano Plot - Gene Impact on Survival',
    xaxis_title='log2(Hazard Ratio)',
    yaxis_title='-log10(p-value)',
)

volcano_path = '../../figures/gene_ranking_volcano.html'
fig_volcano.write_html(volcano_path)
fig_volcano.show()
print(f"\nFigure saved: {volcano_path}")


Figure saved: ../../figures/gene_ranking_volcano.html


In [9]:
# Forest Plot for Top 15 Genes with Plotly
top_n = 15
top_genes = df_ranking.head(top_n).copy()
top_genes = top_genes.iloc[::-1]  # Reverse for display from bottom to top

y_labels = top_genes['Gene'].tolist()
hr = top_genes['Hazard_Ratio'].values
hr_lower = top_genes['HR_CI_Lower'].values
hr_upper = top_genes['HR_CI_Upper'].values

# Error bars (95% CI)
xerr_lower = hr - hr_lower
xerr_upper = hr_upper - hr

colors_forest = ['red' if h > 1 else 'blue' for h in hr]

fig_forest = go.Figure()

fig_forest.add_trace(
    go.Scatter(
        x=hr,
        y=y_labels,
        mode='markers',
        marker=dict(
            color=colors_forest,
            size=10,
            line=dict(color='black', width=1)
        ),
        error_x=dict(
            type='data',
            symmetric=False,
            array=xerr_upper,
            arrayminus=xerr_lower,
            thickness=1.5,
            width=3
        ),
        hovertemplate='Gene: %{y}<br>HR: %{x:.2f}<extra></extra>'
    )
)

# Add vertical line at HR=1
fig_forest.add_vline(
    x=1,
    line_dash='dash',
    line_color='black'
)

# Determine x-range for shading
x_min = float(min(hr_lower.min(), 0.1))
x_max = float(max(hr_upper.max(), 5.0))

fig_forest.update_layout(
    title=f'Forest Plot - Top {top_n} Genes',
    xaxis=dict(
        title='Hazard Ratio (95% CI)',
        range=[x_min, x_max]
    ),
    yaxis=dict(
        title='Gene'
    ),
    shapes=[
        # Higher risk zone (HR > 1)
        dict(
            type='rect',
            x0=1,
            x1=x_max,
            y0=-0.5,
            y1=len(y_labels) - 0.5,
            fillcolor='rgba(255, 0, 0, 0.1)',
            line=dict(width=0),
            layer='below'
        ),
        # Protective zone (HR < 1)
        dict(
            type='rect',
            x0=x_min,
            x1=1,
            y0=-0.5,
            y1=len(y_labels) - 0.5,
            fillcolor='rgba(0, 0, 255, 0.1)',
            line=dict(width=0),
            layer='below'
        )
    ],
)

forest_path = '../../figures/gene_ranking_forest.html'
fig_forest.write_html(forest_path)
fig_forest.show()
print(f"Figure saved: {forest_path}")

Figure saved: ../../figures/gene_ranking_forest.html


## Bootstrap Validation of Ranking Stability

Method: 500 resamplings with replacement to estimate variance of Hazard Ratios and ranking stability.

In [10]:
from tqdm import tqdm
np.random.seed(42)

n_bootstrap = 200
n_samples = len(df_train)
top_genes_list = df_ranking['Gene'].head(20).tolist()

# Store bootstrap results
bootstrap_hr = {gene: [] for gene in top_genes_list}
bootstrap_ranks = {gene: [] for gene in top_genes_list}

print(f"Bootstrap in progress ({n_bootstrap} iterations)...")
for b in tqdm(range(n_bootstrap)):
    # Resampling
    idx_boot = np.random.choice(n_samples, size=n_samples, replace=True)
    df_boot = df_train.iloc[idx_boot].reset_index(drop=True)
    
    hr_dict = {}
    for gene in top_genes_list:
        gene_col = f"Gene_{gene}"
        try:
            cox_df = df_boot[['OS_YEARS', 'OS_STATUS', gene_col]].copy()
            cox_df['OS_STATUS'] = cox_df['OS_STATUS'].astype(int)
            cph = CoxPHFitter(penalizer=0.01)
            cph.fit(cox_df, duration_col='OS_YEARS', event_col='OS_STATUS')
            hr = np.exp(cph.params_[gene_col])
            bootstrap_hr[gene].append(hr)
            hr_dict[gene] = hr
        except:
            bootstrap_hr[gene].append(np.nan)
            hr_dict[gene] = 1.0
    
    # Ranking based on effect size (distance to HR=1)
    sorted_genes = sorted(hr_dict.keys(), key=lambda g: abs(np.log(hr_dict[g])), reverse=True)
    for rank, gene in enumerate(sorted_genes, 1):
        bootstrap_ranks[gene].append(rank)

print("Bootstrap completed!")

Bootstrap in progress (200 iterations)...


100%|██████████| 200/200 [15:08<00:00,  4.54s/it]

Bootstrap completed!





In [11]:
# Analysis of bootstrap results
bootstrap_summary = []
for gene in top_genes_list:
    hr_values = np.array(bootstrap_hr[gene])
    hr_values = hr_values[~np.isnan(hr_values)]
    rank_values = np.array(bootstrap_ranks[gene])
    
    bootstrap_summary.append({
        'Gene': gene,
        'HR_Original': df_ranking[df_ranking['Gene'] == gene]['Hazard_Ratio'].values[0],
        'HR_Bootstrap_Mean': np.mean(hr_values),
        'HR_Bootstrap_Std': np.std(hr_values),
        'HR_Bootstrap_2.5%': np.percentile(hr_values, 2.5),
        'HR_Bootstrap_97.5%': np.percentile(hr_values, 97.5),
        'Rank_Mean': np.mean(rank_values),
        'Rank_Std': np.std(rank_values),
        'Rank_Min': np.min(rank_values),
        'Rank_Max': np.max(rank_values),
        'Stability': 1 - (np.std(rank_values) / 10)  # Normalized stability score
    })

df_bootstrap = pd.DataFrame(bootstrap_summary)
df_bootstrap = df_bootstrap.sort_values('Rank_Mean')
df_bootstrap.index = range(1, len(df_bootstrap) + 1)
df_bootstrap.index.name = 'Rank'

print("=" * 100)
print("RANKING STABILITY BY BOOTSTRAP (Top 20)")
print("=" * 100)
print("\nBootstrap 95% confidence interval for Hazard Ratios:")
print(df_bootstrap[['Gene', 'HR_Original', 'HR_Bootstrap_Mean', 'HR_Bootstrap_2.5%', 
                     'HR_Bootstrap_97.5%', 'Rank_Mean', 'Rank_Std']].round(3).to_string())

RANKING STABILITY BY BOOTSTRAP (Top 20)

Bootstrap 95% confidence interval for Hazard Ratios:
        Gene  HR_Original  HR_Bootstrap_Mean  HR_Bootstrap_2.5%  HR_Bootstrap_97.5%  Rank_Mean  Rank_Std
Rank                                                                                                    
1       TP53        2.744              2.708              2.350               3.172      2.180     1.108
2       FLT3        2.735              2.774              1.860               4.225      2.915     2.796
3      CEBPA        2.333              2.373              1.850               2.998      4.455     2.549
4        MLL        2.362              2.320              1.748               2.953      5.080     3.021
5      RUNX1        2.142              2.125              1.845               2.449      6.505     1.985
6       NRAS        2.109              2.097              1.686               2.602      7.190     2.732
7      GATA2        2.091              2.137              1.539   

## Kaplan-Meier Curves for the Most Significant Genes (Plotly)

In [12]:
# KM curves for the 6 most significant genes
top_6_genes = df_ranking['Gene'].head(6).tolist()

fig_km = make_subplots(
    rows=2,
    cols=3,
    subplot_titles=[gene for gene in top_6_genes],
    horizontal_spacing=0.07,
    vertical_spacing=0.15
)

for i, gene in enumerate(top_6_genes):
    row = i // 3 + 1
    col = i % 3 + 1
    gene_col = f"Gene_{gene}"
    
    # Groups
    group_0 = df_train[df_train[gene_col] == 0]
    group_1 = df_train[df_train[gene_col] == 1]
    
    # Kaplan-Meier for WT
    kmf = KaplanMeierFitter()
    kmf.fit(group_0['OS_YEARS'], group_0['OS_STATUS'], label=f'{gene} WT (n={len(group_0)})')
    sf_0 = kmf.survival_function_
    time_0 = sf_0.index.values
    surv_0 = sf_0.iloc[:, 0].values
    
    # Kaplan-Meier for Mut
    kmf.fit(group_1['OS_YEARS'], group_1['OS_STATUS'], label=f'{gene} Mut (n={len(group_1)})')
    sf_1 = kmf.survival_function_
    time_1 = sf_1.index.values
    surv_1 = sf_1.iloc[:, 0].values
    
    # Get HR and p-value for subtitle
    gene_row = df_ranking[df_ranking['Gene'] == gene].iloc[0]
    hr = gene_row['Hazard_Ratio']
    pval = gene_row['Cox_pvalue']
    
    fig_km.add_trace(
        go.Scatter(
            x=time_0,
            y=surv_0,
            mode='lines',
            name=f'{gene} WT (n={len(group_0)})',
            legendgroup=f'{gene}_WT',
            showlegend=(row == 1 and col == 1)  # Only show once in global legend
        ),
        row=row,
        col=col
    )
    
    fig_km.add_trace(
        go.Scatter(
            x=time_1,
            y=surv_1,
            mode='lines',
            name=f'{gene} Mut (n={len(group_1)})',
            legendgroup=f'{gene}_Mut',
            showlegend=(row == 1 and col == 1)
        ),
        row=row,
        col=col
    )
    
    fig_km.update_xaxes(title_text='Time (years)', row=row, col=col, range=[0, 10])
    fig_km.update_yaxes(title_text='Survival probability', row=row, col=col, range=[0, 1])
    
    # Update subplot title with HR and p-value
    fig_km.layout.annotations[i].text = f'{gene}<br>HR={hr:.2f}, p={pval:.2e}'

fig_km.update_layout(
    title='Kaplan-Meier Curves for Top 6 Genes',
    height=800,
    width=1200
)

km_path = '../../figures/kaplan_meier_top_genes.html'
fig_km.write_html(km_path)
fig_km.show()
print(f"\nFigure saved: {km_path}")


Figure saved: ../../figures/kaplan_meier_top_genes.html


## Summary and Export of Results

In [13]:
# Export full ranking
df_export = df_ranking[['Gene', 'N_Mutated', 'Prevalence_%', 'Hazard_Ratio', 'HR_CI_Lower', 'HR_CI_Upper',
                        'LogRank_pvalue', 'LogRank_FDR', 'Cox_pvalue', 'Cox_FDR', 'C_index', 'Score_Composite']].copy()
df_export.to_csv('../../results/gene_ranking_survival.csv', index=True)
print("Results exported to results/gene_ranking_survival.csv")

# Final summary
print("\n" + "="*100)
print("SUMMARY OF THE ANALYSIS")
print("="*100)

n_significant_bonf = len(df_ranking[df_ranking['Cox_pvalue_Bonferroni'] < 0.05])
n_significant_fdr = len(df_ranking[df_ranking['Cox_FDR'] < 0.05])
n_protective = len(df_ranking[(df_ranking['Cox_FDR'] < 0.05) & (df_ranking['Hazard_Ratio'] < 1)])
n_risk = len(df_ranking[(df_ranking['Cox_FDR'] < 0.05) & (df_ranking['Hazard_Ratio'] > 1)])

print(f"\n📊 GLOBAL STATISTICS:")
print(f"   • {len(gene_cols)} genes analyzed")
print(f"   • {n_significant_fdr} significant genes (FDR < 0.05)")
print(f"   • {n_significant_bonf} significant genes (Bonferroni < 0.05)")
print(f"   • {n_risk} higher-risk genes (HR > 1)")
print(f"   • {n_protective} protective genes (HR < 1)")

print(f"\n🔴 TOP 5 HIGHER-RISK GENES (HR > 1):")
risk_genes = df_ranking[(df_ranking['Cox_FDR'] < 0.05) & (df_ranking['Hazard_Ratio'] > 1)].head(5)
for i, (_, row) in enumerate(risk_genes.iterrows(), 1):
    print(f"   {i}. {row['Gene']:10s} | HR = {row['Hazard_Ratio']:.2f} [{row['HR_CI_Lower']:.2f}-{row['HR_CI_Upper']:.2f}] | FDR = {row['Cox_FDR']:.2e}")

print(f"\n🔵 PROTECTIVE GENES (HR < 1):")
protective_genes = df_ranking[(df_ranking['Cox_FDR'] < 0.05) & (df_ranking['Hazard_Ratio'] < 1)]
for i, (_, row) in enumerate(protective_genes.iterrows(), 1):
    print(f"   {i}. {row['Gene']:10s} | HR = {row['Hazard_Ratio']:.2f} [{row['HR_CI_Lower']:.2f}-{row['HR_CI_Upper']:.2f}] | FDR = {row['Cox_FDR']:.2e}")

print(f"\n📈 METHODS USED:")
print("   1. Log-Rank test (non-parametric comparison of survival curves)")
print("   2. Univariate Cox model (estimation of Hazard Ratios)")
print("   3. Benjamini-Hochberg FDR correction (control of false positives)")
print("   4. Bootstrap (200 iterations) for stability assessment")
print("   5. Composite score combining Log-Rank, Cox, and C-index")

Results exported to results/gene_ranking_survival.csv

SUMMARY OF THE ANALYSIS

📊 GLOBAL STATISTICS:
   • 70 genes analyzed
   • 25 significant genes (FDR < 0.05)
   • 21 significant genes (Bonferroni < 0.05)
   • 24 higher-risk genes (HR > 1)
   • 1 protective genes (HR < 1)

🔴 TOP 5 HIGHER-RISK GENES (HR > 1):
   1. TP53       | HR = 2.74 [2.41-3.13] | FDR = 4.24e-49
   2. RUNX1      | HR = 2.14 [1.89-2.43] | FDR = 4.48e-31
   3. ASXL1      | HR = 1.67 [1.50-1.85] | FDR = 1.88e-20
   4. STAG2      | HR = 1.96 [1.69-2.28] | FDR = 2.01e-17
   5. NRAS       | HR = 2.11 [1.75-2.54] | FDR = 7.78e-14

🔵 PROTECTIVE GENES (HR < 1):
   1. SF3B1      | HR = 0.64 [0.56-0.72] | FDR = 9.68e-12

📈 METHODS USED:
   1. Log-Rank test (non-parametric comparison of survival curves)
   2. Univariate Cox model (estimation of Hazard Ratios)
   3. Benjamini-Hochberg FDR correction (control of false positives)
   4. Bootstrap (200 iterations) for stability assessment
   5. Composite score combining Log-Rank