# Debugging autoreload

In [1]:
%load_ext autoreload
%autoreload 2

# Load packages

In [20]:
import pickle
from sklearn.model_selection import RepeatedStratifiedKFold
import numpy as np
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
import pandas as pd
from src.utils.configs import read_parse_config
from src.pt.hyper_opt import train_hyper_opt
import optuna
import pathlib
import os
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
import distinctipy
import matplotlib.colors as mcolors
import matplotlib.patheffects as pe
from plottable import ColumnDefinition, Table
import yaml

# Load data

In [7]:
seed = 42

path_root = pathlib.Path(os.getcwd())
path_plots = f"{path_root}/plots"
path_data = f"{path_root}/data/immuno-regression"
df_feats = pd.read_excel(f"{path_data}/features.xlsx")
imms = df_feats.columns.to_list()
df = pd.read_excel(f"{path_data}/data.xlsx")

# Generate stratification

In [None]:
tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = seed

val_n_splits = 4
val_n_repeats = 4
val_random_state = seed

stratify_cat_parts_all = {
    'Control': df.index[df['Status'] == 'Control'].values,
    'ESRD': df.index[df['Status'] == 'ESRD'].values,
}

for part_all, ids_all in stratify_cat_parts_all.items():
    trgt_all = df.loc[ids_all, 'Age'].values
    ptp_all = np.ptp(trgt_all)
    num_bins_all = 5
    bins_all = np.linspace(np.min(trgt_all) - 0.1 * ptp_all, np.max(trgt_all) + 0.1 * ptp_all, num_bins_all + 1)
    binned_all = np.digitize(trgt_all, bins_all) - 1
    unique_all, counts_tst = np.unique(binned_all, return_counts=True)
    
    k_fold_all = RepeatedStratifiedKFold(
        n_splits=tst_n_splits,
        n_repeats=tst_n_repeats,
        random_state=tst_random_state
    )
    splits_all = k_fold_all.split(X=ids_all, y=binned_all, groups=binned_all)
    
    for split_id, (ids_trn_val, ids_tst) in enumerate(splits_all):
        df.loc[ids_all[ids_trn_val], f"Split_{split_id}"] = "trn_val"
        df.loc[ids_all[ids_tst], f"Split_{split_id}"] = "tst"

samples = {}
for split_id in range(tst_n_splits * tst_n_repeats):
    samples[split_id] = {
        'test': df.index[df[f"Split_{split_id}"] == "tst"].values,
        'train_validation': df.index[df[f"Split_{split_id}"] == "trn_val"].values,
        'trains': {},
        'validations': {},
    }

    stratify_cat_parts_trnval = {
        'Control': df.index[(df['Status'] == 'Control') & (df[f"Split_{split_id}"] == 'trn_val')].values,
        'ESRD': df.index[(df['Status'] == 'ESRD') & (df[f"Split_{split_id}"] == 'trn_val')].values,
    }

    for part_trnval, ids_trnval in stratify_cat_parts_trnval.items():
        trgt_trnval = df.loc[ids_trnval, 'Age'].values
        ptp_trnval = np.ptp(trgt_trnval)
        num_bins_trnval = 5
        bins_trnval = np.linspace(np.min(trgt_trnval) - 0.1 * ptp_trnval, np.max(trgt_trnval) + 0.1 * ptp_trnval, num_bins_trnval + 1)
        binned_trnval = np.digitize(trgt_trnval, bins_trnval) - 1
        unique_trnval, counts_trnval = np.unique(binned_trnval, return_counts=True)
        k_fold_trnval = RepeatedStratifiedKFold(
            n_splits=val_n_splits,
            n_repeats=val_n_repeats,
            random_state=val_random_state
        )
        splits_trnval = k_fold_trnval.split(X=ids_trnval, y=binned_trnval, groups=binned_trnval)
        
        for fold_id, (ids_trn, ids_val) in enumerate(splits_trnval):
            df.loc[ids_trnval[ids_trn], f"Split_{split_id}_Fold_{fold_id}"] = "trn"
            df.loc[ids_trnval[ids_val], f"Split_{split_id}_Fold_{fold_id}"] = "val"
         
    for fold_id in range(val_n_splits * val_n_repeats):
        samples[split_id]['trains'][fold_id] = df.index[df[f"Split_{split_id}_Fold_{fold_id}"] == "trn"].values
        samples[split_id]['validations'][fold_id] = df.index[df[f"Split_{split_id}_Fold_{fold_id}"] == "val"].values

    samples[split_id]['cv_indexes'] = [
        (
            np.where(df.index[df[f"Split_{split_id}"] == "trn_val"].isin(df.index[(df[f"Split_{split_id}"] == "trn_val") & (df[f"Split_{split_id}_Fold_{i}"] == 'trn')]))[0],
            np.where(df.index[df[f"Split_{split_id}"] == "trn_val"].isin(df.index[(df[f"Split_{split_id}"] == "trn_val") & (df[f"Split_{split_id}_Fold_{i}"] == 'val')]))[0],
        )
        for i in range(val_n_splits * val_n_repeats)
    ]
    
