In [None]:
"""
Expected data path and required files: 

DATA_DIR
├── UK_65
│    ├── "dataset.csv"
│    ├── "to_censure.pkl"
│    ├── "inactive_ids.pkl"
│    └── "stats.pkl"
│
├── UK_70
├── FR_65
└── FR_70
"""

DATA_DIR =  '../../../data/datasets'
OUTPUT_DIR = '../../../data/results'
EXTRACTION_DATE='2023-01-01'
END_YEAR=2010

In [None]:
DISEASES_OF_INTEREST = [
    'alzheimer',
    #'parkinson',
    # 'vascular_dementias',
    #'mci',
    # 'alcohol_dementias',
    # 'frontotemporal_dementias',
    # 'other_dementias',
    # 'parkinson_dementias',
    'all_dementias',
]

In [None]:
import os

import pandas as pd
from glob import glob
import pickle
import ast

import numpy as np

from typing import Dict, Set, List, Tuple
from IPython.display import display

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

from lifelines import AalenJohansenFitter, CoxPHFitter
from lifelines.statistics import logrank_test

# Plot like R
import matplotlib as mpl
mpl.style.use('ggplot')

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Helvetica']
plt.rcParams['font.size'] = 12

sns.set_palette('pastel')
sns.set(style='ticks')

# sns.despine()

In [None]:
XLIM = (pd.to_datetime(EXTRACTION_DATE)-pd.to_datetime(f'{END_YEAR}-12-31')).days
nb_years = int(XLIM/365.25)
YEARS = np.arange(nb_years+1)
YEARS_TICKS = (365.25*YEARS).astype(int)

In [None]:
def load_dataset(country, age, base_dir=DATA_DIR):
    dataset_file = os.path.join(base_dir, country + "_" + str(age), "dataset.csv")
    to_censure_file = os.path.join(base_dir, country + "_" + str(age), "to_censure.pkl")
    inactive_ids_file = os.path.join(base_dir, country + "_" + str(age), "inactive_ids.pkl")
    stats_file = os.path.join(base_dir, country + "_" + str(age), "stats.pkl")

    dataset = pd.read_csv(dataset_file)
    # convert the tuple strings back to tuples (when we load the df the `diseases` column is considered as a string)
    dataset['diseases'] = dataset['diseases'].apply(lambda x: ast.literal_eval(x))
    
    with open(to_censure_file, "rb") as f:
        to_censure = pickle.load(f)
    with open(inactive_ids_file, "rb") as f:
        inactive_ids = pickle.load(f)
    with open(stats_file, "rb") as f:
        stats = pickle.load(f)

    return dataset, to_censure, inactive_ids, stats

