# Debugging autoreload

In [None]:
import optuna
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
import optuna
import pathlib
import os
from omegaconf import OmegaConf
from tqdm import tqdm
from sklearn.impute import KNNImputer

import warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*exists and is not empty.*")
warnings.filterwarnings("ignore", ".*is smaller than the logging interval Trainer.*")


# Load immunomarkers models

In [None]:
feats_imm = pd.read_excel(f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/059_imm_data_selection/feats_selected.xlsx", index_col=0).index.values

epi_data_type = 'no_harm'
imm_data_type = 'imp_source(imm)_method(knn)_params(5)' # 'origin' 'imp_source(imm)_method(knn)_params(5)' 'imp_source(imm)_method(miceforest)_params(2)'

selection_method = 'mrmr' # 'f_regression' 'spearman' 'mrmr'
n_feats = 100

path_imm = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{imm_data_type}/{epi_data_type}/{selection_method}_{n_feats}"
path_save = f"{path_imm}/EpImAge"
pathlib.Path(path_save).mkdir(parents=True, exist_ok=True)

df_models = pd.read_excel(f"{path_imm}/best_models_v3.xlsx", index_col=0)
                
imm_epi_feats = {}              
imm_models = {}
for imm in (pbar := tqdm(feats_imm)):
    pbar.set_description(f"Processing {imm}")
    imm_epi_feats[imm] = pd.read_excel(f"{path_imm}/{imm}/feats_con.xlsx", index_col=0).index.values.tolist()
    imm_path_model = f"{path_imm}/{imm}/pytorch_tabular/candidates/{df_models.at[imm, 'model']}/{df_models.at[imm, 'directory']}/model.ckpt"
    head, tail = os.path.split(imm_path_model)
    imm_models[imm] = TabularModel.load_model(f"{head}")

feats_epi_cmn = list(set.union(*[set(x) for x in imm_epi_feats.values()]))
print(f"Number of CpGs: {len(feats_epi_cmn)}")

# Load epigenetics data

In [None]:
path_epi = "D:/YandexDisk/Work/bbd/immunology/003_EpImAge/epi"
feats_pheno = ['Age', 'Sex', 'Status', 'Tissue']
gpls = [f.name for f in os.scandir(path_epi) if f.is_dir()]
gse_missed_cpgs = {}
dfs = []
for gpl in gpls:
    print(gpl)
    gses = [f.name for f in os.scandir(f"{path_epi}/{gpl}") if f.is_dir()]
    for gse in (pbar := tqdm(gses)):
        pbar.set_description(f"Processing {gse}")
        if gse == 'GSEUNN':
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/{epi_data_type}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/{epi_data_type}/pheno.csv", index_col='index')
        elif gse == 'GSE53740':
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/pheno.csv", index_col=0)
            df_pheno.drop(df_pheno.index[df_pheno['Status'] == 'Unknown'], inplace=True)
        elif gse == 'GSE87648':
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/pheno.csv", index_col=0)
            df_pheno.drop(df_pheno.index[df_pheno['Status'] == 'HS'], inplace=True)
        else:
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/pheno.csv", index_col=0)
        gse_missed_cpgs[gse] = len(set(feats_epi_cmn) - set(df_betas.columns))
        exist_cpgs = list(set.intersection(set(df_betas.columns), set(feats_epi_cmn)))
        df = pd.merge(df_pheno.loc[:, feats_pheno], df_betas.loc[:, exist_cpgs], left_index=True, right_index=True)
        df.insert(0, 'GPL', gpl)
        df.insert(0, 'GSE', gse)
        dfs.append(df)
        
df_gse_missed_cpgs = pd.DataFrame.from_dict(gse_missed_cpgs, orient='index', columns=['Missed CpGs'])
df_gse_missed_cpgs.to_excel(f"{path_save}/gse_missed_cpgs.xlsx", index=True, index_label='GSE')

df = pd.concat(dfs, verify_integrity=True)

# Impute missing values

In [None]:
n_neighbors = 5
X = df.loc[:, feats_epi_cmn + ['Age']].values
print(f'Missing before imputation: {np.isnan(X).sum()}')
imputer = KNNImputer(n_neighbors=n_neighbors)
X_imptd = imputer.fit_transform(X)
print(f'Missing after imputation: {np.isnan(X_imptd).sum()}')

In [None]:
df.loc[:, feats_epi_cmn + ['Age']] = X_imptd

# Calculate immunomarkers

In [None]:
for imm in (pbar := tqdm(feats_imm)):
    pbar.set_description(f"Processing {imm}")
    df[f"{imm}_log"] = imm_models[imm].predict(df.loc[:, imm_epi_feats[imm]])
    df[imm] = np.exp(df[f"{imm}_log"])

In [None]:
df[['GPL', 'GSE'] + feats_pheno + list(feats_imm) + [f"{imm}_log" for imm in feats_imm]].to_excel(f"{path_save}/data.xlsx")
df.to_pickle(f"{path_save}/data_full.pkl")

# Check models on GSEUNN

In [None]:
tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337
tst_split_id = 5

val_n_splits = 4
val_n_repeats = 2
val_random_state = 1337
val_fold_id = 5

fn_samples = f"samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats})"
with open(f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{fn_samples}.pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }

        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

split_dict = samples[tst_split_id]

df_models_check = pd.DataFrame(index=feats_imm)
for imm in (pbar := tqdm(feats_imm)):
    pbar.set_description(f"Processing {imm}")
    data_imm = pd.read_excel(f"{path_imm}/{imm}/data.xlsx", index_col=0)
    
    y_train_real = torch.from_numpy(data_imm.loc[split_dict['trains'][val_fold_id], f"{imm}_log"].values)
    y_validation_real = torch.from_numpy(data_imm.loc[split_dict['validations'][val_fold_id], f"{imm}_log"].values)
    y_test_real = torch.from_numpy(data_imm.loc[split_dict['test'], f"{imm}_log"].values)
    
    y_train_pred = torch.from_numpy(df.loc[split_dict['trains'][val_fold_id], f"{imm}_log"].values)
    y_validation_pred = torch.from_numpy(df.loc[split_dict['validations'][val_fold_id], f"{imm}_log"].values)
    y_test_pred = torch.from_numpy(df.loc[split_dict['test'], f"{imm}_log"].values)
    
    df_models_check.at[imm, 'train_mae_before'] = df_models.at[imm, 'train_mean_absolute_error']
    df_models_check.at[imm, 'validation_mae_before'] = df_models.at[imm, 'validation_mean_absolute_error']
    df_models_check.at[imm, 'test_mae_before'] = df_models.at[imm, 'test_mean_absolute_error']
    df_models_check.at[imm, 'train_mae_after'] = mean_absolute_error(y_train_pred, y_train_real).numpy()
    df_models_check.at[imm, 'validation_mae_after'] = mean_absolute_error(y_validation_pred, y_validation_real).numpy()
    df_models_check.at[imm, 'test_mae_after'] = mean_absolute_error(y_test_pred, y_test_real).numpy()
    
    df_models_check.at[imm, 'train_rho_before'] = df_models.at[imm, 'train_pearson_corrcoef']
    df_models_check.at[imm, 'validation_rho_before'] = df_models.at[imm, 'validation_pearson_corrcoef']
    df_models_check.at[imm, 'test_rho_before'] = df_models.at[imm, 'test_pearson_corrcoef']
    df_models_check.at[imm, 'train_rho_after'] = pearson_corrcoef(y_train_pred, y_train_real).numpy()
    df_models_check.at[imm, 'validation_rho_after'] = pearson_corrcoef(y_validation_pred, y_validation_real).numpy()
    df_models_check.at[imm, 'test_rho_after'] = pearson_corrcoef(y_test_pred, y_test_real).numpy()

df_models_check['train_mae_diff'] = df_models_check['train_mae_after'] - df_models_check['train_mae_before']
df_models_check['validation_mae_diff'] = df_models_check['validation_mae_after'] - df_models_check['validation_mae_before']
df_models_check['test_mae_diff'] = df_models_check['test_mae_after'] - df_models_check['test_mae_before']

df_models_check['train_rho_diff'] = df_models_check['train_rho_after'] - df_models_check['train_rho_before']
df_models_check['validation_rho_diff'] = df_models_check['validation_rho_after'] - df_models_check['validation_rho_before']
df_models_check['test_rho_diff'] = df_models_check['test_rho_after'] - df_models_check['test_rho_before']

df_models_check.to_excel(f"{path_save}/models_check.xlsx")