In [None]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from treeffuser import Treeffuser

from sklearn.ensemble import (
    HistGradientBoostingClassifier,
    HistGradientBoostingRegressor,
)
from sklearn.model_selection import cross_val_predict, LeaveOneOut
from sklearn.metrics import ConfusionMatrixDisplay, mean_squared_error, roc_auc_score

from data import load_amp, load_uke, GENERATIVE_COLUMNS

X_amp, covariates_amp = load_amp(
    "../data/updrs_amp_all.csv", sample_one_measurement_per_subject=True
)
X_amp[GENERATIVE_COLUMNS]

In [None]:
x_uke, covariates_uke, extra_data_uke, y_uke = load_uke("../data/pdq_uke_new.csv")

valid_measurements = (~pd.isna(x_uke["PDQ"])) & (~pd.isna(y_uke["PDQ"]))
x_uke = x_uke[valid_measurements]
covariates_uke = covariates_uke[valid_measurements]
extra_data_uke = extra_data_uke[valid_measurements]
y_uke = y_uke[valid_measurements]

x_uke

In [None]:
agg_covariates = pd.concat((
    covariates_uke.reset_index(names="Subject").melt(id_vars="Subject", var_name="Covariate", value_name="Value").assign(Cohort="UKE"),
    covariates_amp.reset_index(names="Subject").melt(id_vars="Subject", var_name="Covariate", value_name="Value").assign(Cohort="AMP"),
), ignore_index=True)

with sns.axes_style("whitegrid"):
    sns.set_context("paper")
    
    covariates = agg_covariates.groupby("Covariate")
    fig, axes = plt.subplots(1, len(covariates), figsize=(5 * len(covariates), 5))
    for (covariate, covariate_data), ax in zip(covariates, axes):
        sns.histplot(covariate_data, x="Value", hue="Cohort", discrete=True, stat="probability", common_norm=False, ax=ax)
        ax.set_xlabel(covariate)

## Train the model

In [10]:
def preprocess_covariates(covariates):
    covariates = covariates.copy()
    covariates["Sex"] = covariates["Sex"].map({"Male": 0.0, "Female": 1.0})
    covariates["Education"] = covariates["Education"].cat.codes.astype(float).replace(-1, np.nan)
    return covariates.astype(float)

covariates_uke = preprocess_covariates(covariates_uke)
covariates_amp = preprocess_covariates(covariates_amp)

In [None]:
model = Treeffuser(seed=42)
model.fit(covariates_amp.to_numpy().astype(np.float32), X_amp[GENERATIVE_COLUMNS].to_numpy().astype(np.float32))

In [None]:
def generate_samples(model, x, x_covariates, n_samples: int = 50, seed: int = 42) -> pd.DataFrame:
    generated_samples = []
    prediction = np.clip(
        model.sample(x_covariates, n_samples=n_samples, seed=seed), 0, 100
    )

    for i, (subject_id, ground_truth) in enumerate(x.iterrows()):
        for i_key, (key, value) in enumerate(ground_truth.items()):
            if pd.isna(value):
                continue

            generated_samples.append(
                pd.DataFrame.from_dict(
                    {
                        "Prediction": prediction[:, i, i_key],
                        "Ground truth": value,
                        "Score": key,
                        "Subject": subject_id,
                        "Sample ID": np.arange(n_samples),
                    }
                )
            )

    return pd.concat(generated_samples, ignore_index=True)


uke_samples = generate_samples(
    model, x_uke[GENERATIVE_COLUMNS], covariates_uke.to_numpy().astype(np.float32), n_samples=500
)

In [None]:
with sns.axes_style("whitegrid"):
    sns.set_context("paper")
    sns.catplot(
        data=uke_samples, x="Subject", y="Prediction", col="Score", col_wrap=3, kind="box"
    )

In [None]:
selection = uke_samples[
    (uke_samples["Score"] == "UPDRS I") & (uke_samples["Subject"] == 0)
]

