# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
import pickle
from sklearn.model_selection import RepeatedStratifiedKFold
import numpy as np
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
import pandas as pd
from src.utils.configs import read_parse_config
from src.pt.hyper_opt import train_hyper_opt
import pathlib
from tqdm import tqdm
import optuna
import os

# Load data

In [None]:
seed = 42

path_root = pathlib.Path(os.getcwd())
path_plots = f"{path_root}/plots"
path_data = f"{path_root}/data/age-regression"
df_feats = pd.read_excel(f"{path_data}/features.xlsx", index_col=0)
imms = df_feats.index.to_list()
df = pd.read_excel(f"{path_data}/data.xlsx")
df_controls = df[df['Status'] == 'Control']

# Generate stratification

In [None]:
tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = seed

val_n_splits = 4
val_n_repeats = 4
val_random_state = seed

quantiles = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

gse_count = df_controls['GSE'].value_counts()
stratify_cat_parts_all = {gse: df_controls.index[df_controls['GSE'] == gse].values for gse, count in gse_count.items()}

for part_all, ids_all in (pbar := tqdm(stratify_cat_parts_all.items())):
    pbar.set_description(f"Processing {part_all} ({len(ids_all)})")
    quantiles_all = pd.qcut(df_controls.loc[ids_all, 'Age'].values, quantiles, labels=False, duplicates='drop')
    unique_all, counts_all = np.unique(quantiles_all, return_counts=True)
    
    if max(counts_all) >= len(quantiles):
        k_fold_all = RepeatedStratifiedKFold(
            n_splits=tst_n_splits,
            n_repeats=tst_n_repeats,
            random_state=tst_random_state
        )
        splits_all = k_fold_all.split(X=ids_all, y=quantiles_all, groups=quantiles_all)
        for split_id, (ids_trn_val, ids_tst) in enumerate(splits_all):
            df_controls.loc[ids_all[ids_trn_val], f"Split_{split_id}"] = "trn_val"
            df_controls.loc[ids_all[ids_tst], f"Split_{split_id}"] = "tst"
            
    else:
        for split_id in range(tst_n_splits * tst_n_repeats):
            df_controls.loc[ids_all, f"Split_{split_id}"] = "tst"

samples = {}
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        df_controls[f"Split_{split_id}_Fold_{fold_id}"] = df_controls[f"Split_{split_id}"]
    
    samples[split_id] = {
        'test': df_controls.index[df_controls[f"Split_{split_id}"] == "tst"].values,
        'train_validation': df_controls.index[df_controls[f"Split_{split_id}"] == "trn_val"].values,
        'trains': {},
        'validations': {},
    }
    
    stratify_cat_parts_trnval = {}
    for gse, count in gse_count.items():
        gse_ids = df_controls.index[(df_controls['GSE'] == gse) & (df_controls[f"Split_{split_id}"] == 'trn_val')].values
        if len(gse_ids) > 0:
            stratify_cat_parts_trnval[gse] = gse_ids

    for part_trnval, ids_trnval in stratify_cat_parts_trnval.items():
        quantiles_trnval = pd.qcut(df_controls.loc[ids_trnval, 'Age'].values, quantiles, labels=False, duplicates='drop')
        unique_trnval, counts_trnval = np.unique(quantiles_trnval, return_counts=True)
        k_fold_trnval = RepeatedStratifiedKFold(
            n_splits=val_n_splits,
            n_repeats=val_n_repeats,
            random_state=val_random_state
        )
        splits_trnval = k_fold_trnval.split(X=ids_trnval, y=quantiles_trnval, groups=quantiles_trnval)
        for fold_id, (ids_trn, ids_val) in enumerate(splits_trnval):
            df_controls.loc[ids_trnval[ids_trn], f"Split_{split_id}_Fold_{fold_id}"] = "trn"
            df_controls.loc[ids_trnval[ids_val], f"Split_{split_id}_Fold_{fold_id}"] = "val"
         
    for fold_id in range(val_n_splits * val_n_repeats):
        samples[split_id]['trains'][fold_id] = df_controls.index[df_controls[f"Split_{split_id}_Fold_{fold_id}"] == "trn"].values
        samples[split_id]['validations'][fold_id] = df_controls.index[df_controls[f"Split_{split_id}_Fold_{fold_id}"] == "val"].values

    samples[split_id]['cv_indexes'] = [
        (
            np.where(df_controls.index[df_controls[f"Split_{split_id}"] == "trn_val"].isin(df_controls.index[(df_controls[f"Split_{split_id}"] == "trn_val") & (df_controls[f"Split_{split_id}_Fold_{i}"] == 'trn')]))[0],
            np.where(df_controls.index[df_controls[f"Split_{split_id}"] == "trn_val"].isin(df_controls.index[(df_controls[f"Split_{split_id}"] == "trn_val") & (df_controls[f"Split_{split_id}_Fold_{i}"] == 'val')]))[0],
        )
        for i in range(val_n_splits * val_n_repeats)
    ]
    