def random_split(df, percentage_split = 0.3, random_state=42):
           """
           Split aléatoire stratifié 20/80 qui préserve la proportion des classes
           basé sur la colonne 'diseases' (liste vide = 0, liste non-vide = 1)
           percentage split% pour Cox, (1-percentage_split)% pour Prediction
           
           Parameters:
           -----------
           df : DataFrame
               Dataset à diviser (doit contenir une colonne 'diseases')
           random_state : int
               Seed pour la reproductibilité
               
           Returns:
           --------
           df_cox, df_pred : tuple of DataFrames
               Cox (percentage split%) et Prediction ((1-percentage_split)%) avec proportions équilibrées des classes
           """
           # Créer la variable target basée sur diseases
           df = df.copy()
           df['target'] = df['diseases'].apply(lambda x: 0 if len(x) == 0 else 1)

           # Séparer les classes
           df_class_0 = df[df['target'] == 0].copy()  # Pas de maladie
           df_class_1 = df[df['target'] == 1].copy()  # Au moins une maladie

           # Mélanger chaque classe séparément
           df_class_0 = df_class_0.sample(frac=1, random_state=random_state)
           df_class_1 = df_class_1.sample(frac=1, random_state=random_state)

           # Split 20/80 pour chaque classe
           split_idx_0 = int(len(df_class_0) * percentage_split)  
           split_idx_1 = int(len(df_class_1) * percentage_split)  

           # 20% pour Cox
           df_cox_0 = df_class_0.iloc[:split_idx_0].copy()
           df_cox_1 = df_class_1.iloc[:split_idx_1].copy()
           df_cox = pd.concat([df_cox_0, df_cox_1], ignore_index=True)

           # 80% pour Prediction  
           df_pred_0 = df_class_0.iloc[split_idx_0:].copy()
           df_pred_1 = df_class_1.iloc[split_idx_1:].copy()
           df_pred = pd.concat([df_pred_0, df_pred_1], ignore_index=True)

           # Mélanger les splits finaux
           df_cox = df_cox.sample(frac=1, random_state=random_state).reset_index(drop=True)
           df_pred = df_pred.sample(frac=1, random_state=random_state).reset_index(drop=True)

           # Supprimer la colonne target temporaire
           df_cox = df_cox.drop('target', axis=1)
           df_pred = df_pred.drop('target', axis=1)

           # Statistiques
           n_total = len(df)
           n_class_1_total = df['target'].sum()
           n_class_0_total = n_total - n_class_1_total

           # Recalculer pour les splits finaux
           df_cox_target = df_cox['diseases'].apply(lambda x: 0 if len(x) == 0 else 1)
           df_pred_target = df_pred['diseases'].apply(lambda x: 0 if len(x) == 0 else 1)

           n_cox_1 = df_cox_target.sum()
           n_cox_0 = len(df_cox) - n_cox_1

           n_pred_1 = df_pred_target.sum()
           n_pred_0 = len(df_pred) - n_pred_1

           print(f"Split stratifié 30/70 basé sur colonne 'diseases':")
           print(f"Dataset original: {n_total:,} patients ({n_class_1_total:,} avec maladies [{n_class_1_total/n_total*100:.1f}%], {n_class_0_total:,} sans maladie)")
           print(f"Cox split (30%): {len(df_cox):,} patients ({n_cox_1:,} avec maladies [{n_cox_1/len(df_cox)*100:.1f}%], {n_cox_0:,} sans maladie)")
           print(f"Prediction split (70%): {len(df_pred):,} patients ({n_pred_1:,} avec maladies [{n_pred_1/len(df_pred)*100:.1f}%], {n_pred_0:,} sans maladie)")

           # Vérification des proportions
           cox_ratio = len(df_cox) / n_total * 100
           pred_ratio = len(df_pred) / n_total * 100
           print(f"Ratios réels: Cox {cox_ratio:.1f}%, Prediction {pred_ratio:.1f}%")

           return df_cox, df_pred

In [None]:
# Appliquer le split à tous les datasets
print("=== RANDOM SPLIT 30/70 ===")
DATASETS_COX = {65:{}, 70:{}}     # Pour Cox models
DATASETS_PRED = {65:{}, 70:{}}    # Pour Prediction

for age in {65, 70}:
    for country in {'FR', 'UK'}:
        df_full, *_ = load_dataset(country, age)
        df_cox, df_pred = random_split(df_full, random_state=5)  # ← Changement ici

        DATASETS_COX[age][country] = df_cox
        DATASETS_PRED[age][country] = df_pred

        print(f"{country}_{age}: {len(df_full)} total -> {len(df_cox)} cox + {len(df_pred)} pred")

# Sauvegarder les données récentes pour le notebook Prediction
import pickle
with open(os.path.join(OUTPUT_DIR, 'datasets_prediction.pkl'), 'wb') as f:
    pickle.dump(DATASETS_PRED, f)

In [None]:
from functools import reduce
MED_SUPPORT = reduce(set.intersection, [set(DATASETS_COX[age][country].columns[9:]) for country in ['FR', 'UK'] for age in [65, 70]])
MED_SUPPORT = list(MED_SUPPORT)
MED_SUPPORT.remove('avg. CHARLSON')
print(MED_SUPPORT)

