In [None]:
from dotenv import load_dotenv, find_dotenv
assert load_dotenv(find_dotenv(usecwd=False)), "The .env file was not loaded."

import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from drn import DRNExplainer, crps, rmse, split_and_preprocess
from skopt.plots import plot_objective

from analysis_utils import (
    calibration_plot,
    crps_wilcoxon_test,
    generate_latex_table,
    nll_wilcoxon_test,
    ql90_wilcoxon_test,
    quantile_losses_raw,
    quantile_points,
    quantile_residuals_plots,
    rmse_wilcoxon_test,
)
from generate_synthetic_dataset import generate_synthetic_gamma_lognormal

torch.set_num_threads(1)

In [None]:
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["xtick.labelsize"] = 15
plt.rcParams["ytick.labelsize"] = 15

In [None]:
MODEL_DIR = Path("models/synth")
PLOT_DIR = Path("plots/synth")
PLOT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
features, target, means, dispersion = generate_synthetic_gamma_lognormal(20000)

In [None]:
(
    x_train,
    x_val,
    x_test,
    y_train,
    y_val,
    y_test,
    x_train_raw,
    x_val_raw,
    x_test_raw,
    num_features,
    cat_features,
    all_categories,
    ct,
) = split_and_preprocess(
    features, target, ["X_1", "X_2"], [], seed=42, num_standard=False
)
x_train

In [None]:
X_train = torch.Tensor(x_train.values)
Y_train = torch.Tensor(y_train.values)
X_val = torch.Tensor(x_val.values)
Y_val = torch.Tensor(y_val.values)
X_test = torch.Tensor(x_test.values)
Y_test = torch.Tensor(y_test.values)

train_dataset = torch.utils.data.TensorDataset(X_train, Y_train)
val_dataset = torch.utils.data.TensorDataset(X_val, Y_val)

In [None]:
glm = torch.load(MODEL_DIR / "glm.pkl", weights_only=False)
cann = torch.load(MODEL_DIR / "cann.pkl", weights_only=False)
mdn = torch.load(MODEL_DIR / "mdn.pkl", weights_only=False)
ddr = torch.load(MODEL_DIR / "ddr.pkl", weights_only=False)
drn = torch.load(MODEL_DIR / "drn.pkl", weights_only=False)

## Sec 5.3.1: Evaluation

In [None]:
names = ["GLM", "CANN", "MDN", "DDR", "DRN"]
models = [glm, cann, mdn, ddr, drn]

print("Generating distributional forecasts")
dists_train = {}
dists_val = {}
dists_test = {}

for name, model in zip(names, models):
    print(f"- {name}")
    dists_train[name] = model.distributions(X_train)
    dists_val[name] = model.distributions(X_val)
    dists_test[name] = model.distributions(X_test)

In [None]:
print("Calculating CDF over a grid")
GRID_SIZE = 3000  # Increase this to get more accurate CRPS estimates
grid = torch.linspace(0, np.max(y_train) * 1.1, GRID_SIZE).unsqueeze(-1)

cdfs_train = {}
cdfs_val = {}
cdfs_test = {}

for name, model in zip(names, models):
    print(f"- {name}")
    cdfs_train[name] = dists_train[name].cdf(grid)
    cdfs_val[name] = dists_val[name].cdf(grid)
    cdfs_test[name] = dists_test[name].cdf(grid)

### NLL

In [None]:
print("Calculating negative loglikelihoods")
nlls_train = {}
nlls_val = {}
nlls_test = {}

for name, model in zip(names, models):
    nlls_train[name] = -dists_train[name].log_prob(Y_train).mean()
    nlls_val[name] = -dists_val[name].log_prob(Y_val).mean()
    nlls_test[name] = -dists_test[name].log_prob(Y_test).mean()

for nll_dict, df_name in zip(
    [nlls_train, nlls_val, nlls_test], ["training", "val", "test"]
):
    print(f"NLL on {df_name} set")
    for name, model in zip(names, models):
        print(f"{name}: {nll_dict[name]:.4f}")
    print(f"-------------------------------")

In [None]:
nll_wilcoxon_test(dists_val, Y_val, "Validation")
nll_wilcoxon_test(dists_test, Y_test, "Test")

### CRPS

In [None]:
print("Calculating CRPS")
grid = grid.squeeze()
crps_train = {}
crps_val = {}
crps_test = {}