with sns.axes_style("whitegrid"):
    sns.set_context("paper")
    sns.histplot(selection, x="Prediction", binwidth=1, kde=True, discrete=True, stat="probability")

print(selection["Prediction"].mean())
print(selection["Prediction"].median())

In [None]:
def calculate_representations(samples: pd.DataFrame):
    representations = []
    for subject, data in samples.groupby("Subject"):
        for score, score_data in data.groupby("Score"):
            # mean = score_data["Prediction"].median()
            # std = score_data["Prediction"].std()
            # mean = score_data["Prediction"].median()
            # std = (score_data["Prediction"] - mean).abs().median() + 1e-6
            # ground_truth_score = (ground_truth - mean) / std

            ground_truth = score_data["Ground truth"].iloc[0]
            ground_truth_score = round((score_data["Prediction"] > ground_truth).sum() / len(
                score_data
            ), 2)
            representations.append((subject, score, ground_truth_score))

    return (
        pd.DataFrame.from_records(representations, columns=["Subject", "Score", "Quantity"])
        .pivot(index="Subject", columns="Score", values="Quantity")
        .reset_index()
        .set_index("Subject")
    )

representations = calculate_representations(uke_samples)
representations

In [None]:
regscores = representations.copy()
regscores["Normalized PDQ"] = (x_uke["PDQ"] - y_uke["PDQ"]) / (
    x_uke["PDQ"] + y_uke["PDQ"]
)
regscores = regscores.reset_index().melt(
    id_vars=["Subject", "Normalized PDQ"], value_name="Quantity"
)

with sns.axes_style("whitegrid"):
    sns.set_context("paper")

    g = sns.lmplot(
        data=regscores,
        x="Quantity",
        y="Normalized PDQ",
        col="Score",
        col_wrap=3,
        facet_kws=dict(sharex=False, sharey=False),
        robust=True,
    )
    g.refline(y=0)

## Predict the effect

In [49]:
CONDITIONS = {
    "demographical data, time, and QoL": pd.concat(
        (covariates_uke[["Age", "Sex", "Time since diagnosis"]], x_uke["PDQ"], extra_data_uke), axis=1
    ),
    "demographical data and scores (QoL included)": pd.concat(
        (covariates_uke[["Age", "Sex", "Time since diagnosis"]], x_uke, extra_data_uke), axis=1
    ),
    "relative scores, time since diagnosis and QoL": pd.concat(
        (representations, x_uke["PDQ"], covariates_uke[["Time since diagnosis"]]), axis=1
    ),
    "relative scores, time since diagnosis and surgery": pd.concat(
        (representations, extra_data_uke, covariates_uke[["Time since diagnosis"]]), axis=1
    ),
    "relative scores and time": pd.concat(
        (representations, extra_data_uke), axis=1
    ),
    "relative scores only": representations,
}

cv = LeaveOneOut()
y = (x_uke["PDQ"] - y_uke["PDQ"]) / (x_uke["PDQ"] + y_uke["PDQ"])

In [52]:
predictions = []
for condition, x in CONDITIONS.items():
    y_pred = cross_val_predict(
        HistGradientBoostingRegressor(learning_rate=0.008377625068989763, max_depth=3, min_samples_leaf=3, max_features=0.5679354907073609),
        x,
        y,
        cv=cv,
    )

    predictions.append(
        pd.DataFrame(
            {"Normalized PDQ": y, "Model prediction": y_pred, "Condition": condition}
        )
    )

predictions = pd.concat(predictions, ignore_index=True)

In [54]:
from sklearn.discriminant_analysis import StandardScaler
from sklearn.linear_model import HuberRegressor
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.pipeline import Pipeline

predictions = []
for condition, x in CONDITIONS.items():
    y_pred = cross_val_predict(
        Pipeline(
            [
                ("imputer", IterativeImputer(random_state=42)),
                ("scaler", StandardScaler()),
                ("regressor", HuberRegressor()),
            ]
        ),
        x,
        y,
        cv=cv,
    )

    predictions.append(
        pd.DataFrame(
            {"Normalized PDQ": y, "Model prediction": y_pred, "Condition": condition}
        )
    )

