In [None]:
import shap
import pandas as pd
import os
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import numpy as np
import pickle
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr
from models.tabular.widedeep.ft_transformer import WDFTTransformerModel
import gradio as gr

In [None]:
root_dir = Path(os.getcwd())

fn_model = f"{root_dir}/data/model.ckpt"
fn_shap = f"{root_dir}/data/shap.pickle"

model = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=fn_model)
model.eval()
model.freeze()

feats = [
    'CXCL9',
    'CCL22',
    'IL6',
    'PDGFB',
    'CD40LG',
    'IL27',
    'VEGFA',
    'CSF1',
    'PDGFA',
    'CXCL10'
]


def predict_func(x):
    batch = {
        'all': torch.from_numpy(np.float32(x)),
        'continuous': torch.from_numpy(np.float32(x)),
        'categorical': torch.from_numpy(np.int32(x[:, []])),
    }
    return model(batch).cpu().detach().numpy()


with open(fn_shap, 'rb') as handle:
    shap_dict = pickle.load(handle)
values_train = shap_dict['values_train']
shap_values_train = shap_dict['shap_values_train']
explainer = shap_dict['explainer']

df = pd.read_excel(f"{root_dir}/data/data.xlsx", index_col=0)
df = df.loc[:, feats + ['Age']]

df['SImAge'] = model(torch.from_numpy(df.loc[:, feats].values)).cpu().detach().numpy().ravel()
df['SImAge acceleration'] = df['SImAge'] - df['Age']
df['|SImAge acceleration|'] = df['SImAge acceleration'].abs()

df_res = df[['SImAge acceleration', '|SImAge acceleration|']]

mae = mean_absolute_error(df['Age'].values, df['SImAge'].values)
rho = pearsonr(df['Age'].values, df['SImAge'].values).statistic

sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(4, 4))
scatter = sns.scatterplot(
    data=df,
    x="Age",
    y="SImAge",
    linewidth=0.1,
    alpha=0.75,
    edgecolor="k",
    s=40,
    color='blue',
    ax=ax
)
bisect = sns.lineplot(
    x=[0, 120],
    y=[0, 120],
    linestyle='--',
    color='black',
    linewidth=1.0,
    ax=ax
)
ax.set_xlim(0, 120)
ax.set_ylim(0, 120)
plt.savefig(f'{root_dir}/out/scatter.png', bbox_inches='tight')

sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(2, 4))
sns.violinplot(
    data=df,
    y='SImAge acceleration',
    density_norm='width',
    color='blue',
    saturation=0.75,
)
plt.savefig(f'{root_dir}/out/violin.png', bbox_inches='tight')

shap.summary_plot(
    shap_values=shap_values_train.values,
    features=values_train.values,
    feature_names=feats,
    max_display=len(feats),
    plot_type="violin",
    show=True,
)
plt.savefig(f'{root_dir}/out/shap.png', bbox_inches='tight')

trgt_id = "tst_ctrl_025"
shap_values_trgt = explainer.shap_values(df.loc[trgt_id, feats].values)
base_value = explainer.expected_value[0]

shap.plots.waterfall(
    shap.Explanation(
        values=shap_values_trgt,
        base_values=base_value,
        data=df.loc[trgt_id, feats].values,
        feature_names=feats
    ),
    max_display=len(feats),
    show=True,
)
order = np.argsort(-np.abs(shap_values_trgt))
locally_ordered_feats = [feats[i] for i in order]
plt.savefig(f'{root_dir}/out/waterfall_{trgt_id}.png', bbox_inches='tight')

In [None]:
def predict_func(x):
    batch = {
        'all': torch.from_numpy(np.float32(x)),
        'continuous': torch.from_numpy(np.float32(x)),
        'categorical': torch.from_numpy(np.int32(x[:, []])),
    }
    return model(batch).cpu().detach().numpy()


def load_model():
    root_dir = Path(os.getcwd())
    fn_model = f"{root_dir}/data/model.ckpt"
    model = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=fn_model)
    model.eval()
    model.freeze()
    return model


def load_feats():
    feats = [
        'CXCL9',
        'CCL22',
        'IL6',
        'PDGFB',
        'CD40LG',
        'IL27',
        'VEGFA',
        'CSF1',
        'PDGFA',
        'CXCL10'
    ]
    return feats


def load_shap():
    root_dir = Path(os.getcwd())
    fn_shap = f"{root_dir}/data/shap.pickle"
    with open(fn_shap, 'rb') as handle:
        shap_dict = pickle.load(handle)
    values_train = shap_dict['values_train']
    shap_values_train = shap_dict['shap_values_train']
    explainer = shap_dict['explainer']
    
    return values_train, shap_values_train, explainer