# Chekning for non-intersection
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                raise ValueError(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

with open(f"{path_data}/stratification.pickle", 'wb') as handle:
    pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Load stratification

In [3]:
with open(f"{path_data}/stratification.pickle", 'rb') as handle:
    samples = pickle.load(handle)

# Load PytorchTabular configs

## Load Data, Trainer, Optimizer configs

In [4]:
path_configs = f"{path_root}/configs/immuno-regression"

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['seed'] = seed
trainer_config['checkpoints'] = 'valid_loss'
trainer_config['load_best'] = True
trainer_config['auto_lr_find'] = False
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

lr_find_min_lr = 1e-8
lr_find_max_lr = 10
lr_find_num_training = 512
lr_find_mode = "exponential"
lr_find_early_stop_threshold = 8.0

## Load default Models configs

In [5]:
models_archs = ['DANet', 'FTTransformer', 'GANDALF', 'TabNetModel' , 'CategoryEmbeddingModel']
configs_models_default = {}
configs_models_default['DANet'] = read_parse_config(f"{path_configs}/models/DANetConfig.yaml", DANetConfig)
configs_models_default['FTTransformer'] = read_parse_config(f"{path_configs}/models/FTTransformerConfig.yaml", FTTransformerConfig)
configs_models_default['GANDALF'] = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
configs_models_default['TabNetModel'] = read_parse_config(f"{path_configs}/models/TabNetModelConfig.yaml", TabNetModelConfig)
configs_models_default['CategoryEmbeddingModel'] = read_parse_config(f"{path_configs}/models/CategoryEmbeddingModelConfig.yaml", CategoryEmbeddingModelConfig)

# Training immunomarkers models

## Optuna params

In [6]:
# TPE sampler
n_trials = 512
opt_seed = seed
n_startup_trials = 128
n_ei_candidates = 16

# Init optimization metrics with directions
opt_parts = ['test', 'validation']
opt_metrics = [('mean_absolute_error', 'minimize'), ('pearson_corrcoef', 'maximize')]
opt_directions = []
for part in opt_parts:
    for metric_pair in opt_metrics:
        opt_directions.append(f"{metric_pair[1]}")

## Cross-validation and Hyperparameter optimization training

In [None]:
# Loop for immunomarkers
for imm in imms:
    
    feats = df_feats[imm].to_list()
    data_config['target'] = [imm]
    data_config['continuous_cols'] = feats
    
    path_ckpts = f"{path_root}/logs/Immunomarkers/{imm}"
    pathlib.Path(path_ckpts).mkdir(parents=True, exist_ok=True)
    trainer_config['checkpoints_path'] = path_ckpts
    
    # Dataframes with results
    dfs_results = []
    
    # Loop for train-validation/test splits
    for split_id, split_dict in samples.items():
        # Loop for train/validation folds
        for fold_id in split_dict['trains']:
            
            test = df.loc[split_dict['test'], feats + [imm]]
            train = df.loc[split_dict['trains'][fold_id], feats + [imm]]
            validation = df.loc[split_dict['validations'][fold_id], feats + [imm]]
            
            # Loop for models archs
            for m_arch in models_archs:
                
                tabular_model_default = TabularModel(
                    data_config=data_config,
                    model_config=configs_models_default[m_arch],
                    optimizer_config=optimizer_config,
                    trainer_config=trainer_config,
                    verbose=False,
                )
                datamodule = tabular_model_default.prepare_dataloader(train=train, validation=validation, seed=seed)
                
                trials_results = []
                study = optuna.create_study(
                    study_name=f"{imm}_{split_id}_{fold_id}_{m_arch}",
                    sampler=optuna.samplers.TPESampler(
                        n_startup_trials=n_startup_trials,
                        n_ei_candidates=n_ei_candidates,
                        seed=opt_seed,
                    ),
                    directions=opt_directions
                )
                study.optimize(
                    func=lambda trial: train_hyper_opt(
                        trial=trial,
                        trials_results=trials_results,
                        opt_metrics=opt_metrics,
                        opt_parts=opt_parts,
                        model_config_default=configs_models_default[m_arch],
                        data_config_default=data_config,
                        optimizer_config_default=optimizer_config,
                        trainer_config_default=trainer_config,
                        experiment_config_default=None,
                        train=train,
                        validation=validation,
                        test=test,
                        datamodule=datamodule,
                        min_lr=lr_find_min_lr,
                        max_lr=lr_find_max_lr,
                        num_training=lr_find_num_training,
                        mode=lr_find_mode,
                        early_stop_threshold=lr_find_early_stop_threshold
                    ), 
                    n_trials=n_trials, 
                    show_progress_bar=False
                )
                df_trials = pd.DataFrame(trials_results)
                df_trials['split_id'] = split_id
                df_trials['fold_id'] = fold_id
                dfs_results.append(df_trials)
    
    # Resulting Dataframe for immunomarker
    df_results = pd.concat(dfs_results)            
    df_results.sort_values(by=['test_loss'], ascending=[True], inplace=True)
    df_results.to_excel(f"{path_ckpts}/results.xlsx")

# Plotting

## Correlation heatmap

In [9]:
df_epi_imm_corr = pd.DataFrame(index=imms, columns=list(range(1, 101)))
for imm in imms:
    corrs = []
    for cpg in df_feats[imm]:
        res = stats.pearsonr(df[imm], df[cpg], alternative='two-sided')
        corrs.append(abs(res.statistic))
    df_epi_imm_corr.loc[imm, :] = sorted(corrs, reverse=True)
    
df_fig = df_epi_imm_corr.astype(float)
sns.set_theme(style='ticks', font_scale=1.0)
fig, ax = plt.subplots(figsize=(30, 10))
heatmap = sns.heatmap(
    df_fig,
    annot=False,
    cmap='hot',
    linewidth=0.1,
    linecolor='black',
    cbar_kws={
        'orientation': 'horizontal',
        'location': 'top',
        'fraction': 0.05,
        'pad': 0.025,
        'aspect': 60
    },
    annot_kws={"size": 12},
    ax=ax
)
ax.set_xlabel('Top CpGs')
ax.set_ylabel('')
heatmap_pos = heatmap.get_position()
ax.figure.axes[-1].set_title(fr"|Pearson $\rho$|", fontsize='large')
ax.figure.axes[-1].tick_params(labelsize='large')
for spine in ax.figure.axes[-1].spines.values():
    spine.set_linewidth(1)
plt.savefig(f"{path_plots}/imms_cpgs_correlation.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_plots}/imms_cpgs_correlation.pdf", bbox_inches='tight')
plt.close(fig)

## Models results

In [None]:
n_rows = 4 * 3
n_cols = 8
fig_height = 20
fig_width = 35

imm_colors = distinctipy.get_colors(n_colors=len(imms), exclude_colors=[mcolors.hex2color(mcolors.CSS4_COLORS['gray'])], rng=42)

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.8, 0.2]*4, gridspec_kw={'wspace':0.35, 'hspace': 0.05}, sharey=False, sharex=False)

