In [None]:
from results import TemporalValidation
from data_monitoring import DataDrift
import mlflow
from pathlib import Path
import pandas as pd

# %%
figure_folder = Path.cwd() / "figures"
result_folder = Path.cwd().parent / "chirurgie-desire" / "results"


# %%
mlflow.set_tracking_uri(f"sqlite:///{result_folder}/mlruns.db")


In [None]:
from load_data import CATEGORICAL_FEATURES_RENAMED, RENAME_LABELS_DICT, X_test, X_train, date_split, y_test, y_train

## 1. Temporal validation

In [None]:
# Define performance limits
validation_score_mean = {"mean_ROC": 0.82, "std_ROC": 0.02, "mean_Brier": 0.158, "std_Brier": 0.01}

mean_AUC = validation_score_mean["mean_ROC"]
AUC_SD = validation_score_mean["std_ROC"]
auc_lower_limit = mean_AUC - (3 * AUC_SD)

mean_brier = validation_score_mean["mean_Brier"]
brier_SD = validation_score_mean["std_Brier"]
brier_upper_limit = mean_brier + (3 * brier_SD)

test_data = pd.merge(X_test, y_test, left_index=True, right_index=True)

temporal_validation = TemporalValidation(
    models=models,
    target_column="safe_discharge",
    date_column="admission_start_time",
    rolling_window=1,
)
performance_df = temporal_validation.calculate_monthly_performance(test_data)

fig = temporal_validation.plot_monthly_performance(
    title="Model Performance Over Time",
    auc_lower_limit=auc_lower_limit,
    brier_upper_limit=brier_upper_limit,
    plot_auc=True,
    plot_brier=True,
)
if fig:
    fig.show()

## 2. Data Drift Monitoring

In [None]:
X_train_renamed = X_train.rename(columns=RENAME_LABELS_DICT)
X_test_renamed = X_test.rename(columns=RENAME_LABELS_DICT)

# Initialize and run the data drift monitor
drift_monitor = DataDrift(
    X_train=X_train_renamed.drop(columns=["Hospital"]),
    categorical_features=CATEGORICAL_FEATURES_RENAMED,
    timestamp_column="admission_start_time",
)
drift_monitor.fit()
drift_monitor.calculate(X_test_renamed.drop(columns=["Hospital"]))

# Check for univariate alerts
univariate_alerts = drift_monitor.get_univariate_alerts()
if univariate_alerts:
    print("Univariate Drift Alerts for the following features:")
    print(univariate_alerts)
    # Plot drift for alerted features
    drift_plot = drift_monitor.plot_univariate_drift(kind="drift", column_names=univariate_alerts)
    if drift_plot:
        drift_plot.show()

    dist_plot = drift_monitor.plot_univariate_drift(kind="distribution", column_names=univariate_alerts)
    if dist_plot:
        dist_plot.show()

else:
    print("No univariate drift detected.")

# Check for multivariate alerts
multivariate_alerts = drift_monitor.get_multivariate_alerts()
if not multivariate_alerts.empty:
    print("\nMultivariate Drift Alerts:")
    print(multivariate_alerts)
    # Plot multivariate drift
    multi_drift_plot = drift_monitor.plot_multivariate_drift()
    if multi_drift_plot:
        multi_drift_plot.show()
else:
    print("\nNo multivariate drift detected.")