# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [34]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
from scipy import stats
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import itertools
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.pt.hyper_opt import train_hyper_opt
from src.utils.hash import dict_hash
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import optuna
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.cm
import matplotlib as mpl
from statsmodels.stats.multitest import multipletests
import re
import datetime
from collections import Counter
from matplotlib.ticker import MaxNLocator
from itertools import chain

def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

def form_bar(base):
    def formatter(x):
        return f'{str(int(round(x * base)))}/{base}'
    return formatter


# Load data

In [None]:
path = f"D:/YandexDisk/Work/bbd/fmba"

# data_raw = pd.read_excel(f"{path}/2024-08-30 Пример 1000 за 2023 г в формате широких данных.xlsx", index_col=0)
data = pd.read_excel(f"{path}/2024-10-14 1147 за 2023 г в формате широких данных.xlsx", index_col=0)
data['дата рождения'] = pd.to_datetime(data['дата рождения'])
data['date_now'] = pd.to_datetime("2024-10-10")
data['Age'] = (data['date_now'] - data['дата рождения']) / np.timedelta64(1, 'D') / 365.25
data = data[data['Age'].notna()]

df_hlty = pd.read_excel(f"{path}/здоровые_бпд.xlsx", index_col=0)
df_sick = pd.read_excel(f"{path}/больные_бпд.xlsx", index_col=0)

ids_hlty_sick = df_hlty.index.intersection(df_sick.index).to_list()
print(f'ids_hlty_sick:\n{ids_hlty_sick}')

df_inventory = pd.read_excel(f"{path}/Опись биоматериала. Отправка 28.10.2024_selected.xlsx", index_col='ID')
df_inventory = df_inventory.loc[df_inventory.index.drop_duplicates(), :]
df_inventory = df_inventory[df_inventory.index.notnull()]

df_supervisors = pd.read_excel(f"{path}/Руководители.xlsx", index_col='ID')
df_supervisors = df_supervisors.loc[df_supervisors.index.drop_duplicates(), :]
df_supervisors = df_supervisors[df_supervisors.index.notnull()]

missed_hlty = set(df_hlty.index) - set(df_hlty.index.intersection(data.index))
missed_sick = set(df_sick.index) - set(df_sick.index.intersection(data.index))

ids_inventory_intxn = {
    '1000+ List (with Age)': df_inventory.index.intersection(data.index).values,
    'Healthy': df_inventory.index.intersection(df_hlty.index).values,
    'Sick': df_inventory.index.intersection(df_sick.index).values,
    'Radiation': df_inventory.index.intersection(data.index[data['Текущая основная вредность - Физические факторы'] == 'Ионизирующие излученияК, радиоактивные веществаК;']).values,
    'No radiation': df_inventory.index.intersection(data.index[data['Текущая основная вредность - Физические факторы'] != 'Ионизирующие излученияК, радиоактивные веществаК;']).values,
}
df_inventory_dist = pd.DataFrame()
df_inventory_dist.at['Total', 'Count'] = len(df_inventory.index)
for g in ids_inventory_intxn:
    df_inventory_dist.at[g, 'Count'] = len(ids_inventory_intxn[g])
    df_inventory[g] = 0
    df_inventory.loc[ids_inventory_intxn[g], g] = 1
df_inventory_dist.to_excel(f"{path}/distribution_Опись.xlsx", index_label='Опись')
df_inventory.to_excel(f"{path}/Опись_intxn.xlsx")

ids_supervisors_intxn = {
    '1000+ List (with Age)': df_supervisors.index.intersection(data.index).values,
    'Healthy': df_supervisors.index.intersection(df_hlty.index).values,
    'Sick': df_supervisors.index.intersection(df_sick.index).values,
    'Radiation': df_supervisors.index.intersection(data.index[data['Текущая основная вредность - Физические факторы'] == 'Ионизирующие излученияК, радиоактивные веществаК;']).values,
    'No radiation': df_supervisors.index.intersection(data.index[data['Текущая основная вредность - Физические факторы'] != 'Ионизирующие излученияК, радиоактивные веществаК;']).values,
    'Опись': df_supervisors.index.intersection(df_inventory.index).values,
}
df_supervisors_dist = pd.DataFrame()
df_supervisors_dist.at['Total', 'Count'] = len(df_supervisors.index)
for g in ids_supervisors_intxn:
    df_supervisors_dist.at[g, 'Count'] = len(ids_supervisors_intxn[g])
    df_supervisors[g] = 0
    df_supervisors.loc[ids_supervisors_intxn[g], g] = 1