In [None]:
len(MED_SUPPORT)

# Dataset formatting

In [None]:
def preprocess(df, disease, min_disease_time:bool=False): # drop_other_diseases:bool,
    """
    Map original dataset with the following columns:
        - person_id
        - diseases
        - gender_code
        - person_state_code
        - duration (days)
        - avg. Alcohol (glasses/day)
        - avg. Tobaco (cigarettes/day)
        - avg. BMI
        - avg. CHARLSON
        - one hot encoded columns of medications
    
    Into the following format:
    +------+-------+-------+ ... -------+
    | time | event | cov_1 | ...  cov_n |
    +------+-------+-------+ ... -------+
    
    where
        - event is indicating if the patient has the disease (1) or is inactive/temporaire (0) or is dead (2)
        - time = time to event
    """
    med_cols = df.columns[9:].to_list()

    # Extraire le statut MCI à baseline s'il existe
    if 'mci_at_baseline' in df.columns:
        mci_baseline = df['mci_at_baseline'].fillna(0).astype(int)
    else:
        # Si pas disponible, créer une colonne avec des valeurs manquantes
        print("⚠️  Colonne 'mci_at_baseline' non trouvée - création avec valeurs NaN")
        mci_baseline = pd.Series([np.nan] * len(df), index=df.index)
    #-----------------------------------------------------------------------------------------------
    ## Transform variables
    # Alcohol
    drink_alcohol = lambda x: int(x >= 2) if np.isfinite(x) else np.nan
    df['drink ≥ 2 glasses/day'] = df.pop('avg. Alcohol (glasses/day)').apply(drink_alcohol)
    
    
    # Smoking
    smoke = lambda x: int(x > 0) if  np.isfinite(x) else np.nan
    df['smoke'] = df.pop('avg. Tobaco (cigarettes/day)').apply(smoke)
    
    # BMI (www.cdc.gov)
    df['underweight'] = df['avg. BMI'].where(df['avg. BMI'].isna(), 
                                             (df['avg. BMI'] < 20).astype(int))
    df['normal'] = df['avg. BMI'].where(df['avg. BMI'].isna(), 
                                        df['avg. BMI'].between(20, 25, inclusive='left').astype(int))
    df['overweight'] = df['avg. BMI'].where(df['avg. BMI'].isna(), 
                                            df['avg. BMI'].between(25, 30, inclusive='left').astype(int))
    df['obese'] = df['avg. BMI'].where(df['avg. BMI'].isna(), 
                                       df['avg. BMI'].between(30, 35, inclusive='left').astype(int))
    df['extremely obese'] = df['avg. BMI'].where(df['avg. BMI'].isna(), 
                                                 (df['avg. BMI'] >= 35).astype(int))
    
    # CHARLSON
    df['low CCI'] = df['avg. CHARLSON'].where(df['avg. CHARLSON'].isna(), 
                                              (df['avg. CHARLSON'] < 4).astype(int))                                  
    df['medium CCI'] = df['avg. CHARLSON'].where(df['avg. CHARLSON'].isna(), 
                                                 df['avg. CHARLSON'].between(4, 5, inclusive='left').astype(int))
    df['high CCI'] = df['avg. CHARLSON'].where(df['avg. CHARLSON'].isna(), 
                                               (df['avg. CHARLSON'] > 5).astype(int))
    # female
    df['is female'] = df.pop('gender_code').eq('F').astype(int)

    # MCI à baseline 
    df['mci_at_baseline'] = mci_baseline
    
    non_med_cols = [
        'is female', 
        'avg. CHARLSON',
            'low CCI', 'medium CCI', 'high CCI',
        'avg. BMI', 
            'underweight', 'normal', 'overweight', 'obese', 'extremely obese',
        'drink ≥ 2 glasses/day', 
        'smoke'
    ]
    #-----------------------------------------------------------------------------------------------
    # if we have multiple diseases we consider them sound for this disease of interest
    event = df['diseases'].apply(lambda x: any(d == disease for d, _ in x))
    
    if min_disease_time:
        f = lambda x: min(x, key=lambda x: x[1])[1]
    else:
        f = lambda x: [t for d, t in x if d == disease][0]
    
    time = df['duration (days)']
    time.mask(event, df.loc[event, 'diseases'].apply(f), inplace=True)
    
    # for the competing risk (death) put 2 for patients dead w/o the disease
    mask_die = df['person_state_code'].isin({'D', 'S', 'P'})
    competing_event = np.where(mask_die & (~event), 2, event.astype(int))
    
    final_df = pd.DataFrame({
        'time': time,
        'event': competing_event,
    })
    
    return pd.concat([final_df, df[non_med_cols+med_cols]], axis=1)