predictions = pd.concat(predictions, ignore_index=True)

In [None]:
from sklearn.base import BaseEstimator, ClassifierMixin

class TreeffuserCv(BaseEstimator, ClassifierMixin):
    def __init__(self, seed: int = 42):
        self.seed = seed

    def fit(self, X, y):
        self.model = Treeffuser(seed=self.seed)
        self.model.fit(X.to_numpy().astype(np.float32), y.to_numpy().astype(np.float32))

    def predict(self, X):
        return self.model.sample(X.to_numpy().astype(np.float32), n_samples=100, seed=self.seed).mean(axis=0)

cv = LeaveOneOut()
y = (x_uke["PDQ"] - y_uke["PDQ"]) / (x_uke["PDQ"] + y_uke["PDQ"])

y_pred = cross_val_predict(
    TreeffuserCv(seed=42),
    representations,
    y,
    cv=cv,
)

predictions = pd.DataFrame({
    "Normalized PDQ": y,
    "Model prediction": y_pred,
    "Condition": "Linear model"
})

In [None]:
def annotate(data, **kws):
    ax = plt.gca()
    #ax.plot(
    #    [0, 1], [0, 1], transform=ax.transAxes, color="gray", linestyle="--", alpha=0.3
    #)

    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    ax.fill_betweenx(
        y=[max(y_min, 0), y_min],
        x1=x_min,
        x2=x_max,
        color=(251 / 255, 212 / 255, 183 / 255),
        zorder=-100,
    )
    ax.fill_betweenx(
        y=[0.0, y_max],
        x1=x_min,
        x2=x_max,
        color=(188 / 255, 220 / 255, 190 / 255),
        zorder=-100,
    )
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

    ax.set_xlabel(f"Model prediction with {data['Condition'].iloc[0]}")
    ax.set_title(
        f"MSE: {mean_squared_error(data['Normalized PDQ'], data['Model prediction']):.3f}"
    )


with sns.axes_style("white"):
    sns.set_context("paper")

    grid = sns.lmplot(
        data=predictions,
        y="Normalized PDQ",
        x="Model prediction",
        col="Condition",
        robust=True,
    )
    grid.refline(x=0, y=0)
    grid.map_dataframe(annotate)


### Estimate the stochastic component

In [None]:
stochastic_component = pd.read_csv("mse.csv", index_col="Seed")
sns.ecdfplot(data=stochastic_component, x="MSE", hue="Condition")

round(stochastic_component["MSE"].mean(), 3)

## Classify the effect

In [15]:
cv = LeaveOneOut()
y = (x_uke["PDQ"] - y_uke["PDQ"]) >= 4.72

predictions = []
for condition, x in CONDITIONS.items():
    y_pred = cross_val_predict(
        HistGradientBoostingClassifier(learning_rate=0.008377625068989763, max_depth=3, min_samples_leaf=3, max_features=0.5679354907073609),
        x,
        y,
        cv=cv,
    )

    predictions.append(
        pd.DataFrame({"Ground truth": y, "Prediction": y_pred, "Condition": condition})
    )

predictions = pd.concat(predictions, ignore_index=True)

In [None]:
with sns.axes_style("white"):
    sns.set_context("paper")

    fig, axes = plt.subplots(1, len(predictions.groupby("Condition")), figsize=(12, 4))
    for i, ((condition, data), ax) in enumerate(
        zip(predictions.groupby("Condition"), axes)
    ):
        ConfusionMatrixDisplay.from_predictions(
            y_true=data["Ground truth"],
            y_pred=data["Prediction"],
            ax=ax,
            colorbar=False,
            cmap="Reds",
            display_labels=["No improvement", "Improvement"],
        )
        ax.set_title(
            f"{condition}\nAUC: {roc_auc_score(data['Ground truth'], data['Prediction']):.2f}"
        )
        if i > 0:
            ax.set_ylabel("")
            ax.set_yticklabels(["", ""])