df_supervisors_dist.to_excel(f"{path}/distribution_Руководители.xlsx", index_label='Руководители')
df_supervisors.to_excel(f"{path}/Руководители_intxn.xlsx")

groups_ids = {
    'Heathy': df_hlty.index.intersection(data.index).values,
    'Sick': df_sick.index.intersection(data.index).values
}

for group, ids in groups_ids.items():
    print(f"{group}: {len(ids)}")

cols_diseases = [
    'невропатолог - код_заболевания',
    'отоларинголог - код_заболевания',
    'офтальмолог - код_заболевания',
    'дерматолог - код_заболевания',
    'хирург - код_заболевания',
    'терапевт - код_заболевания',
]

cols_diseases_colors = {}
for col_disease in cols_diseases:
    statuses = np.concatenate(data[col_disease].dropna().str.split(';').values)
    statuses = statuses[statuses != '']
    statuses_counter = Counter(statuses)
    df_statuses_counter = pd.DataFrame.from_dict(statuses_counter, orient='index', columns=['Count'])
    df_statuses_counter.sort_values(['Count'], ascending=[False], inplace=True)
    colors = distinctipy.get_colors(df_statuses_counter.shape[0], [mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
    cols_diseases_colors[col_disease] = {status: colors[status_id] for status_id, status in enumerate(df_statuses_counter.index.values)}

In [None]:
df_sick[df_sick.index.duplicated()]

In [None]:
df_hlty[df_hlty.index.duplicated()]

# Select samples

In [None]:
path_save = f"{path}/02_select_samples"

ids_hlty = df_hlty.drop(ids_hlty_sick).index.values
ids_sick = df_sick.drop(ids_hlty_sick).index.values
ids_inv = df_inventory.index.values
ids_spv = df_supervisors.index.values
ids_rad = data.index[data['Текущая основная вредность - Физические факторы'] == 'Ионизирующие излученияК, радиоактивные веществаК;'].values
ids_norad = data.index[data['Текущая основная вредность - Физические факторы'] != 'Ионизирующие излученияК, радиоактивные веществаК;'].values

groups = {
    'Heathy with Radiation': set.intersection(set(ids_inv), set(ids_hlty), set(ids_rad)),
    'Sick with Radiation': set.intersection(set(ids_inv), set(ids_sick), set(ids_rad)),
    'Heathy without Radiation': set.intersection(set(ids_inv), set(ids_hlty), set(ids_norad)),
    'Sick without Radiation': set.intersection(set(ids_inv), set(ids_sick), set(ids_norad)),
}

groups_colors = {
    'Heathy with Radiation': 'crimson',
    'Sick with Radiation': 'dodgerblue',
    'Heathy without Radiation': 'lawngreen',
    'Sick without Radiation': 'darkorchid',
}

df_count_all = pd.DataFrame()
df_count_all.at['Healthy', 'Radiation'] = len(groups['Heathy with Radiation'])
df_count_all.at['Sick', 'Radiation'] = len(groups['Sick with Radiation'])
df_count_all.at['Healthy', 'No radiation'] = len(groups['Heathy without Radiation'])
df_count_all.at['Sick', 'No radiation'] = len(groups['Sick without Radiation'])

df_count_spv = pd.DataFrame()
df_count_spv.at['Healthy', 'Radiation'] = len(groups['Heathy with Radiation'].intersection(set(ids_spv)))
df_count_spv.at['Sick', 'Radiation'] = len(groups['Sick with Radiation'].intersection(set(ids_spv)))
df_count_spv.at['Healthy', 'No radiation'] = len(groups['Heathy without Radiation'].intersection(set(ids_spv)))
df_count_spv.at['Sick', 'No radiation'] = len(groups['Sick without Radiation'].intersection(set(ids_spv)))

n_rows = 2
n_cols = 2
fig_width = 10
fig_height = 8
hist_bins = np.linspace(5, 115, 23)

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=True, sharex=True)
for g_id, (g, g_ids) in enumerate(groups.items()):
    row_id, col_id = divmod(g_id, n_cols)
    
    histplot = sns.histplot(
        data=data.loc[list(g_ids), ],
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        color=groups_colors[g],
        ax=axs[row_id, col_id]
    )
    axs[row_id, col_id].set(xlim=(15, 80))
    axs[row_id, col_id].set_title(f"{g} (Total {len(g_ids)}, Supervisors {len(groups[g].intersection(set(ids_spv)))})")
fig.tight_layout()    
fig.savefig(f"{path_save}/hist_age.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_save}/hist_age.pdf", bbox_inches='tight')
plt.close(fig)

needed_samples = {
    'Heathy with Radiation': 51 - len(groups['Heathy with Radiation'].intersection(set(ids_spv))),
    'Sick with Radiation': 51 - len(groups['Sick with Radiation'].intersection(set(ids_spv))),
    'Heathy without Radiation': 0,
    'Sick without Radiation': 0,
}

groups_selected = {
    'Heathy with Radiation': [],
    'Sick with Radiation': [],
    'Heathy without Radiation': list(set.intersection(set(ids_inv), set(ids_hlty), set(ids_norad))),
    'Sick without Radiation': list(set.intersection(set(ids_inv), set(ids_sick), set(ids_norad))),
}

age_bin_edges = np.linspace(5, 115, 23)
age_prob = np.asarray([1/22] * 22)
bin_diff = 5
for g in ['Heathy with Radiation', 'Sick with Radiation']:
    data_cands = data.loc[list(groups[g] - set(ids_spv)), :]
    print(data_cands.shape[0])
    data_cands.loc[:, 'Prob Age'] = age_prob[np.rint((data_cands.loc[:, 'Age'].values - age_bin_edges[0]) / (bin_diff + 0.0001)).astype(int)]
    
    n_same_age = needed_samples[g]
    print(n_same_age)
    index_selected = data_cands.sample(n=n_same_age, replace=False, weights='Prob Age', random_state=36).index
    if index_selected.is_unique:
        ids_selected = index_selected.to_list()
        groups_selected[g] = ids_selected + list(groups[g].intersection(set(ids_spv)))
        print(len(groups_selected[g]))
    else:
        print("Not unique index")

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=True, sharex=True)
for g_id, (g, g_ids) in enumerate(groups_selected.items()):
    row_id, col_id = divmod(g_id, n_cols)
    
    histplot = sns.histplot(
        data=data.loc[g_ids, ],
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        color=groups_colors[g],
        ax=axs[row_id, col_id]
    )
    axs[row_id, col_id].set(xlim=(15, 80))
    axs[row_id, col_id].set_title(f"{g} (Total {len(g_ids)}, Supervisors {len(set(groups_selected[g]).intersection(set(ids_spv)))})")
    axs[row_id, col_id].yaxis.set_major_locator(MaxNLocator(integer=True))