for name, model in zip(names, models):
    crps_train[name] = crps(Y_train, grid, cdfs_train[name])
    crps_val[name] = crps(Y_val, grid, cdfs_val[name])
    crps_test[name] = crps(Y_test, grid, cdfs_test[name])

for crps_dict, df_name in zip(
    [crps_train, crps_val, crps_test], ["training", "val", "test"]
):
    print(f"CRPS on {df_name} set")
    for name, model in zip(names, models):
        print(f"{name}: {crps_dict[name].mean():.4f}")
    print(f"------------------------------")

In [None]:
crps_wilcoxon_test(cdfs_val, Y_val, grid, "Validation")
crps_wilcoxon_test(cdfs_test, Y_test, grid, "Validation")

### RMSE

In [None]:
rmse_train = {}
rmse_val = {}
rmse_test = {}

for name, model in zip(names, models):
    means_train = dists_train[name].mean
    means_val = dists_val[name].mean
    means_test = dists_test[name].mean
    rmse_train[name] = rmse(y_train, means_train)
    rmse_val[name] = rmse(y_val, means_val)
    rmse_test[name] = rmse(y_test, means_test)

for rmse_dict, df_name in zip(
    [rmse_train, rmse_val, rmse_test], ["training", "validation", "test"]
):
    print(f"RMSE on {df_name} set")
    for name, model in zip(names, models):
        print(f"{name}: {rmse_dict[name].mean():.4f}")
    print(f"-------------------------------")

In [None]:
rmse_wilcoxon_test(dists_val, Y_val, "Validation")
rmse_wilcoxon_test(dists_test, Y_test, "Test")

### 90 Quantile Loss

In [None]:
ql_90_train = {}
ql_90_val = {}
ql_90_test = {}

for features, response, dataset_name, ql_dict in zip(
    [X_train, X_val, X_test],
    [y_train, y_val, y_test],
    ["Training", "Validation", "Test"],
    [ql_90_train, ql_90_val, ql_90_test],
):
    print(f"{dataset_name} Dataset Quantile Loss(es)")
    for model, model_name in zip(models, names):
        ql_dict[model_name] = (
            quantile_losses_raw(  ## TODO from PL: ED to check - this originally didn't have "raw"
                0.9,
                model,
                model_name,
                features,
                response,
                max_iter=1000,
                tolerance=1e-4,
                l=torch.Tensor([0]),
                u=torch.Tensor(
                    [np.max(y_train) + 3 * (np.max(y_train) - np.min(y_train))]
                ),
            )
        )
    print(f"----------------------")

In [None]:
models = (glm, cann, mdn, ddr, drn)
ql90_wilcoxon_test(models, X_val, Y_val, y_train, "Validation")
ql90_wilcoxon_test(models, X_test, Y_test, y_train, "Test")

### Table

In [None]:
latex_table = generate_latex_table(
    nlls_val,
    crps_val,
    rmse_val,
    ql_90_val,
    nlls_test,
    crps_test,
    rmse_test,
    ql_90_test,
    names,
    label_txt="Evaluation Metrics",
    caption_txt="Model comparisons based on various evaluation metrics.",
    scaling_factor=0.95,
)
print(latex_table)

### Quantile Residuals

In [None]:
quantile_residuals_plots(quantile_points(cdfs_test, y_test, grid))
plt.savefig(PLOT_DIR / "Quantile Residuals Plot Synthetic.png")
plt.show()

### Calibration

In [None]:
calibration_plot(cdfs_test, y_test, grid)
plt.savefig(PLOT_DIR / "Calibration Plot Synthetic.png");

## Sec 5.3.2: Interpretability

In [None]:
# Investigated Instance
x_1 = 0.1
x_2 = 0.1

### Local: Density Plot and Kernel SHAP

In [None]:
# Initialise Explainer
drn_explainer = DRNExplainer(
    drn, glm, drn.cutpoints, x_train_raw, cat_features, all_categories, ct
)

# Plot adjustment factors
drn_explainer.plot_adjustment_factors(
    instance=pd.DataFrame(np.array([x_1, x_2]).reshape(1, 2), columns=["X_1", "X_2"]),
    num_interpolations=1000,
    plot_adjustments_labels=False,
    x_range=(0, 6),
    synthetic_data=generate_synthetic_gamma_lognormal,
)