In [None]:
%%time
names = [f'{country}_{age}' for age in {65, 70} for country in {'FR', 'UK'}]
preprocessed_df = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}

for age in [65, 70]:
    for country in ['FR', 'UK']:
        df = DATASETS_COX[age][country]
        for disease in DISEASES_OF_INTEREST:
            preprocessed_df[disease][f'{country}_{age}'] = preprocess(df.copy(), disease, min_disease_time=True)
            
            # save dataframe
            #preprocessed_df[disease][f'{country}_{age}'].to_csv(os.path.join(OUTPUT_DIR, f'{country}_{age}_{disease}.csv'), index=False)

# Medication analysis

In [None]:
CIF_TITLE = "Comparative CIF: Influence of "

In [None]:
def fit_cox(df, col:str, adj_cols:List[str], alpha:float):
    ci = 100*(1-alpha)
    cols = ['time', 'event'] + [col] + adj_cols
        
    ## Fit COX
    #-------------------------------------------------------------------------------------------------------
    # censor the death event
    event = df['event'].copy()
    df['event'] = np.where(df['event']==1, 1, 0)
    
    # fit
    cph = CoxPHFitter(alpha=alpha).fit(df[cols], duration_col='time', event_col='event')
    summary = cph.summary

    # p-value
    p_value = summary.loc[col, 'p']
    
    # hazard_ratio
    hazard_ratio = summary.loc[col, 'exp(coef)']
    ci_lower = summary.loc[col, f'exp(coef) lower {ci:.4f}%']
    ci_upper = summary.loc[col, f'exp(coef) upper {ci:.4f}%']
    #-------------------------------------------------------------------------------------------------------
    # death as event and censor event 
    df['event'] =  np.where(event==2, 1, 0)
    
    # fit
    cph = CoxPHFitter(alpha=alpha).fit(df[cols], duration_col='time', event_col='event')
    summary = cph.summary

    # p-value
    death_p_value = summary.loc[col, 'p']
    
    # hazard_ratio
    death_hazard_ratio = summary.loc[col, 'exp(coef)']
    death_ci_lower = summary.loc[col, f'exp(coef) lower {ci:.4f}%']
    death_ci_upper = summary.loc[col, f'exp(coef) upper {ci:.4f}%']
    
    return p_value, (ci_lower, hazard_ratio, ci_upper), death_p_value, (death_ci_lower, death_hazard_ratio, death_ci_upper), 


def get_better_ax(ax, title):
    ax.set_title(title)
    ax.legend(loc='upper left')

    # x axis
    ax.set_xlabel('Time (years)')
    ax.set_xticks(YEARS_TICKS, labels=YEARS)
    ax.set_xlim(0, XLIM)

    # frame
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(top=False, right=False)

    # y
    ax.set_ylabel('Cumulative incidence')
    ax.grid(axis='y', alpha=0.5)
    