fig.tight_layout()    
fig.savefig(f"{path_save}/hist_age_selected.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_save}/hist_age_selected.pdf", bbox_inches='tight')
plt.close(fig)  
        
data_selected = data.loc[chain.from_iterable(groups_selected.values()), :]
col = data_selected.pop("Age")
data_selected.insert(0, col.name, col)
data_selected.insert(1, 'Status', 'Sick')
data_selected.loc[groups_selected['Heathy with Radiation'] + groups_selected['Heathy without Radiation'], 'Status'] = 'Healthy'
data_selected.insert(2, 'Radiation', 0)
data_selected.loc[groups_selected['Heathy with Radiation'] + groups_selected['Sick with Radiation'], 'Radiation'] = 1
data_selected.insert(3, 'Supervisor', 0)
data_selected.loc[data_selected.index.intersection(df_supervisors.index), 'Supervisor'] = 1
cols_to_front = [
    'Текущая основная вредность - Физические факторы',
    'невропатолог - код_заболевания',
    'отоларинголог - код_заболевания',
    'офтальмолог - код_заболевания',
    'дерматолог - код_заболевания',
    'хирург - код_заболевания',
    'терапевт - код_заболевания',
]
for col_front_id, col_front in enumerate(cols_to_front):
    col = data_selected.pop(col_front)
    data_selected.insert(col_front_id + 4, col.name, col)
print(f"Duplicated indexes: {data_selected.index[data_selected.index.duplicated()].unique().to_list()}")
data_selected.to_excel(f"{path_save}/data_selected.xlsx")