plt.savefig(PLOT_DIR / "(0.1, 0.1) Density Plot.png")

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)

# Plot DP adjustment SHAP for Mean Adjustment Explanation
drn_explainer.plot_dp_adjustment_shap(
    instance_raw=pd.DataFrame(
        np.array([x_1, x_2]).reshape(1, 2), columns=["X_1", "X_2"]
    ),
    method="Kernel",
    nsamples_background_fraction=0.5,
    top_K_features=3,
    labelling_gap=0.12,
    dist_property="Mean",
    x_range=(2.14, 2.23),
    y_range=(0.0, 0.75),
    observation=True,
    density_transparency=0.9,
    adjustment=True,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="Mean Adjustment Explanation",
    synthetic_data=generate_synthetic_gamma_lognormal,
    legend_loc="upper left",
)

plt.savefig(PLOT_DIR / "(0.1, 0.1) Mean Adjustment Plot.png")

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)

drn_explainer.cdf_plot(
    instance=pd.DataFrame(np.array([x_1, x_2]).reshape(1, 2), columns=["X_1", "X_2"]),
    method="Kernel",
    nsamples_background_fraction=0.005,
    top_K_features=3,
    labelling_gap=0.15,
    dist_property="90% Quantile",
    quantile_bounds=(
        torch.Tensor([drn.cutpoints[0]]),
        torch.Tensor([drn.cutpoints[-1] * 2]),
    ),
    x_range=(3.4, 3.58),
    y_range=(0.87, 0.93),
    density_transparency=0.9,
    adjustment=True,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="90% Quantile Adjustment Explanation",
    synthetic_data=generate_synthetic_gamma_lognormal,
)

plt.savefig(PLOT_DIR / "(0.1, 0.1) Quantile Adjustment Plot.png")

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)

# Plot CDF for 90% Quantile Explanation
drn_explainer.cdf_plot(
    instance=pd.DataFrame(np.array([x_1, x_2]).reshape(1, 2), columns=["X_1", "X_2"]),
    method="Kernel",
    nsamples_background_fraction=0.05,
    top_K_features=3,
    labelling_gap=0.16,
    dist_property="90% Quantile",
    quantile_bounds=(
        torch.Tensor([drn.cutpoints[0]]),
        torch.Tensor([drn.cutpoints[-1] * 2]),
    ),
    x_range=(0, 8),
    y_range=(0.0, 1.0),
    density_transparency=0.9,
    adjustment=False,
    plot_baseline=False,
    synthetic_data=generate_synthetic_gamma_lognormal,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="90% Quantile Explanation",
)

plt.savefig(PLOT_DIR / "(0.1, 0.1) Quantile Explanation Plot.png")

### Global: SHAP Dependence

In [None]:
# Initialise DRNExplainer
drn_explainer = DRNExplainer(
    drn, glm, drn.cutpoints, x_train_raw, cat_features, all_categories, ct
)

# Calculate Kernel SHAP values for the DRN model
kernel_shap_drn = drn_explainer.kernel_shap(
    explaining_data=x_test_raw,
    distributional_property="Mean",  # can change to 'XX% Quantile',
    nsamples_background_fraction=0.5,
    adjustment=True,
    glm_output=True,
)

In [None]:
kernel_shap_drn.global_importance_plot(num_features + cat_features, output="drn")
plt.savefig(PLOT_DIR / "(Synthetic) SHAP Importance Mean.png");

In [None]:
kernel_shap_drn.beeswarm_plot(num_features + cat_features, output="drn")
plt.savefig(PLOT_DIR / "(Synthetic) SHAP Beeswarm Mean.png");

In [None]:
kernel_shap_drn.shap_dependence_plot(("X_1", "X_2"), output="drn")
plt.savefig(PLOT_DIR / "(Synthetic) SHAP Dependence Mean.png");

In [None]:
for name in names[1:]:
    with open(MODEL_DIR / f"{name.lower()}_hp.pkl", "rb") as f:
        res_hp = pickle.load(f)

        with plt.rc_context(
            {
                "xtick.labelsize": "x-small",
                "ytick.labelsize": "x-small",
                "axes.labelsize": "x-small",
                "axes.titlesize": "x-small",
            }
        ):
            plot_objective(res_hp)
            plt.savefig(PLOT_DIR / f"{name.lower()}_hp.png", bbox_inches="tight")