for imm_id, imm in enumerate(imms):
    imm_color = imm_colors[imm_id]
    imm_metrics = pd.read_excel(f"{path_root}/models/Immunomarkers/{imm}/metrics.xlsx", index_col=0)
    with open(f"{path_root}/models/Immunomarkers/{imm}/config.yml") as f:
        imm_config = yaml.safe_load(f)
    imm_df = pd.read_excel(f"{path_root}/models/Immunomarkers/{imm}/df.xlsx", index_col=0)
    imm_df.rename(columns={f"{imm}_log": imm}, inplace=True)
    
    row_id, col_id = divmod(imm_id, n_cols)
    row_id_table = row_id * 3
    row_id_scatter = row_id * 3 + 1
    row_id_empty = row_id * 3 + 2

    q01 = df[imm].quantile(0.01)
    q99 = df[imm].quantile(0.99)

    df_metrics = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$"], columns=['Train', 'Validation', 'Test'])
    df_metrics.at['MAE', 'Train'] = f"{imm_metrics.at['Train', 'mean_absolute_error']:0.3f}"
    df_metrics.at['MAE', 'Validation'] = f"{imm_metrics.at['Validation', 'mean_absolute_error']:0.3f}"
    df_metrics.at['MAE', 'Test'] = f"{imm_metrics.at['Test', 'mean_absolute_error']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Train'] = f"{imm_metrics.at['Train', 'pearson_corrcoef']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Validation'] = f"{imm_metrics.at['Validation', 'pearson_corrcoef']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Test'] = f"{imm_metrics.at['Test', 'pearson_corrcoef']:0.3f}"
    
    col_defs = [
        ColumnDefinition(
            name="index",
            title=imm_config['_model_name'].replace('Model', ''),
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            group=fr"$\mathbf{{{imm}}}$",
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left",
            group=fr"$\mathbf{{{imm}}}$",
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=1.5,
            group=fr"$\mathbf{{{imm}}}$",
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
            group=fr"$\mathbf{{{imm}}}$",
        )
    ]

    table = Table(
        df_metrics,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[row_id_table, col_id],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    kdeplot = sns.kdeplot(
        data=imm_df.loc[imm_df['Group'] != 'Test', :],
        x=imm,
        y='Prediction',
        fill=True,
        cbar=False,
        color=imm_color,
        cut=0,
        legend=False,
        ax=axs[row_id_scatter, col_id]
    )
    scatter = sns.scatterplot(
        data=imm_df.loc[imm_df['Group'] == 'Test', :],
        x=imm,
        y="Prediction",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=35,
        color=imm_color,
        ax=axs[row_id_scatter, col_id],
    )
    axs[row_id_scatter, col_id].axline((0, 0), slope=1, color="black", linestyle=":")
    axs[row_id_scatter, col_id].set_xlim(q01, q99)
    axs[row_id_scatter, col_id].set_ylim(q01, q99)
    axs[row_id_scatter, col_id].set_xlabel(imm, color=imm_color, path_effects=[pe.withStroke(linewidth=1.0, foreground="black")])
    
    axs[row_id_empty, col_id].axis('off')

fig.tight_layout()    
fig.savefig(f"{path_plots}/imms_models.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_plots}/imms_models.pdf", bbox_inches='tight')
plt.close(fig)