In [None]:
age_prob = age_counts / len(df_imm.loc[ids_to, 'Age'].values)
df_imm.loc[ids_from, 'Prob Age (IPOP)'] = age_prob[np.rint((df_imm.loc[ids_from, 'Age'].values - age_bin_edges[0]) / (bin_diff + 0.0001)).astype(int)]

n_same_age = 200
index_from_to = df_imm.loc[ids_from, :].sample(n=n_same_age, replace=False, weights='Prob Age (IPOP)', random_state=1337).index
if index_from_to.is_unique:
    ids_from_to = index_from_to.values
else:
    print("Not unique index")

# NaNs analysis

In [None]:
nan_pct = data.isna().sum().sum() / data.size * 100
print(nan_pct)

nan_feats = data.isna().sum(axis=0).to_frame(name="Number of NaNs")
nan_feats["% of NaNs"] = nan_feats["Number of NaNs"] / data.shape[0] * 100
nan_feats["Number of not-NaNs"] = data.notna().sum(axis=0)
nan_feats.sort_values(["% of NaNs"], ascending=[False], inplace=True)
nan_feats.to_excel(f"{path_save}/nan_feats.xlsx", index_label="Features")

# Healthy and Sick groups analysis 

In [None]:
path_save = f"{path}/01_test_data"

