In [None]:
#!pip install seaborn==0.13

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os, re
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.family'] = 'Arial'

# Import model prediction

In [None]:
ipt_DIR_male = "../../../2_model_construction/lasso/out/MAPE/0.99/Lasso/male/tsfresh/"
ipt_DIR_female = "../../../2_model_construction/lightGBM/out/MAPE/0.99/LGBM/female/both/"

outDIR_figure = "../figure/"
outDIR = "../out/"
os.makedirs(outDIR, exist_ok=True)


In [None]:
df_pred_male = pd.read_csv(os.path.join(ipt_DIR_male,"pred_vs_true.csv"), usecols=[1,2,5],index_col=0)
df_pred_female = pd.read_csv(os.path.join(ipt_DIR_female,"pred_vs_true.csv"), usecols=[1,2,5],index_col=0)
df_pred = pd.concat([df_pred_male,df_pred_female])

# Import systemic parameters

In [None]:
df = pd.read_csv("../data/systemic_params.csv",index_col="group.cmp")

# Import matching results

In [None]:
df_match = pd.read_csv("../out/pred_vs_true_matched_male_and_female.csv")

### Figure 4A Colour coding by propensity score matching

In [None]:
new_col = "Class"
new_col_2 = "Class_2"
excluded_class = "Excluded"
younger_class_name = "Model-predicted younger"
older_class_name = "Model-predicted older"

df[new_col_2] = excluded_class
df.loc[df.Age * 0.9 > df.Predicted_age, new_col_2] = younger_class_name
df.loc[df.Age * 1.1 < df.Predicted_age, new_col_2] = older_class_name
df[new_col] = df[new_col_2]
df.loc[df.index.isin(df_match.query("pred_real==0")["group.cmp"]), new_col] = younger_class_name+ " (matched)"
df.loc[df.index.isin(df_match.query("pred_real==1")["group.cmp"]), new_col] = older_class_name+" (matched)"

In [None]:
transp_class = [excluded_class, younger_class_name, older_class_name]

In [None]:
# define color mapping
color_mapping = {
    "Excluded": "gray",
    younger_class_name: sns.color_palette("tab10")[0],
    older_class_name: sns.color_palette("tab10")[1],
    younger_class_name + " (matched)": sns.color_palette("tab10")[0],
    older_class_name + " (matched)": sns.color_palette("tab10")[1],
}

# l for diagonal line
l = np.arange(35, 75, 1)
# グラフのサイズを設定
plt.figure(figsize=(5, 4))

# semi-transparent plot except for "matched"
sns.scatterplot(x=df.loc[df[new_col].isin(transp_class), "Age"],
                y=df.loc[df[new_col].isin(transp_class), "Predicted_age"],
                hue=df.loc[df[new_col].isin(transp_class), new_col],
                palette=color_mapping, alpha=0.1, sizes=0.5)

# plot for "matched"
sns.scatterplot(x=df.loc[~df[new_col].isin(transp_class), "Age"],
                y=df.loc[~df[new_col].isin(transp_class), "Predicted_age"], legend=False,
                hue=df.loc[~df[new_col].isin(transp_class), new_col],
                palette=color_mapping, sizes=0.5)

# diagonal line indicating thresholds
plt.plot(l, l * 0.9, c="black", alpha=0.5, linestyle="--")
plt.plot(l, l * 1.1, c="black", alpha=0.5, linestyle="--")

plt.xlabel("Actual age", fontsize=14)
plt.ylabel("Predicted age", fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)
plt.grid()

handles, labels = plt.gca().get_legend_handles_labels()
for handle in handles:
    handle.set_alpha(1)
plt.legend(handles=handles, labels=labels, title='', fontsize=12)
plt.tight_layout()

plt.savefig(outDIR_figure + '/color_TrueAge_vs_predAge' + '.pdf', bbox_inches="tight")