# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
from scipy import stats
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import itertools
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.pt.hyper_opt import train_hyper_opt
from src.utils.hash import dict_hash
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import optuna
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.cm
import matplotlib as mpl
from statsmodels.stats.multitest import multipletests
import re
from itertools import chain
from pathlib import Path
import requests
from sklearn.decomposition import PCA
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection
from sklearn.manifold import MDS, Isomap, TSNE
import missingno as msno
from collections import Counter
import functools
from sklearn.cluster import DBSCAN, HDBSCAN
from regression_bias_corrector import LinearBiasCorrector


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

def form_bar(base):
    def formatter(x):
        return f'{str(int(round(x * base)))}/{base}'
    return formatter


# Process data

In [None]:
path = f"E:/YandexDisk/Work/bbd/mriya"

df_params = pd.read_excel(f"{path}/Испытуемые Яндекс.xlsx", sheet_name='Parameters', index_col=0)
df_params['analysis_type'].replace(
    {
        'Sphygmocardiography': 'Сфигмография',
        'Echocardiography': 'Эхокардиография',
        'ECG': 'Электрокардиография',
        'BP': 'Биохимический анализ крови',
        'CBC': 'Общий анализ крови',
        'Anthropometry': 'Антропометрия'   
    },
    inplace=True
)

df_blood = pd.read_excel(f"{path}/09_samples_from_mriya/Испытуемые Мрия.xlsx", sheet_name='Blood', index_col=1)
df_blood.index = df_blood.index.astype(str)
df_heart = pd.read_excel(f"{path}/09_samples_from_mriya/Испытуемые Мрия.xlsx", sheet_name='Heart', index_col=1)
df_heart.index = df_heart.index.astype(str)
df_heart['sample_date'] = pd.to_datetime(df_heart['sample_date'])
df_heart['birthday'] = pd.to_datetime(df_heart['birthday'])
df_blood.insert(3, 'Age', (df_blood['sample_date'] - df_blood['birthday']) / np.timedelta64(1, 'D') / 365.25)
df_heart.insert(3, 'Age', (df_heart['sample_date'] - df_heart['birthday']) / np.timedelta64(1, 'D') / 365.25)

suffixes=('', '_heart')
df = pd.merge(df_blood, df_heart, left_index=True, right_index=True, how='outer', suffixes=suffixes)
cols_cmn = df_blood.columns.intersection(df_heart.columns).to_list()
        
cols_types = df_params['analysis_type'].dropna().unique()
cols_sets = {x: df.columns.intersection(df_params.index[df_params['analysis_type'] == x]).to_list() for x in cols_types}
df = df.loc[:, cols_cmn + list(chain.from_iterable(cols_sets.values()))]
df = df[df['Age'].notna()]

df.to_excel(f"{path}/09_samples_from_mriya/data_mriya.xlsx")

# NaNs by features groups
with pd.ExcelWriter(f"{path}/09_samples_from_mriya/nans.xlsx", engine='xlsxwriter') as writer:
    for col_set, cols in cols_sets.items():
        data = df.loc[:, cols]
        nan_feats = data.isna().sum(axis=0).to_frame(name="Number of NaNs")
        nan_feats["% of NaNs"] = nan_feats["Number of NaNs"] / data.shape[0] * 100
        nan_feats["Number of not-NaNs"] = data.notna().sum(axis=0)
        nan_feats.sort_values(["% of NaNs"], ascending=[True], inplace=True)
        nan_feats.to_excel(writer, sheet_name=col_set)
        
# Nans and correlations
feats_cnt = df_params.index[df_params['data_type'].isin(['decimal', 'integer'])].intersection(list(chain.from_iterable(cols_sets.values()))).to_list()

df_feats = pd.DataFrame(index=feats_cnt)

nan_feats = df[df_feats.index.to_list()].isna().sum(axis=0).to_frame(name="Number of NaNs")
df_feats.loc[df_feats.index, "Number of NaNs"] = nan_feats.loc[df_feats.index, "Number of NaNs"]
df_feats["% of NaNs"] = nan_feats["Number of NaNs"] / df.shape[0] * 100
df_feats["Number of not-NaNs"] = df[df_feats.index.to_list()].notna().sum(axis=0)
df_feats.sort_values(["% of NaNs"], ascending=[True], inplace=True)

