In [None]:
#!pip install seaborn==0.11.2
#!pip install statannotations==0.5.0

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import scipy.stats as stats

In [None]:
from statannotations.Annotator import Annotator
from scipy import stats

In [None]:
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.family'] = 'Arial'

# Comparison of raw data for features 
Graph in the same format as figure4

## data import

In [None]:
X = pd.read_csv("../../1_data_processing/processed_data/X.csv", index_col="group.cmp")

In [None]:
shap_male_dir = "../out/shap_male.csv"  
df_shap = pd.read_csv(shap_male_dir,index_col="group.cmp")

## data preprocessing

In [None]:
colnum_pred_real=np.where(df_shap.columns=="pred_real")[0][0]

In [None]:
X_selected = X.loc[df_shap.index, df_shap.columns[(colnum_pred_real+1):]]

In [None]:
df_tmp1 = pd.merge(df_shap["pred_real"], X_selected, left_index=True, right_index=True)

In [None]:
df_tmp1["pred_real"] = df_tmp1["pred_real"].replace({1:"Pred_old",0:"Pred_young"})
new_col="Class"
df_tmp1=df_tmp1.rename(columns={'pred_real':new_col})

In [None]:
outDIR = "../figure/"
outDIR_th = os.path.join(outDIR, "matched")
os.makedirs(outDIR_th, exist_ok=True)

## re-labeling & calc. mean sd

In [None]:
new_col = "Class"
new_col_2 = "Class_2"
excluded_class = "Excluded"
younger_class_name = "Pred_young"
older_class_name = "Pred_old"

In [None]:
score_all_mean = df_tmp1.groupby(new_col).mean().T
score_all_sd = df_tmp1.groupby(new_col).std().T
score_all_mean_str = score_all_mean.round(2).astype(str)
score_all_sd_str = score_all_sd.round(2).astype(str)
# Calculating mean ± sd for summary
score_all_str = score_all_mean_str + "±" + score_all_sd_str
grouped = df_tmp1.groupby(new_col)

# Shorten labels for plotting
short_label = "Model-predicted"
df_tmp1[short_label] = excluded_class
df_tmp1.loc[df_tmp1[new_col] == younger_class_name, short_label] = "Younger"
df_tmp1.loc[df_tmp1[new_col] == older_class_name, short_label] = "Older"

# Define category order
order = ["Younger", "Older"]

## draw graphs and export stats

In [None]:
# Dictionary to store test statistics
tstats = {}

# Iterate through columns of the dataframe
for i in df_tmp1.columns:
    if pd.api.types.is_numeric_dtype(df_tmp1[i]):
        i_filename = re.sub(r'[\\/:*?"<>|\^\$\{\}\(\) ]+', '', i)
        plt.figure(figsize=(2, 3))

        # Create strip plot
        ax = sns.stripplot(x=short_label, y=i, data=df_tmp1, jitter=0.1, size=2, alpha=0.5, linewidth=.1, order=order)

        # Create point plot
        sns.pointplot(x=short_label, y=i, data=df_tmp1, join=False, capsize=0.1, color='black', scale=0.5, order=order)

        # Change line width
        for line in ax.lines:
            line.set_linewidth(1)

        # Add statistical annotation using statannotations
        if all(df_tmp1.groupby(short_label)[i].var() != 0):
            pairs = [("Younger", "Older")]
            annotator = Annotator(ax, pairs, data=df_tmp1, x=short_label, y=i, order=order)
            annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', comparisons_correction=None)
            annotator.apply_and_annotate()

        # Add margin to the x-axis
        plt.margins(x=0.25) 
        # Save the figure
        plt.savefig(f"{outDIR_th}/{i_filename}_vs_predAge.pdf", bbox_inches="tight")
        plt.close()

        # If the column is binary, perform chi-square test and ROC curve
        if np.isin(df_tmp1[i].dropna().unique(), [0, 1]).all():
            cleaned_df = df_tmp1[[new_col, i]].dropna()
            crossed = pd.crosstab(cleaned_df[new_col], cleaned_df[i])
            tstats[i] = stats.chi2_contingency(crossed.dropna())[1]
        else:
            tstats[i] = stats.mannwhitneyu(df_tmp1[i][df_tmp1[new_col] == younger_class_name].dropna(),
                                           df_tmp1[i][df_tmp1[new_col] == older_class_name].dropna(),
                                           alternative="two-sided")[1]

df_summary = pd.concat([score_all_str, pd.Series(tstats, name="p_values")], axis=1)
df_summary.to_csv(f"{outDIR_th}/score_all_str.csv")
print(df_summary[df_summary["p_values"] < 0.05])