def plot_CIF(alpha, TA, EA, labelA, TB, EB, labelB, ax, title):
    ajd = AalenJohansenFitter(alpha=alpha)
    for T, E, label in [(TA, EA, labelA), (TB, EB, labelB)]:
        ajd.fit(T, E, event_of_interest=1, label=label)
        ajd.plot(ax=ax)
    get_better_ax(ax, CIF_TITLE+title)


def get_cox_analysis(df, med_names:List[str], alpha:float=0.05, correction:int=1, title:str=None)->Dict[str, float]:
    ### Setting
    # corrections
    alpha_corrected = alpha/correction
    
    # prepare covariate adjustment 
    col_adj = ['is female', 'avg. CHARLSON', 'avg. BMI']
    all_not_na = pd.concat([df[col].notna() for col in col_adj], axis=1).all(axis=1)
    
    # CIF
    T, E = df['time'], df['event']
    med_names.sort()
    
    # Figure
    n_cols = 2; n_rows = (len(med_names) +3)// n_cols
    if (len(med_names) +3) % n_cols != 0:
        n_rows += 1
    
    fig_width = 12; fig_height = 4 * n_rows    

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))
    axes = axes.ravel()
    
    #======================================================================================================================================
    #======================================================================================================================================
    ### Models
    df = df.loc[:, ~df.columns.duplicated()]  # Supprime les doublons de colonnes
    p_values, hazard_ratios = {}, {}
    death_p_values, death_hazard_ratios = {}, {}
    adjusted_p_values, adjusted_hazard_ratios = {}, {}
    death_adjusted_p_values, death_adjusted_hazard_ratios = {}, {}
    
    ## Non meds covariates
    # gender 
    cond=df['is female'].astype(bool)
    TA, EA = T[cond], E[cond]
    TB, EB = T[~cond], E[~cond]
    
    plot_CIF(alpha=alpha_corrected, 
             TA=TA, EA=EA, labelA="Female",
             TB=TB, EB=EB, labelB="Male",
             ax=axes[0], title="gender")
    
    p, hr, death_p, death_hr = fit_cox(df.copy(), alpha=alpha_corrected,
                                       col='is female', adj_cols=[])
    p_values['is female'], hazard_ratios['is female'] = p*correction, hr
    death_p_values['is female'], death_hazard_ratios['is female'] = death_p*correction, death_hr
    #----------------------------------------------------------------------------------------------------
    # CHARLSON
    mask = df['avg. CHARLSON'].notna()
    print(mask)
    print(f"Type de mask: {type(mask)}")
    print(f"Shape de mask: {mask.shape}")
    print(f"Dimensions: {mask.ndim}")
    df_charlson = df.loc[mask]
    
    ajd = AalenJohansenFitter(alpha=alpha_corrected)
    for level in ['low', 'medium', 'high']:
        label = level + ' CCI'
        cond=df_charlson[label].astype(bool)
        TA, EA = df_charlson.loc[cond, 'time'], df_charlson.loc[cond, 'event']
        ajd.fit(TA, EA, event_of_interest=1, label=label)
        ajd.plot(ax=axes[1])
    get_better_ax(axes[1], CIF_TITLE+"CCI")
    
    p, hr, death_p, death_hr = fit_cox(df_charlson.copy(), alpha=alpha_corrected, 
                                       col='avg. CHARLSON', adj_cols=[])
    p_values['avg. CHARLSON'], hazard_ratios['avg. CHARLSON'] = p*correction, hr
    death_p_values['avg. CHARLSON'], death_hazard_ratios['avg. CHARLSON'] = death_p*correction, death_hr
    #----------------------------------------------------------------------------------------------------
    # BMI
    mask = df['avg. BMI'].notna()
    df_bmi = df.loc[mask]
    for label in ['underweight', 'normal', 'overweight', 'obese', 'extremely obese']:
        cond=df_bmi[label].astype(bool)
        TA, EA = df_bmi.loc[cond, 'time'], df_bmi.loc[cond, 'event']
        ajd.fit(TA, EA, event_of_interest=1, label=label)
        ajd.plot(ax=axes[2])
    get_better_ax(axes[2], CIF_TITLE+"BMI")
    
    p, hr, death_p, death_hr =  fit_cox(df_bmi.copy(), alpha=alpha_corrected, 
                                        col='avg. BMI', adj_cols=[])
    p_values['avg. BMI'], hazard_ratios['avg. BMI'] = p*correction, hr
    death_p_values['avg. BMI'], death_hazard_ratios['avg. BMI'] = death_p*correction, death_hr
    #----------------------------------------------------------------------------------------------------
    ## Meds
    print(med_names)
    for med_name, ax in zip(med_names, axes[3:]):
        #====================================================================================================
        ### CIF
        print(med_name)
        cond = df[med_name].astype(bool)
        TA, EA = T[cond], E[cond]
        TB, EB = T[~cond], E[~cond]
        
        plot_CIF(alpha=alpha_corrected, 
                 TA=TA, EA=EA, labelA="Took "+med_name,
                 TB=TB, EB=EB, labelB="Didn't take "+med_name,
                 ax=ax, title=med_name)
        #====================================================================================================
        ### Cox
        #----------------------------------------------------------------------------------------------------
        ## Unadjusted
        p, hr, death_p, death_hr = fit_cox(df.copy(), alpha=alpha_corrected, 
                                           col=med_name, adj_cols=[])
        p_values[med_name], hazard_ratios[med_name] = p*correction, hr
        death_p_values[med_name], death_hazard_ratios[med_name] = death_p*correction, death_hr
        #----------------------------------------------------------------------------------------------------
        ## Adjusting for gender, BMI, CHARLSON
        p, hr, death_p, death_hr = fit_cox(df.loc[all_not_na].copy(), alpha=alpha_corrected, 
                                           col=med_name, adj_cols=col_adj)
        adjusted_p_values[med_name], adjusted_hazard_ratios[med_name] = p*correction, hr
        death_adjusted_p_values[med_name], death_adjusted_hazard_ratios[med_name] = death_p*correction, death_hr
    #===============================================================================
    #===============================================================================
    ## Remove unused axes
    if (len(med_names)+3) < n_rows * n_cols:
        for ax in axes[len(med_names)+3:]:
            ax.set_axis_off()
    
    ## Title
    plt.subplots_adjust(top=0.9)
    t = "Comparative Cumulative Incidence"
    if title:
         t += " for " + title
    fig.suptitle(t, fontsize=16, y=1.005)
    
    plt.tight_layout()
    plt.show()
    
    return (p_values, hazard_ratios, 
            death_p_values, death_hazard_ratios, 
            adjusted_p_values, adjusted_hazard_ratios, 
            death_adjusted_p_values, death_adjusted_hazard_ratios)

