In [None]:
import pandas as pd
import os
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import numpy as np
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr
from models.tabular.widedeep.ft_transformer import WDFTTransformerModel

In [None]:
root_dir = Path(os.getcwd())

fn_model = f"{root_dir}/data/model.ckpt"

model = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=fn_model)
model.eval()
model.freeze()

feats = [
    'CXCL9',
    'CCL22',
    'IL6',
    'PDGFB',
    'CD40LG',
    'IL27',
    'VEGFA',
    'CSF1',
    'PDGFA',
    'CXCL10'
]

df = pd.read_excel(f"{root_dir}/data/data_test.xlsx", index_col=0)
df = df.loc[:, feats + ['Age']]

df['SImAge'] = model(torch.from_numpy(df.loc[:, feats].values)).cpu().detach().numpy().ravel()
df['SImAge acceleration'] = df['SImAge'] - df['Age']
df['|SImAge acceleration|'] = df['SImAge acceleration'].abs()

df_res = df[['SImAge acceleration', '|SImAge acceleration|']]

mae = mean_absolute_error(df['Age'].values, df['SImAge'].values)
rho = pearsonr(df['Age'].values, df['SImAge'].values).statistic

sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(4, 4))
scatter = sns.scatterplot(
    data=df,
    x="Age",
    y="SImAge",
    linewidth=0.1,
    alpha=0.75,
    edgecolor="k",
    s=40,
    color='blue',
    ax=ax
)
bisect = sns.lineplot(
    x=[0, 120],
    y=[0, 120],
    linestyle='--',
    color='black',
    linewidth=1.0,
    ax=ax
)
ax.set_xlim(0, 120)
ax.set_ylim(0, 120)
plt.savefig('scatter.png', bbox_inches='tight')


sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(2, 4))
sns.violinplot(
    data=df,
    y='SImAge acceleration',
    scale='width',
    color='blue',
    saturation=0.75,
)
plt.savefig('violin.png', bbox_inches='tight')

In [None]:
def predict(input):
    df = pd.read_excel(input, index_col=0)
    df = df.loc[:, feats + ['Age']]
    df['SImAge'] = model(torch.from_numpy(df.loc[:, feats].values)).cpu().detach().numpy().ravel()
    df['SImAge acceleration'] = df['SImAge'] - df['Age']
    
    df_res = df[['SImAge acceleration']]
    df_res.to_excel(f'{root_dir}/out/output.xlsx')
    
    mae = mean_absolute_error(df['Age'].values, df['SImAge'].values)
    rho = pearsonr(df['Age'].values, df['SImAge'].values).statistic
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4, 4))
    scatter = sns.scatterplot(
        data=df,
        x="Age",
        y="SImAge",
        linewidth=0.1,
        alpha=0.75,
        edgecolor="k",
        s=40,
        color='blue',
        ax=ax
    )
    bisect = sns.lineplot(
        x=[0, 120],
        y=[0, 120],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_xlim(0, 120)
    ax.set_ylim(0, 120)
    plt.savefig(f'{root_dir}/out/scatter.svg', bbox_inches='tight')
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(2, 4))
    sns.violinplot(
        data=df,
        y='SImAge acceleration',
        scale='width',
        color='blue',
        saturation=0.75,
    )
    plt.savefig(f'{root_dir}/out/violin.svg', bbox_inches='tight')
    
    return [f'MAE: {mae}, Pearson Rho: {rho}', f'{root_dir}/out/output.xlsx', f'{root_dir}/out/scatter.svg', f'{root_dir}/out/violin.svg']

In [None]:
import gradio as gr

gr.Interface(fn=predict,
             inputs=gr.File(file_count='single', file_types=['.xlsx', '.csv']),
             outputs=[gr.Text(), gr.File(file_types=['.xlsx'], interactive=False), gr.Image(), gr.Image()],
             allow_flagging='never').launch(share=True)