In [3]:
def predict(input):
    root_dir = Path(os.getcwd())
    
    model = load_model()
    feats = load_feats()
    values_train, shap_values_train, explainer = load_shap()
    
    df = pd.read_excel(input, index_col=0)
    df = df.loc[:, feats + ['Age']]
    
    df['SImAge'] = model(torch.from_numpy(df.loc[:, feats].values)).cpu().detach().numpy().ravel()
    df['SImAge acceleration'] = df['SImAge'] - df['Age']
    df.to_excel(f'{root_dir}/out/df.xlsx')
    
    df_res = df[['SImAge acceleration']]
    df_res.to_excel(f'{root_dir}/out/output.xlsx')
    
    mae = mean_absolute_error(df['Age'].values, df['SImAge'].values)
    rho = pearsonr(df['Age'].values, df['SImAge'].values).statistic
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4, 4))
    scatter = sns.scatterplot(
        data=df,
        x="Age",
        y="SImAge",
        linewidth=0.1,
        alpha=0.75,
        edgecolor="k",
        s=40,
        color='blue',
        ax=ax
    )
    bisect = sns.lineplot(
        x=[0, 120],
        y=[0, 120],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_xlim(0, 120)
    ax.set_ylim(0, 120)
    plt.savefig(f'{root_dir}/out/scatter.svg', bbox_inches='tight')
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(2, 4))
    sns.violinplot(
        data=df,
        y='SImAge acceleration',
        scale='width',
        color='blue',
        saturation=0.75,
    )
    plt.savefig(f'{root_dir}/out/violin.svg', bbox_inches='tight')
    
    shap.summary_plot(
    shap_values=shap_values_train.values,
    features=values_train.values,
    feature_names=feats,
    max_display=len(feats),
    plot_type="violin",
    )
    plt.savefig(f'{root_dir}/out/shap.svg', bbox_inches='tight')
    
    return [f'MAE: {mae}, Pearson Rho: {rho}', 
            f'{root_dir}/out/output.xlsx', 
            [(f'{root_dir}/out/scatter.svg', 'Scatter'), (f'{root_dir}/out/violin.svg', 'Violin'), (f'{root_dir}/out/shap.svg', 'SHAP')], 
            gr.update(choices=list(df.index.values), value=list(df.index.values)[0], interactive=True, visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]


def explain(input):
    root_dir = Path(os.getcwd())
    
    feats = load_feats()
    values_train, shap_values_train, explainer = load_shap()
    
    df = pd.read_excel(f'{root_dir}/out/df.xlsx', index_col=0)
    
    trgt_id = input
    shap_values_trgt = explainer.shap_values(df.loc[trgt_id, feats].values)
    base_value = explainer.expected_value[0]
    
    shap.plots.waterfall(
        shap.Explanation(
            values=shap_values_trgt,
            base_values=base_value,
            data=df.loc[trgt_id, feats].values,
            feature_names=feats
        ),
        max_display=len(feats),
        show=True,
    )
    plt.savefig(f'{root_dir}/out/waterfall_{trgt_id}.svg', bbox_inches='tight')
    
    age = df.loc[trgt_id, ['Age']].values[0]
    simage = df.loc[trgt_id, ['SImAge']].values[0]
    
    return [f'Real age: {age}, SImAge: {simage}', 
            f'{root_dir}/out/waterfall_{trgt_id}.svg']


def set_invisible():
    return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)

In [9]:
import gradio as gr

css = """
h2 {
    text-align: center;
    display:block;
}
"""

with gr.Blocks(css=css, theme=gr.themes.Soft(), title='Small Immunological Age (SImAge)') as app:
    gr.Markdown(
                """
                <h2>Calculate your immunological age using SImAge model</h2>
                """
            )
    with gr.Row():
        with gr.Column():
            gr.Markdown(
                """
                ### Submit immunology data
                The file should contain chronological age ("Age" column) and immunology data for the following 10 cytokines:
                
                CXCL9, CCL22, IL6, PDGFB, CD40LG, IL27, VEGFA, CSF1, PDGFA, CXCL10
                """
            )
            input_file = gr.File(label='Input file', file_count='single', file_types=['.xlsx', '.csv'], height=500)
            submit_button = gr.Button("Submit data", variant="primary")
        with gr.Column():
            with gr.Row():
                output_text = gr.Text(label='Main metrics')
                output_file = gr.File(label='Result file', file_types=['.xlsx'], height=100, interactive=False)
            with gr.Row():
                gallery = gr.Gallery(label='Result Figures', object_fit='cover', columns=3, rows=1)
                '''
                scatter_image = gr.Image(label='Scatter', height=300, scale=1)
                violin_image = gr.Image(label='Violin', height=300, scale=1)
                shap_image = gr.Image(label='SHAP', height=300, scale=1)
                '''
    with gr.Row():
        gr.Markdown(
                """
                ### Local explainability
                """
            )
        with gr.Column():
            input_shap = gr.Dropdown(label='Choose a sample', visible=False)
            shap_button = gr.Button("Get explanation", variant="primary", visible=False)
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=1):
                    shap_local = gr.Text(label='Main metrics', visible=False)
                    shap_cyto = gr.Text(label='Most important cytokines', visible=False)
                with gr.Column(scale=2):
                    shap_waterfall = gr.Image(label='Waterfall', visible=False)
    submit_button.click(fn=predict,
                        inputs=[input_file],
                        outputs=[output_text, output_file, gallery, input_shap, shap_button, shap_local, shap_cyto, shap_waterfall]
                        )
    shap_button.click(fn=explain,
                      inputs=[input_shap],
                      outputs=[shap_local, shap_waterfall]
                      )
    gr.Markdown(
        """
        Reference:
        
        Kalyakulina, A., Yusipov, I., Kondakova, E., Bacalini, M. G., Franceschi, C., Vedunova, M., & Ivanchenko, M. (2023). [Small immunological clocks identified by deep learning and gradient boosting](https://www.frontiersin.org/journals/immunology/articles/10.3389/fimmu.2023.1177611/full). Frontiers in Immunology, 14, 1177611.
        """
    )
app.launch(debug=True)

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


Keyboard interruption in main thread... closing server.