In [None]:
%%time
import warnings
warnings.filterwarnings("ignore", message="Tied event times were detected.")

rename = lambda name: ' '.join(name.split('_')).capitalize()

names = [f'{country}_{age}' for age in {65, 70} for country in {'FR', 'UK'}]

## Unadjusted
# Event
p_values = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}
hazard_ratios = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}
# Death
death_p_values = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}
death_hazard_ratios = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}

## Adjusted
# Event
adjusted_p_values = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}
adjusted_hazard_ratios = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}
# Death
death_adjusted_p_values = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}
death_adjusted_hazard_ratios = {disease:{name:{} for name in names} for disease in DISEASES_OF_INTEREST}


for age in [65, 70]:
    for country in ['FR', 'UK']:
        #MED_SUPPORT.remove('avg. CHARLSON')
        for disease in DISEASES_OF_INTEREST:
            
            results = get_cox_analysis(preprocessed_df[disease][f'{country}_{age}'].copy(), 
                                       med_names=MED_SUPPORT,
                                       alpha=0.05, 
                                       correction=len(MED_SUPPORT), 
                                       title=f'{country}_{age} ({rename(disease) if disease != "mci" else "MCI"})')
            
            ## Unadjusted
            # Event
            p_values[disease][f'{country}_{age}'] = results[0] 
            hazard_ratios[disease][f'{country}_{age}'] = results[1]
            # Death
            death_p_values[disease][f'{country}_{age}'] = results[2] 
            death_hazard_ratios[disease][f'{country}_{age}'] = results[3]            
            
            ## Adjusted
            # Event
            adjusted_p_values[disease][f'{country}_{age}'] = results[4]
            adjusted_hazard_ratios[disease][f'{country}_{age}'] = results[5]
            # Death
            death_adjusted_p_values[disease][f'{country}_{age}'] = results[6]
            death_adjusted_hazard_ratios[disease][f'{country}_{age}'] = results[7]