df_feats[r"Pearson $\rho$"] = 0.0
for f in df_feats.index:
    df_tmp = df.loc[:, ['Age', f]].dropna(axis=0, how='any')
    if df_tmp.shape[0] > 1:
        if df_tmp[f].nunique() > 1:
            vals_1 = df_tmp.loc[:, 'Age'].values
            vals_2 = df_tmp.loc[:, f].values
            rho, _ = stats.pearsonr(vals_1, vals_2)
            df_feats.at[f, r"Pearson $\rho$"] = rho
        else:
            df_feats.at[f, r"Pearson $\rho$"] = 0.0

df_feats.to_excel(f"{path}/09_samples_from_mriya/feats_with_metrics.xlsx", index_label="Features")

df_fig = df_feats.copy()
df_fig['Features'] = df_fig.index
f_cmap = sns.color_palette("coolwarm", as_cmap=True)
# f_norm = mcolors.Normalize(vmin=min(df_fig[r"Pearson $\rho$"]), vmax=max(df_fig[r"Pearson $\rho$"])) 
f_norm = mcolors.TwoSlopeNorm(vcenter=0.0, vmin=min(df_fig[r"Pearson $\rho$"]), vmax=max(df_fig[r"Pearson $\rho$"]))
f_colors = {}
for cval in df_fig[r"Pearson $\rho$"]:
    f_colors.update({cval: f_cmap(f_norm(cval))})
    
sns.set_theme(style='ticks')
fig, ax = plt.subplots(figsize=(10, 30), layout='constrained')
barplot = sns.barplot(
    data=df_fig,
    x='Number of not-NaNs',
    y='Features',
    hue=r"Pearson $\rho$",
    edgecolor='black',
    palette=f_colors,
    dodge=False,
    ax=ax
)
for container in barplot.containers:
    barplot.bar_label(container, label_type='edge', color='gray', fmt='%d', fontsize=8, padding=4.0)
ax.set_ylabel('')
ax.set(yticklabels=df_fig.index.to_list())
ax.get_legend().remove()
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
ax.set_xlabel('Количество записей not-NaN')
sm = plt.cm.ScalarMappable(cmap=f_cmap, norm=f_norm)
sm.set_array([])
cbar = barplot.figure.colorbar(sm, orientation="horizontal")
cbar.set_label("Корреляция с возрастом")
plt.savefig(f"{path}/09_samples_from_mriya/feats_nans_and_age_correlation.pdf", bbox_inches='tight')
plt.savefig(f"{path}/09_samples_from_mriya/feats_nans_and_age_correlation.png", bbox_inches='tight', dpi=200)
plt.close(fig)

# Models inference

In [None]:
path = f"E:/YandexDisk/Work/bbd/mriya"

path_models = f"{path}/models/oct2025"

expl_type = 'current'

feat_trgt = 'Age'

df = pd.read_excel(f"{path}/09_samples_from_mriya/data_mriya.xlsx", index_col=0)
df.index = df.index.astype(str)

models_feats_sets = {
    'Эхокардиография': 'DANet/2',
    'Сфигмография': 'DANet/101',
    'Биохимический анализ крови': 'DANet/2',
    'Электрокардиография': 'DANet/28',
    'Антропометрия': 'DANet/148',
    'Общий анализ крови': 'DANet/394',
    'Все': 'DANet/283',
}

likelihood_feats_sets = {
    'Эхокардиография': 0.78,
    'Сфигмография': 0.85,
    'Биохимический анализ крови': 0.6,
    'Электрокардиография': 0.65,
    'Антропометрия': 0.5,
    'Общий анализ крови': 0.4,
    'Все': 1.0,
}

colors_feats_sets = {
    'Эхокардиография': 'darkcyan',
    'Сфигмография': 'mediumorchid',
    'Биохимический анализ крови': 'goldenrod',
    'Электрокардиография': 'dodgerblue',
    'Антропометрия': 'chartreuse',
    'Общий анализ крови': 'crimson',
    'Все': 'gray',
}