# Chekning for non-intersection
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

with open(f"{path_data}/stratification.pickle", 'wb') as handle:
    pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Load stratification

In [None]:
with open(f"{path_data}/stratification.pickle", 'rb') as handle:
    samples = pickle.load(handle)

# Load PytorchTabular configs

## Load Data, Trainer, Optimizer configs

In [None]:
path_configs = f"{path_root}/configs/age-regression"
path_ckpts = f"{path_root}/logs/Age"
pathlib.Path(path_ckpts).mkdir(parents=True, exist_ok=True)
    
data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['target'] = ['Age']
data_config['continuous_cols'] = imms
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['seed'] = seed
trainer_config['checkpoints'] = 'valid_loss'
trainer_config['load_best'] = True
trainer_config['auto_lr_find'] = False
trainer_config['checkpoints_path'] = path_ckpts
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

lr_find_min_lr = 1e-8
lr_find_max_lr = 10
lr_find_num_training = 512
lr_find_mode = "exponential"
lr_find_early_stop_threshold = 8.0

## Load default Models configs

In [None]:
models_archs = ['CategoryEmbeddingModel', 'DANet', 'FTTransformer', 'GANDALF', 'TabNetModel' ]
configs_models_default = {}
configs_models_default['DANet'] = read_parse_config(f"{path_configs}/models/DANetConfig.yaml", DANetConfig)
configs_models_default['FTTransformer'] = read_parse_config(f"{path_configs}/models/FTTransformerConfig.yaml", FTTransformerConfig)
configs_models_default['GANDALF'] = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
configs_models_default['TabNetModel'] = read_parse_config(f"{path_configs}/models/TabNetModelConfig.yaml", TabNetModelConfig)
configs_models_default['CategoryEmbeddingModel'] = read_parse_config(f"{path_configs}/models/CategoryEmbeddingModelConfig.yaml", CategoryEmbeddingModelConfig)

# Training immunomarkers models

## Optuna params

In [None]:
# TPE sampler
n_trials = 4 # 512
opt_seed = seed
n_startup_trials = 2 # 128
n_ei_candidates = 1 # 16

# Init optimization metrics with directions
opt_parts = ['test', 'validation']
opt_metrics = [('mean_absolute_error', 'minimize')]
opt_directions = []
for part in opt_parts:
    for metric_pair in opt_metrics:
        opt_directions.append(f"{metric_pair[1]}")

## Cross-validation and Hyperparameter optimization training

In [None]:
# Dataframes with results
dfs_results = []

# Loop for train-validation/test splits
for split_id, split_dict in samples.items():
    # Loop for train/validation folds
    for fold_id in split_dict['trains']:
        
        test = df_controls.loc[split_dict['test'], imms + ['Age']]
        train = df_controls.loc[split_dict['trains'][fold_id], imms + ['Age']]
        validation = df_controls.loc[split_dict['validations'][fold_id], imms + ['Age']]
        
        # Loop for models archs
        for m_arch in models_archs:
            
            tabular_model_default = TabularModel(
                data_config=data_config,
                model_config=configs_models_default[m_arch],
                optimizer_config=optimizer_config,
                trainer_config=trainer_config,
                verbose=False,
            )
            datamodule = tabular_model_default.prepare_dataloader(train=train, validation=validation, seed=seed)
            
            trials_results = []
            study = optuna.create_study(
                study_name=f"Age_{split_id}_{fold_id}_{m_arch}",
                sampler=optuna.samplers.TPESampler(
                    n_startup_trials=n_startup_trials,
                    n_ei_candidates=n_ei_candidates,
                    seed=opt_seed,
                ),
                directions=opt_directions
            )
            study.optimize(
                func=lambda trial: train_hyper_opt(
                    trial=trial,
                    trials_results=trials_results,
                    opt_metrics=opt_metrics,
                    opt_parts=opt_parts,
                    model_config_default=configs_models_default[m_arch],
                    data_config_default=data_config,
                    optimizer_config_default=optimizer_config,
                    trainer_config_default=trainer_config,
                    experiment_config_default=None,
                    train=train,
                    validation=validation,
                    test=test,
                    datamodule=datamodule,
                    min_lr=lr_find_min_lr,
                    max_lr=lr_find_max_lr,
                    num_training=lr_find_num_training,
                    mode=lr_find_mode,
                    early_stop_threshold=lr_find_early_stop_threshold
                ), 
                n_trials=n_trials, 
                show_progress_bar=False
            )
            df_trials = pd.DataFrame(trials_results)
            df_trials['split_id'] = split_id
            df_trials['fold_id'] = fold_id
            dfs_results.append(df_trials)

# Resulting Dataframe
df_results = pd.concat(dfs_results)            
df_results.sort_values(by=['test_loss'], ascending=[True], inplace=True)
df_results.to_excel(f"{path_ckpts}/results.xlsx")