for group, ids in groups_ids.items():

    df_group = data.loc[ids, :]
    print(df_group.shape[0])
    
    hue_counts = df_group['пол'].value_counts()
    hue_colors = {'М': 'dodgerblue', 'F': 'crimson'}
    hue_replace = {x: f"{x} ({y})" for x, y in hue_counts.items()}
    hue_colors = {f"{x} ({y})": hue_colors[x] for x, y in hue_counts.items()}
    df_group['пол'].replace(hue_replace, inplace=True)

    hist_bins = np.linspace(5, 115, 23)
    
    sns.set_theme(style='ticks')
    fig, ax = plt.subplots(figsize=(6, 3.5))
    histplot = sns.histplot(
        data=df_group,
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        hue='пол',
        palette=hue_colors,
        ax=ax
    )
    histplot.set(xlim=(0, 120))
    histplot.set_title(group)
    plt.savefig(f"{path_save}/age_hist_{group}.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/age_hist_{group}.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(1, len(cols_diseases), figsize=(30, 15), gridspec_kw={'wspace':0.4}, sharey=False, sharex=False)
    
    for col_disease_id, col_disease in enumerate(cols_diseases):
        statuses = np.concatenate(df_group[col_disease].dropna().str.split(';').values)
        statuses = statuses[statuses != '']
        statuses_counter = Counter(statuses)
        df_statuses_counter = pd.DataFrame.from_dict(statuses_counter, orient='index', columns=['Count'])
        df_statuses_counter.sort_values(['Count'], ascending=[False], inplace=True)
        
        df_statuses_counter = df_statuses_counter.head(50)
        df_statuses_counter['Status'] = df_statuses_counter.index.values
        barplot = sns.barplot(
            data=df_statuses_counter,
            x='Count',
            y='Status',
            hue='Status',
            palette=cols_diseases_colors[col_disease],
            edgecolor='black',
            dodge=False,
            ax=axs[col_disease_id]
        )
        for container in barplot.containers:
            barplot.bar_label(container, label_type='edge', fmt='%.d', fontsize=12, padding=2.0)
        axs[col_disease_id].set_title(col_disease, fontsize='large')
        axs[col_disease_id].set_ylabel('')
        axs[col_disease_id].get_legend().remove()
    plt.savefig(f"{path_save}/barplot_icd_{group}.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/barplot_icd_{group}.pdf", bbox_inches='tight')
    plt.close(fig)

# Legacy

## Age and sex analysis

In [None]:
data['дата рождения'] = pd.to_datetime(data['дата рождения'])
data['date_now'] = pd.to_datetime("2024-01-01")
data['age'] = (data['date_now'] - data['дата рождения']) / np.timedelta64(1, 'D') / 365.25
hp = sns.histplot(data=data, x='age', hue="пол", binwidth=1, multiple="stack")
figure = hp.get_figure()    
figure.savefig(f'{path_save}/age_hist.png')

In [None]:
data_vred = data[data['Текущая основная вредность - Физические факторы'] == 'Ионизирующие излученияК, радиоактивные веществаК;']
hpv = sns.histplot(data=data_vred, x='age', hue="пол", binwidth=1, multiple="stack").set_title("Ионизирующие излучения")
figure = hpv.get_figure()    
figure.savefig(f'{path_save}/ion_age_hist.png')

In [None]:
data_no_vred = data[data['Текущая основная вредность - Физические факторы'] != 'Ионизирующие излученияК, радиоактивные веществаК;']
hpv = sns.histplot(data=data_no_vred, x='age', hue="пол", binwidth=1, multiple="stack").set_title("Без ионизирующих излучений")
figure = hpv.get_figure()    
figure.savefig(f'{path_save}/no_ion_age_hist.png')

## Diseases statistics

In [None]:
data['терапевт - код_заболевания'].replace({None: 'Healthy'}, inplace=True)

subsets = {
    'Все данные': data.index.values,
    'Женщины': data.index[data['пол'] == 'Ж'].values,
    'Мужчины': data.index[data['пол'] == 'М'].values,
    'Ионизирующие излучения': data.index[data['Текущая основная вредность - Физические факторы'] == 'Ионизирующие излученияК, радиоактивные веществаК;'].values,
    'Ионизирующие излучения\nЖенщины': data.index[(data['Текущая основная вредность - Физические факторы'] == 'Ионизирующие излученияК, радиоактивные веществаК;') & (data['пол'] == 'Ж')].values,
    'Ионизирующие излучения\nМужчины': data.index[(data['Текущая основная вредность - Физические факторы'] == 'Ионизирующие излученияК, радиоактивные веществаК;') & (data['пол'] == 'М')].values,
    'Нет излучения': data.index[data['Текущая основная вредность - Физические факторы'] != 'Ионизирующие излученияК, радиоактивные веществаК;'].values,
    'Нет излучения\nЖенщины': data.index[(data['Текущая основная вредность - Физические факторы'] != 'Ионизирующие излученияК, радиоактивные веществаК;') & (data['пол'] == 'Ж')].values,
    'Нет излучения\nМужчины': data.index[(data['Текущая основная вредность - Физические факторы'] != 'Ионизирующие излученияК, радиоактивные веществаК;') & (data['пол'] == 'М')].values,
}

In [None]:
statuses = np.concatenate(data['терапевт - код_заболевания'].str.split(';').values)
statuses = statuses[statuses != '']
statuses_counter = Counter(statuses)
df_statuses_counter = pd.DataFrame.from_dict(statuses_counter, orient='index', columns=['Count'])
df_statuses_counter.sort_values(['Count'], ascending=[False], inplace=True)
colors = distinctipy.get_colors(df_statuses_counter.shape[0], [mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_statuses = {status: colors[status_id] for status_id, status in enumerate(df_statuses_counter.index.values)}

In [None]:
sns.set_theme(style='ticks')
fig, axs = plt.subplots(1, 9, figsize=(30, 20), gridspec_kw={'wspace':0.4},sharey=False, sharex=False)

for subset_id, (subset, subset_ids) in enumerate(subsets.items()):
    df_data_subset = data.loc[subset_ids, :]
    print(f"{subset}: {len(df_data_subset)}")
    statuses = np.concatenate(df_data_subset['терапевт - код_заболевания'].str.split(';').values)
    statuses = statuses[statuses != '']
    statuses_counter = Counter(statuses)
    df_statuses_counter = pd.DataFrame.from_dict(statuses_counter, orient='index', columns=['Count'])
    df_statuses_counter.sort_values(['Count'], ascending=[False], inplace=True)

    df_fig = df_statuses_counter.head(50)
    df_fig['Status'] = df_fig.index.values
    barplot = sns.barplot(
        data=df_fig,
        x='Count',
        y='Status',
        hue='Status',
        palette=colors_statuses,
        edgecolor='black',
        dodge=False,
        ax=axs[subset_id]
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', fmt='%.d', fontsize=12, padding=2.0)
    axs[subset_id].set_title(subset, fontsize='large')
    axs[subset_id].set_ylabel('')
    axs[subset_id].get_legend().remove()
plt.savefig(f"{path_save}/barplot_icd.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/barplot_icd.pdf", bbox_inches='tight')
plt.close(fig)