In [None]:
results

Beware of results for FR_65:
- Alzheimer: J05, M04 
- All dementias: JO5  

## Results

### Construct

In [None]:
def create_p_values_table(p_values, adjusted:bool=False, death:bool=False):
    get_name = lambda country, age: f"{country}_{age}:{'death ' if death else''}{'adjusted ' if adjusted else ''}p-value"
    return pd.concat([pd.DataFrame({get_name(country, age):p_values[f'{country}_{age}'].values()}, index=p_values[f'{country}_{age}'].keys()) for age in [65, 70] for country in ['FR', 'UK']], axis=1)

def create_hazard_table(hazard_ratios, adjusted:bool=False, death:bool=False):
    get_name = lambda country, age: f"{country}_{age}:{'death ' if death else''}{'adjusted ' if adjusted else ''}hazard ratios (95% CI)"
    return pd.concat(
        [pd.DataFrame({get_name(country, age):[f"{v:.2f} ({l:.2f} - {u:.2f})" for med, (l, v, u) in hazard_ratios[f'{country}_{age}'].items()]}, 
                      index=hazard_ratios[f'{country}_{age}'].keys()) 
                      for age in [65, 70] for country in ['FR', 'UK']
        ], 
        axis=1)

In [None]:
def fusion_tables(values:List[pd.DataFrame]):
    df = pd.concat(values, axis=1)
    
    # Make multi-index columns
    column_names = df.columns.tolist()
    split_names = [name.split(':') for name in column_names]
    
    multi_index = pd.MultiIndex.from_tuples(split_names)
    df.columns = multi_index
    
    # Change order of columns
    column_order = [
        (f'{country}_{age}', f'{state}{correction}{name}')\
        for age in [65, 70] for country in ['FR', 'UK'] \
        for correction in ['', 'adjusted '] \
        for state in ['', 'death '] \
        for name in ['hazard ratios (95% CI)', 'p-value']
    ]
    
    index_order = ['is female', 'avg. CHARLSON', 'avg. BMI'] + sorted(MED_SUPPORT)
    
    return df.loc[index_order, column_order]

In [None]:
res = {}
for disease in DISEASES_OF_INTEREST:
    p_value = create_p_values_table(p_values[disease], adjusted=False, death=False)
    hazard_ratio = create_hazard_table(hazard_ratios[disease], adjusted=False, death=False)
    
    adjusted_p_value = create_p_values_table(adjusted_p_values[disease], adjusted=True, death=False)
    adjusted_hazard_ratio = create_hazard_table(adjusted_hazard_ratios[disease], adjusted=True, death=False)
    
    death_p_value = create_p_values_table(death_p_values[disease], adjusted=False, death=True)
    death_hazard_ratio = create_hazard_table(death_hazard_ratios[disease], adjusted=False, death=True)
    
    death_adjusted_p_value = create_p_values_table(death_adjusted_p_values[disease], adjusted=True, death=True)
    death_adjusted_hazard_ratio = create_hazard_table(death_adjusted_hazard_ratios[disease], adjusted=True, death=True)
    
    
    values = [p_value, hazard_ratio, adjusted_p_value, adjusted_hazard_ratio,
              death_p_value, death_hazard_ratio, death_adjusted_p_value, death_adjusted_hazard_ratio]
    res[disease] = fusion_tables(values)