df[f'BioAge Acceleration'] = 0.0
df[f'BioAge'] = 0.0

for feats_set in models_feats_sets:
    data = pd.read_excel(f"{path_models}/{feats_set}/data.xlsx", index_col=0)
    feats = pd.read_excel(f"{path_models}/{feats_set}/feats.xlsx", index_col=0)
    results = pd.read_excel(f"{path_models}/{feats_set}/models/{models_feats_sets[feats_set]}/df.xlsx", index_col=0)
    metrics = pd.read_excel(f"{path_models}/{feats_set}/models/{models_feats_sets[feats_set]}/metrics.xlsx", index_col=0)
    df_shap = pd.read_excel(f"{path_models}/{feats_set}/models/{models_feats_sets[feats_set]}/explanation.xlsx", index_col=0)
    model = TabularModel.load_model(f"{path_models}/{feats_set}/models/{models_feats_sets[feats_set]}")
    corrector = LinearBiasCorrector()
    corrector.fit(results.loc[results['Group'] == 'Train', feat_trgt].values, results.loc[results['Group'] == 'Train', 'Prediction'].values)
    
    df[f'Prediction {feats_set}'] = model.predict(df)
    df[f'Prediction Corrected {feats_set}'] = corrector.predict(df[f'Prediction {feats_set}'])
    df[f'Error {feats_set}'] = df[f'Prediction {feats_set}'] - df[feat_trgt]
    df[f'Error Corrected {feats_set}'] = df[f'Prediction Corrected {feats_set}'] - df[feat_trgt]
    
    if df.loc[:, f'Prediction {feats_set}'].shape[0] > 0:

        fig, axs = plt.subplot_mosaic(
            [
                ['table', 'table'],
                ['scatter', 'violin'],
            ],
            figsize=(9, 6),
            height_ratios=[1, 4],
            width_ratios=[3, 1.5],
            gridspec_kw={
                # "bottom": 0.14,
                # "top": 0.95,
                # "left": 0.1,
                # "right": 0.5,
                "wspace": 0.35,
                "hspace": 0.05,
            },
        )
        
        df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test'])
        for part in ['Train', 'Validation', 'Test']:
            df_table.at['MAE', part] = f"{metrics.at[part, 'mean_absolute_error_unbiased']:0.3f}"
            df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{metrics.at[part, 'pearson_corrcoef_unbiased']:0.3f}"
            df_table.at["Bias", part] = f"{metrics.at[part, 'bias_unbiased']:0.3f}"

        col_defs = [
            ColumnDefinition(
                name="index",
                title='',
                textprops={"ha": "center", "weight": "bold"},
                width=2.5,
                # border="both",
            ),
            ColumnDefinition(
                name="Train",
                textprops={"ha": "left"},
                width=1.5,
                border="left",
            ),
            ColumnDefinition(
                name="Validation",
                textprops={"ha": "left"},
                width=1.5,
            ),
            ColumnDefinition(
                name="Test",
                textprops={"ha": "left"},
                width=1.5,
            )
        ]
        table = Table(
            df_table,
            column_definitions=col_defs,
            row_dividers=True,
            footer_divider=False,
            ax=axs['table'],
            textprops={"fontsize": 8},
            row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
            col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
            column_border_kw={"linewidth": 1, "linestyle": "-"},
        ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])
        axs['table'].set_title(feats_set, fontsize='large')

        xy_min = min(min(results[[feat_trgt, 'Prediction Unbiased']].values.flatten()), min(df[[feat_trgt, f'Prediction Corrected {feats_set}']].values.flatten()))
        xy_max = max(max(results[[feat_trgt, 'Prediction Unbiased']].values.flatten()), max(df[[feat_trgt, f'Prediction Corrected {feats_set}']].values.flatten()))
        xy_ptp = xy_max - xy_min

        # kdeplot = sns.kdeplot(
        #     data=results,
        #     x=feat_trgt,
        #     y='Prediction Unbiased',
        #     fill=True,
        #     cbar=False,
        #     thresh=0.05,
        #     color='lightsteelblue',
        #     legend=False,
        #     ax=axs['scatter']
        # )
        scatter = sns.scatterplot(
            data=results,
            x=feat_trgt,
            y="Prediction Unbiased",
            linewidth=0.5,
            alpha=0.8,
            edgecolor="k",
            s=20,
            color='lightsteelblue',
            ax=axs['scatter'],
        )
        df_fig = df.dropna(subset=[feat_trgt, f'Prediction Corrected {feats_set}'])
        scatter = sns.scatterplot(
            data=df_fig,
            x=feat_trgt,
            y=f'Prediction Corrected {feats_set}',
            linewidth=1.0,
            alpha=1.0,
            edgecolor="k",
            s=50,
            color='crimson',
            ax=axs['scatter'],
        )
        bisect = sns.lineplot(
            x=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
            y=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
            linestyle='--',
            color='black',
            linewidth=1.0,
            ax=axs['scatter']
        )
        # regplot = sns.regplot(
        #     data=results,
        #     x=feat_trgt,
        #     y='Prediction Unbiased',
        #     color='k',
        #     scatter=False,
        #     truncate=False,
        #     ax=axs['scatter']
        # )
        axs['scatter'].set_xlim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
        axs['scatter'].set_ylim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
        axs['scatter'].set_ylabel("Биологический возраст")
        axs['scatter'].set_xlabel("Возраст")

        violin = sns.violinplot(
            data=results,
            x=[0] * results.shape[0],
            y='Error Unbiased',
            color=make_rgb_transparent(mcolors.to_rgb('lightsteelblue'), (1, 1, 1), 0.5),
            density_norm='width',
            saturation=0.75,
            linewidth=1.0,
            ax=axs['violin'],
            legend=False,
        )
        swarm = sns.swarmplot(
            data=df_fig,
            x=[0] * df_fig.shape[0],
            y=f'Error Corrected {feats_set}',
            color='crimson',
            linewidth=1.0,
            ax=axs['violin'],
            size=8,
            legend=False,
        )
        # swarm = sns.swarmplot(
        #     data=results,
        #     x=[0] * results.shape[0],
        #     y='Error Unbiased',
        #     color='lightsteelblue',
        #     linewidth=0.5,
        #     ax=axs['violin'],
        #     size= 50 / np.sqrt(results.shape[0]),
        #     legend=False,
        # )
        axs['violin'].set_ylabel('Возрастная акселерация')
        axs['violin'].set_xlabel('')
        axs['violin'].set(xticklabels=[]) 
        axs['violin'].set(xticks=[])
        
        fig.savefig(f"{path}/09_samples_from_mriya/model_{feats_set}.png", bbox_inches='tight', dpi=200)
        fig.savefig(f"{path}/09_samples_from_mriya/model_{feats_set}.pdf", bbox_inches='tight')
        plt.close(fig)
    
    feats_in = feats.index.values
    df_corr = pd.DataFrame(index=feats_in, columns=['rho'])
    for f in tqdm(feats_in):
        df_tmp = data.loc[:, [feat_trgt, f]].dropna(axis=0, how='any')
        if df_tmp.shape[0] > 1:
            vals_1 = df_tmp.loc[:, feat_trgt].values
            vals_2 = df_tmp.loc[:, f].values
            df_corr.at[f, 'rho'], _ = stats.pearsonr(vals_1, vals_2)
    df_corr.dropna(axis=0, how='any', inplace=True)
    df_corr.insert(1, "abs(rho)", df_corr['rho'].abs())
    df_corr.sort_values(["abs(rho)"], ascending=[False], inplace=True)
    df_corr = df_corr.apply(pd.to_numeric)
    
    n_cols = 6
    
    n_rows = int(np.ceil(len(feats_in) / n_cols))
    n_empty = n_rows * n_cols - len(feats_in)
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(
        nrows=n_rows,
        ncols=n_cols,
        figsize=(n_cols * 3.0, n_rows * 2.5),
        gridspec_kw={'wspace':0.10, 'hspace': 0.05}, 
        sharex=True,
        layout='constrained'
    )
    if axs.ndim > 1:
        for feat_id, feat in enumerate(df_corr.index.values):
            row_id, col_id = divmod(feat_id, n_cols)
            scatter = sns.scatterplot(
                data=data,
                x=feat_trgt,
                y=feat,
                color='lightsteelblue',
                linewidth=0.5,
                alpha=0.75,
                edgecolor="k",
                s=10,
                ax=axs[row_id, col_id]
            )
            scatter = sns.scatterplot(
                data=df,
                x=feat_trgt,
                y=feat,
                color='crimson',
                linewidth=1.0,
                alpha=1.0,
                edgecolor="k",
                s=35,
                ax=axs[row_id, col_id]
            )
            axs[row_id, col_id].set_title(fr"Pearson $\rho$: {df_corr.loc[feat, 'rho']:0.3f}")
            y_labe_fontsize = min(15 / (len(feat) / 20), 13)
            axs[row_id, col_id].set_ylabel(feat, fontsize=y_labe_fontsize)
            axs[row_id, col_id].xaxis.set_tick_params(which='both', labelbottom=True)
        for empty_id in range(n_empty):   
            axs[n_rows - 1, n_cols - 1 - empty_id].axis('off')
    else:
        for feat_id, feat in enumerate(df_corr.index.values):
            row_id, col_id = divmod(feat_id, n_cols)
            scatter = sns.scatterplot(
                data=data,
                x=feat_trgt,
                y=feat,
                color='lightsteelblue',
                linewidth=0.5,
                alpha=0.75,
                edgecolor="k",
                s=10,
                ax=axs[max(row_id, col_id)]
            )
            scatter = sns.scatterplot(
                data=df,
                x=feat_trgt,
                y=feat,
                color='crimson',
                linewidth=1.0,
                alpha=1.0,
                edgecolor="k",
                s=35,
                ax=axs[max(row_id, col_id)]
            )
            axs[max(row_id, col_id)].set_title(fr"Pearson $\rho$: {df_corr.loc[feat, 'rho']:0.3f}")
            y_labe_fontsize = min(15 / (len(feat) / 20), 13)
            axs[max(row_id, col_id)].set_ylabel(feat, fontsize=y_labe_fontsize)
            axs[max(row_id, col_id)].xaxis.set_tick_params(which='both', labelbottom=True)
        for empty_id in range(n_empty):   
            axs[n_cols - 1 - empty_id].axis('off')
    fig.savefig(f"{path}/09_samples_from_mriya/feats_{feats_set}.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path}/09_samples_from_mriya/feats_{feats_set}.pdf", bbox_inches='tight')
    plt.close(fig)
    
    feats_corr = pd.DataFrame(index=feats, columns=['Correlation'])
    for f in feats.index.values:
        feats_corr.at[f, 'Correlation'], _ = stats.pearsonr(data.loc[:, f].values, data.loc[:, feat_trgt].values)
    
    rho = metrics.at['Test', 'pearson_corrcoef_unbiased'] * feats_corr['Correlation'].abs().max() * likelihood_feats_sets[feats_set]
    mae = metrics.at['Test', 'mean_absolute_error_unbiased']
    curr_threshold = rho * mae
    print(f'MAE: {mae}, rho: {rho}, threshold (rho*MAE): {curr_threshold}')
    
    if feats_set != 'Все':
        ids_not_na = df.index[df[f'Error Corrected {feats_set}'].notna()]
        df.loc[ids_not_na, f'BioAge Acceleration {feats_set}'] = df.loc[ids_not_na, f'Error Corrected {feats_set}'] * rho
        df.loc[ids_not_na, f'BioAge Acceleration'] += df.loc[ids_not_na, f'BioAge Acceleration {feats_set}']
        
df[f'BioAge'] = df[feat_trgt] + df[f'BioAge Acceleration']
    
df.to_excel(f"{path}/09_samples_from_mriya/result_mriya.xlsx")

In [None]:
df['BioAge'].values