In [None]:
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

from evaluate import join_scores
from process_ldopa import build_metadata

In [None]:
def report_score(scores):
    return f"{scores.mean():.3f} [\u00B1 {scores.std():.3f}]"

In [None]:
my_scores = pd.read_pickle("outputs/predictions/all_scores.pkl")

my_scores["source"] = [
    "OxWalk" if "wrist" in pid else "Ldopa" for pid in my_scores.index
]

scores = my_scores.groupby(["source"]).mean()
scores.to_csv("outputs/predictions/performance_table.csv")

In [None]:
score_df = join_scores("outputs/predictions")
metadata = build_metadata(processeddir="data/Ldopa_Processed")["MeanUPDRS"]

df = pd.concat([score_df, metadata], axis=1)

for source in ["LDOPA", "OXWALK"]:
    for model in ["rf", "ssl"]:
        df[f"scores_{model}_train_{source}_test_all"] = np.where(
            pd.isna(df["MeanUPDRS"]),
            df[f"scores_{model}_train_{source}_test_OXWALK"],
            df[f"scores_{model}_train_{source}_test_LDOPA"]
        )

cols = {
    "scores_rf_train_all_test_all": "Combined Trained RF",
    "scores_rf_train_LDOPA_test_all": "Ldopa Trained RF",
    "scores_rf_train_OXWALK_test_all": "OxWalk Trained RF",
    "scores_ssl_train_all_test_all": "Combined Trained SSL",
    "scores_ssl_train_LDOPA_test_all": "Ldopa Trained SSL",
    "scores_ssl_train_OXWALK_test_all": "OxWalk Trained SSL",
}

df.rename(columns=cols, inplace=True)

df = df[list(cols.values()) + ["MeanUPDRS"]]

In [None]:
# Melt the DataFrame for scatter plot
dfm = df.melt("MeanUPDRS", var_name="Models", value_name="F1 score", ignore_index = False).reset_index(names="Participant")
dfm = dfm.dropna().reset_index(drop=True)

cuts = [round(np.quantile(df["MeanUPDRS"].dropna(), 0.25), 1),
        round(np.quantile(df["MeanUPDRS"].dropna(), 0.5), 1),
        round(np.quantile(df["MeanUPDRS"].dropna(), 0.75), 1)]

dfm['MeanUPDRS_Quantiles'] = pd.qcut(dfm['MeanUPDRS'], q=4,
                                     labels=[f"0-{cuts[0]}", 
                                             f"{cuts[0]}-{cuts[1]}", 
                                             f"{cuts[1]}-{cuts[2]}", 
                                             f"{cuts[2]}+"])

In [None]:
dfm['size'] = 1  # Adjust the size of the points

# Create the scatter plot
fig = px.scatter(dfm, x="MeanUPDRS", y="F1 score", color="Models", hover_name="Participant",
                         size="size", trendline="ols", size_max=10, height=500, width=900,
                         labels={"F1 score": "F1 Score", "MeanUPDRS": "PD Severity"})

# Set axis and plot titles
fig.update_layout(
    title="Comparison of Model Performance based on PD Severity",
    xaxis_title="PD Severity (MeanUPDRS)",
    yaxis_title="F1 Score",
    legend_title="Model - Tested on:"
)

fig.update_traces(line=dict(width=3))

legend_table = pd.DataFrame([{
   "OxWalk": report_score(df.loc[df["MeanUPDRS"].isna(), col]), 
   "LDopa": report_score(df.loc[df["MeanUPDRS"].notna(), col])
   } for col in df.columns if 'MeanUPDRS' not in col],
   index=[col for col in df.columns if 'MeanUPDRS' not in col]
)

table_trace = go.Table(
    header=dict(values=['Models', 'OxWalk', 'LDopa']),
    cells=dict(values=[legend_table.index,
                        legend_table['OxWalk'],
                        legend_table['LDopa']])
)

config = {
  'toImageButtonOptions': {
    'format': 'png',
    'filename': 'custom_image',
    'height': 500,
    'width': 900,
    'scale': 6 
  }
}

# Show or save the plot
fig.show(config=config)
os.makedirs("outputs/plots", exist_ok=True)

fig.write_html(os.path.join("outputs/plots", "scatter_plotly.html"))

In [None]:
# Create the box plot
fig = px.box(dfm, x='MeanUPDRS_Quantiles', y='F1 score', color='Models',
             labels={'F1 score': 'F1 Score', 'MeanUPDRS_Quantiles': 'PD Severity (UPDRS Part III Score)'},
             title='Box Plot of Model Performance for PD Severity Quartiles',
             height=500, width=800)

# Set x-axis labels to show UPDRS range
fig.update_xaxes(categoryorder='array', categoryarray=sorted(dfm['MeanUPDRS_Quantiles'].unique()))

# Show or save the plot
config = {
  'toImageButtonOptions': {
    'format': 'png',
    'filename': 'custom_image',
    'height': 500,
    'width': 900,
    'scale': 6 
  }
}

fig.show(config=config)

os.makedirs("outputs/plots", exist_ok=True)
fig.write_html(os.path.join("outputs/plots", "box_plotly.html"))

In [None]:
table_trace = go.Table(
    header=dict(values=['Models - Tested on:', 'OxWalk', 'LDopa']),
    cells=dict(values=[legend_table.index,
                        legend_table['OxWalk'],
                        legend_table['LDopa']],
                height=25),
    columnwidth=[3,2,2],
)

fig = go.Figure(data=[table_trace])

fig.update_layout(title="Comparison of Model Performance", width=600, height=400)

# Show or save the plot
fig.show(config = {
  'toImageButtonOptions': {
    'scale': 5
  }
})