In [None]:
output_file = os.path.join(OUTPUT_DIR, 'res.pkl')
with open(output_file, 'wb') as file:
    pickle.dump(res, file)

### Display

In [None]:
output_file = os.path.join(OUTPUT_DIR, 'res.pkl')
with open(output_file, 'rb') as file:
    res = pickle.load(file)

In [None]:
threshold = 5e-2
highlight_significative_cells = lambda row: ['background-color: lightgreen' if isinstance(value, float) and value < threshold else '' for value in row]
#highlight_significative_rows = lambda row: ['background-color: green' if any(row.xs('p-value', level=1) < threshold) else ''] * len(row)
highlight_meds_rows = lambda row: ['background-color: yellow' if row.name in {'A06', 'A10', 'C09', 'C10', 'N06'} else '']*len(row)

In [None]:
subset = [(f'{country}_{age}', f'{state}{correction}p-value') for age in [65, 70] \
          for country in ['FR', 'UK'] for correction in ['', 'adjusted '] for state in ['', 'death ']]

In [None]:
def display_with_format(df):
    display(df.style.apply(highlight_meds_rows, axis=1).apply(highlight_significative_cells, axis=1).format(formatter='{:.2e}', subset=subset))

In [None]:
display_with_format(res['alzheimer'])
res['alzheimer'].to_csv(OUTPUT_DIR+"/results_alzheimer.csv")

In [None]:
display_with_format(res['all_dementias'])
print(OUTPUT_DIR)
res['all_dementias'].to_csv(OUTPUT_DIR+"/results_all_dementias.csv")

# Nicer plots

In [None]:
def get_better_ax(ax, title):
    ax.set_title(title)
    ax.legend(loc='upper left')

    # x axis
    ax.set_xlabel('Time (years)')
    ax.set_xticks(YEARS_TICKS, labels=YEARS)
    ax.set_xlim(0, XLIM)

    # frame
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(top=False, right=False)

    # y
    ax.set_ylabel('Cumulative incidence')
    ax.grid(axis='y', alpha=0.5)

def plot_CIF(alpha, TA, EA, labelA, TB, EB, labelB, ax, title):
    ajd = AalenJohansenFitter(alpha=alpha)
    for T, E, label in [(TA, EA, labelA), (TB, EB, labelB)]:
        ajd.fit(T, E, event_of_interest=1, label=label)
        ajd.plot(ax=ax)
    get_better_ax(ax, title)

In [None]:
%%time
alpha_corrected = 0.05/43
med_to_show = ['A06', 'G04', 'N06', 'C09', 'C10']

n_cols = len(med_to_show); n_rows = len(DISEASES_OF_INTEREST)
fig_width = 5 * n_cols; fig_height = 4 * n_rows    

fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), sharey='row')

for j, med_name in enumerate(med_to_show):
    for i, disease in enumerate(DISEASES_OF_INTEREST):
        ax = axes[i, j]
        
        df = preprocessed_df[disease]['UK_70'].copy()
        
        T, E = df['time'], df['event']
        cond = df[med_name].astype(bool)
        TA, EA = T[cond], E[cond]
        TB, EB = T[~cond], E[~cond]

        plot_CIF(alpha=alpha_corrected, 
                 TA=TA, EA=EA, labelA="Took "+med_name,
                 TB=TB, EB=EB, labelB="Didn't take "+med_name,
                 ax=ax, title=disease+' '+med_name)
plt.tight_layout()
plt.show()