In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from evaluate import join_scores
from process_ldopa import build_metadata

In [None]:
def report_score(scores):
    return f"{scores.mean():.3f} [\u00B1 {scores.std():.3f}]"

In [None]:
my_scores = pd.read_pickle("outputs/predictions/all_scores.pkl")

my_scores["source"] = [
    "OxWalk" if "wrist" in pid else "Ldopa" for pid in my_scores.index
]

scores = my_scores.groupby(["source"]).mean()
scores.to_csv("outputs/predictions/performance_table.csv")

In [None]:
score_df = join_scores("outputs/predictions")
metadata = build_metadata(processeddir="data/Ldopa_Processed")["MeanUPDRS"]

df = pd.concat([score_df, metadata], axis=1)

for source in ["LDOPA", "OXWALK"]:
    for model in ["rf", "ssl"]:
        df[f"scores_{model}_train_{source}_test_all"] = np.where(
            pd.isna(df["MeanUPDRS"]),
            df[f"scores_{model}_train_{source}_test_OXWALK"],
            df[f"scores_{model}_train_{source}_test_LDOPA"]
        )

cols = {
    "scores_rf_train_all_test_all": "Combined Trained RF",
    "scores_rf_train_LDOPA_test_all": "Ldopa Trained RF",
    "scores_rf_train_OXWALK_test_all": "OxWalk Trained RF",
    "scores_ssl_train_OXWALK_test_all": "1: OxWalk (healthy)",
    "scores_ssl_train_LDOPA_test_all": "2: MJFF-LR (PD)",
    "scores_ssl_train_all_test_all": "3: OxWalk (healthy) + MJFF-LR (PD)",
}

df.rename(columns=cols, inplace=True)

df = df[list(cols.values()) + ["MeanUPDRS"]]

df['MeanUPDRS'].fillna(-1, inplace=True)

In [None]:
# Melt the DataFrame for scatter plot
legend_name = 'Model trained on'

dfm = df.melt("MeanUPDRS", var_name=legend_name, value_name="F1 score", ignore_index = False).reset_index(names="Participant")
dfm = dfm.dropna().reset_index(drop=True)

bins = [-5, 0, 15, 20, 25, float('inf')]
labels = ["OxWalk<br>(Healthy)", "MJFF-LR<br>(0-15)", "MJFF-LR<br>(15-20)", 
          "MJFF-LR<br>(20-25)", "MJFF-LR<br>(25+)"]


dfm['Population'] = pd.cut(dfm['MeanUPDRS'], bins,
                                    labels=labels)

In [None]:
metadata = pd.DataFrame(metadata, index=metadata.index)
metadata.name = "MeanUPDRS"

bins = [0, 15, 20, 25, float('inf')]
labels = ["Least\nsevere", "Less\nsevere", "More\nsevere", "Most\nsevere"]

metadata["Population"] = pd.cut(metadata["MeanUPDRS"], bins, labels=labels)

In [None]:
fig, ax = plt.subplots(figsize=(10,4), dpi=1000)
plt.rcParams.update({'font.size': 14})
with sns.color_palette("Dark2"): 
    sns.boxplot(data=metadata, y="Population", x="MeanUPDRS", 
                ax=ax)
sns.stripplot(data=metadata, y="Population", x="MeanUPDRS", 
              ax=ax, color='black', alpha=0.3, size=10)
plt.ylabel("MJFF-LR subpopulation")
plt.xlabel("Mean UPDRS Part III score")
plt.xlim(0, 35)
plt.show()

In [None]:
df_ssl = dfm[['Trained' not in model for model in dfm[legend_name]]]

In [None]:
df_pres = dfm.copy()

In [None]:
df_pres.Population = df_pres.Population.cat.remove_categories('OxWalk<br>(Healthy)')

df_pres.Population = df_pres.Population.cat.rename_categories({
    "MJFF-LR<br>(0-15)": "0-15",
    "MJFF-LR<br>(15-20)": "15-20",
    "MJFF-LR<br>(20-25)": "20-25",
    "MJFF-LR<br>(25+)": "25+"
})

In [None]:
df_pres = df_pres[df_pres["Model trained on"]=="1: OxWalk (healthy)"]

In [None]:
df_pres.dropna(inplace=True)

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=1000)
with sns.color_palette("Dark2"):
    sns.boxplot(data=df_pres, x="Population", y="F1 score", width=0.3, ax=ax)
ax.set_xlabel("MJFF-LR subpopulation")
ax.set_ylim(0.45, 1.02)

In [None]:
dfm.Population.value_counts()

In [None]:
dfm_pres = dfm[~dfm["Model trained on"].str.contains("RF")]

In [None]:
dfm_pres.Population = dfm_pres.Population.cat.rename_categories({
    "MJFF-LR<br>(0-15)": "MJFF-LR\n(Least severe)",
    "MJFF-LR<br>(15-20)": "MJFF-LR\n(Less severe)",
    "MJFF-LR<br>(20-25)": "MJFF-LR\n(More severe)",
    "MJFF-LR<br>(25+)": "MJFF-LR\n(Most severe)",
    "OxWalk<br>(Healthy)": "OxWalk\n(Healthy)"
})

In [None]:
dfm_pres.Population.value_counts()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=1000)
with sns.color_palette("Set1"):
    sns.boxplot(data=dfm_pres, x="Population", y="F1 score", hue="Model trained on", ax=ax)
ax.set_xlabel("Testing Population")
plt.axvline(0.5, color='red', linestyle='--')
ax.set_ylim(0.2, 1.01)
plt.savefig("outputs/plots/final_perf.png")