In [None]:
%load_ext dotenv
%dotenv

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from drn import DRNExplainer, crps, rmse, split_and_preprocess

from analysis_utils import (
    calibration_plot,
    crps_wilcoxon_test,
    nll_wilcoxon_test,
    ql90_wilcoxon_test,
    quantile_losses_raw,
    quantile_points,
    quantile_residuals_plots,
    rmse_wilcoxon_test,
)

torch.set_num_threads(1)

In [None]:
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["xtick.labelsize"] = 15
plt.rcParams["ytick.labelsize"] = 15

In [None]:
MODEL_DIR = Path("models/real")
PLOT_DIR = Path("plots/real")
PLOT_DIR.mkdir(parents=True, exist_ok=True)

# Section 6: Real Data

## Sec 6.1: Data Preprocessing

In [None]:
# Read the CSV file into a DataFrame
csv_file_path = "freMPL1.csv"
df = pd.read_csv(csv_file_path)
claims = df.loc[df["ClaimAmount"] > 0, :]

In [None]:
claims["ClaimAmount"].plot.density(color="green", xlim=(0, 30000))
# Setting the title with a larger font size
plt.title("Empirical Density of Truncated Claims", fontsize=20)

# Setting the labels for x and y axes with larger font sizes
plt.xlabel("Claim Amount ($)", fontsize=15)
plt.ylabel("", fontsize=15)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

plt.savefig(PLOT_DIR / "Empirical Density (Real).png")

In [None]:
# Scaling
target = claims["ClaimAmount"] / 1000
features = claims.drop("ClaimAmount", axis=1)
features = features.drop(
    ["RecordBeg", "RecordEnd", "ClaimInd", "Garage"], axis=1
)  # Drop garage due to missing values

# Convert "VehAge" categories to numeric
features["VehAge"] = features["VehAge"].map(
    {
        "0": 0,
        "1": 1,
        "2": 2,
        "3": 3,
        "4": 4,
        "5": 5,
        "6-7": 6,
        "8-9": 8,
        "10+": 11,
    }
)
feature_names = features.columns

speed_ranges = [speed for speed in np.unique(features["VehMaxSpeed"])]
speed_series = pd.Series(speed_ranges)
mapping = {speed_range: i + 1 for i, speed_range in enumerate(speed_ranges)}
features["VehMaxSpeed"] = features["VehMaxSpeed"].map(mapping)
features["SocioCateg"] = features["SocioCateg"].str.extract("(\d+)").astype(int)

cat_features = [
    "HasKmLimit",
    "Gender",
    "MariStat",
    "VehUsage",
    "VehBody",
    "VehPrice",
    "VehEngine",
    "VehEnergy",
    "VehClass",
    "SocioCateg",
]

num_features = [feature for feature in features.columns if feature not in cat_features]

In [None]:
# Split and preprocess the data
(
    x_train,
    x_val,
    x_test,
    y_train,
    y_val,
    y_test,
    x_train_raw,
    x_val_raw,
    x_test_raw,
    num_features,
    cat_features,
    all_categories,
    ct,
) = split_and_preprocess(features, target, num_features, cat_features, seed=0)

# Calculate and print statistics for y_train, y_val, y_test
np.max(y_train), np.median(y_train), np.max(y_val), np.median(y_val), np.max(
    y_test
), np.median(y_test)

In [None]:
X_train = torch.Tensor(x_train.values)
Y_train = torch.Tensor(y_train.values)
X_val = torch.Tensor(x_val.values)
Y_val = torch.Tensor(y_val.values)
X_test = torch.Tensor(x_test.values)
Y_test = torch.Tensor(y_test.values)

train_dataset = torch.utils.data.TensorDataset(X_train, Y_train)
val_dataset = torch.utils.data.TensorDataset(X_val, Y_val)

In [None]:
glm = torch.load(MODEL_DIR / "glm.pkl", weights_only=False)
cann = torch.load(MODEL_DIR / "cann.pkl", weights_only=False)
mdn = torch.load(MODEL_DIR / "mdn.pkl", weights_only=False)
ddr = torch.load(MODEL_DIR / "ddr.pkl", weights_only=False)
drn = torch.load(MODEL_DIR / "drn.pkl", weights_only=False)

## Sec 6.3: Evaluation

In [None]:
names = ["GLM", "CANN", "MDN", "DDR", "DRN"]
models = [glm, cann, mdn, ddr, drn]

print("Generating distributional forecasts")
dists_train = {}
dists_val = {}
dists_test = {}

for name, model in zip(names, models):
    print(f"- {name}")
    dists_train[name] = model.distributions(X_train)
    dists_val[name] = model.distributions(X_val)
    dists_test[name] = model.distributions(X_test)

print("Calculating CDF over a grid")
GRID_SIZE = 3000  # Increase this to get more accurate CRPS estimates
grid = torch.linspace(0, np.max(y_train) * 1.1, GRID_SIZE).unsqueeze(-1)

cdfs_train = {}
cdfs_val = {}
cdfs_test = {}

for name, model in zip(names, models):
    print(f"- {name}")
    cdfs_train[name] = dists_train[name].cdf(grid)
    cdfs_val[name] = dists_val[name].cdf(grid)
    cdfs_test[name] = dists_test[name].cdf(grid)

### NLL

In [None]:
print("Calculating negative loglikelihoods")
nlls_train = {}
nlls_val = {}
nlls_test = {}

for name, model in zip(names, models):
    nlls_train[name] = -dists_train[name].log_prob(Y_train).mean()
    nlls_val[name] = -dists_val[name].log_prob(Y_val).mean()
    nlls_test[name] = -dists_test[name].log_prob(Y_test).mean()


for nll_dict, df_name in zip(
    [nlls_train, nlls_val, nlls_test], ["training", "val", "test"]
):
    print(f"NLL on {df_name} set")
    for name, model in zip(names, models):
        print(f"{name}: {nll_dict[name].mean():.4f}")
    print(f"-------------------------------")

In [None]:
nll_wilcoxon_test(dists_val, Y_val, "Validation")
nll_wilcoxon_test(dists_test, Y_test, "Test")

### CRPS

In [None]:
print("Calculating CRPS")
grid = grid.squeeze()
crps_train = {}
crps_val = {}
crps_test = {}

for name, model in zip(names, models):
    crps_train[name] = crps(Y_train, grid, cdfs_train[name])
    crps_val[name] = crps(Y_val, grid, cdfs_val[name])
    crps_test[name] = crps(Y_test, grid, cdfs_test[name])

for crps_dict, df_name in zip(
    [crps_train, crps_val, crps_test], ["training", "val", "test"]
):
    print(f"CRPS on {df_name} set")
    for name, model in zip(names, models):
        print(f"{name}: {crps_dict[name].mean():.4f}")
    print(f"------------------------------")

In [None]:
crps_wilcoxon_test(cdfs_val, Y_val, grid, "Validation")
crps_wilcoxon_test(cdfs_test, Y_test, grid, "Validation")

### RMSE

In [None]:
rmse_train = {}
rmse_val = {}
rmse_test = {}

for name, model in zip(names, models):
    means_train = dists_train[name].mean
    means_val = dists_val[name].mean
    means_test = dists_test[name].mean
    rmse_train[name] = rmse(y_train, means_train)
    rmse_val[name] = rmse(y_val, means_val)
    rmse_test[name] = rmse(y_test, means_test)

for rmse_dict, df_name in zip(
    [rmse_train, rmse_val, rmse_test], ["training", "validation", "test"]
):
    print(f"RMSE on {df_name} set")
    for name, model in zip(names, models):
        print(f"{name}: {rmse_dict[name].mean():.4f}")
    print(f"-------------------------------")

In [None]:
rmse_wilcoxon_test(dists_val, Y_val, "Validation")
rmse_wilcoxon_test(dists_test, Y_test, "Test")

### 90 Quantile Loss

In [None]:
ql_90_train = {}
ql_90_val = {}
ql_90_test = {}

for features, response, dataset_name, ql_dict in zip(
    [X_train, X_val, X_test],
    [y_train, y_val, y_test],
    ["Training", "Validation", "Test"],
    [ql_90_train, ql_90_val, ql_90_test],
):
    print(f"{dataset_name} Dataset Quantile Loss(es)")
    for model, model_name in zip(models, names):
        ql_dict[model_name] = (
            quantile_losses_raw(  ## TODO from PL: ED to check - this originally didn't have "raw"
                0.9,
                model,
                model_name,
                features,
                response,
                max_iter=1000,
                tolerance=1e-4,
                l=torch.Tensor([0]),
                u=torch.Tensor(
                    [np.max(y_train) + 3 * (np.max(y_train) - np.min(y_train))]
                ),
            )
        )
    print(f"----------------------")

In [None]:
models = (glm, cann, mdn, ddr, drn)
ql90_wilcoxon_test(models, X_val, Y_val, y_train, "Validation")
ql90_wilcoxon_test(models, X_test, Y_test, y_train, "Test")

### Table

In [None]:
def generate_latex_table(
    nlls_val,
    crps_val,
    rmse_val,
    ql_90_val,
    nlls_test,
    crps_test,
    rmse_test,
    ql_90_test,
    model_names,
    label_txt="Evaluation Metrics Table",
    caption_txt="Evaluation Metrics Table.",
    scaling_factor=1.0,
):
    header_row = (
        "\\begin{center}\n"
        + "\captionof{table}{"
        + f"{caption_txt}"
        + "}\n"
        + "\label{"
        + f"{label_txt}"
        + "}\n"
        + "\scalebox{"
        + f"{scaling_factor}"
        + "}{\n"
        + "\\begin{tabular}{l|cccc|cccc}\n\\toprule\n\\toprule\n"
        + "&  \multicolumn{4}{c}{$\mathcal{D}_{\\text{Validation}}$}"
        + "& \multicolumn{4}{c}{ $\mathcal{D}_{\\text{Test}}$}\\\\ \n"
        + " \cmidrule{2-5}  \cmidrule{6-9} $\\text{Model}$ $\\backslash$ $\\text{Metrics}$"
        + " & NLL & CRPS & RMSE & 90\% QL & NLL & CRPS & RMSE & 90\% QL \\\\ \\midrule"
    )
    rows = [header_row]

    for name in model_names:
        row = (
            f"{name} &  {(nlls_val[name].mean()):.4f}"
            f" &  {(crps_val[name].mean()):.4f} "
            f" & {(rmse_val[name].mean()):.4f} "
            f" & {(ql_90_val[name].mean()):.4f} "
            f" & {(nlls_test[name].mean()):.4f} "
            f" & {(crps_test[name].mean()):.4f} "
            f" & {(rmse_test[name].mean()):.4f} "
            f" & {(ql_90_test[name].mean()):.4f} \\\\ "
        )
        rows.append(row)

    table = (
        "\n".join(rows)
        + "\n\\bottomrule\n\\bottomrule"
        + "\n\\end{tabular}"
        + "\n}"
        + "\n\end{center}"
    )
    return table


latex_table = generate_latex_table(
    nlls_val,
    crps_val,
    rmse_val,
    ql_90_val,
    nlls_test,
    crps_test,
    rmse_test,
    ql_90_test,
    names,
    label_txt="Evaluation Metrics",
    caption_txt="Model comparisons based on various evaluation metrics.",
    scaling_factor=0.95,
)
print(latex_table)

### Quantile Residuals

In [None]:
quantile_residuals_plots(quantile_points(cdfs_test, y_test, grid))
plt.savefig(PLOT_DIR / "Quantile Residuals Plot Real.png");

### Calibration

In [None]:
calibration_plot(cdfs_test, y_test, grid)
plt.savefig(PLOT_DIR / "Calibration Plot Real.png");

## Sec 6.4: Interpretability

### Sec 6.4.1 Local Interpretability

#### (a) Extreme Case

In [None]:
means_diff = drn.distributions(X_test).mean - glm.distributions(X_test).mean
# Find the top 5 values and their indices
values, indices = torch.topk(means_diff.view(-1), 3, sorted=True)
multi_indices = np.unravel_index(indices.numpy(), means_diff.shape)

print(multi_indices, means_diff[multi_indices])
idx_first = multi_indices[0][0]
idx_second = multi_indices[0][1]
idx_third = multi_indices[0][2]
y_test.values[multi_indices], drn.distributions(X_test).mean[
    multi_indices
], glm.distributions(X_test).mean[multi_indices],

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)
idx = idx_second
drn_explainer.plot_dp_adjustment_shap(
    instance_raw=x_test_raw.iloc[idx : (idx + 1)],
    method="Kernel",
    nsamples_background_fraction=0.5,
    top_K_features=5,
    labelling_gap=0.1,
    dist_property="Mean",
    x_range=(0.0, 50.0),
    y_range=(0.0, 0.75),
    observation=Y_test[idx : (idx + 1)],
    density_transparency=0.5,
    adjustment=True,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="Explaining a Large Mean Adjustment",
    legend_loc="upper left",
)

plt.savefig(PLOT_DIR / "(Real) Mean Adjustment SHAP.png");

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)
idx = idx_second
drn_explainer.plot_dp_adjustment_shap(
    instance_raw=x_test_raw.iloc[idx : (idx + 1)],
    method="Kernel",
    nsamples_background_fraction=0.5,
    top_K_features=5,
    labelling_gap=0.1,
    dist_property="Mean",
    # other_df_models = [mdn, ddr], model_names = ["MDN", "DDR"],\
    x_range=(0.0, 50.0),
    y_range=(0.0, 0.75),
    observation=Y_test[idx : (idx + 1)],
    density_transparency=0.5,
    adjustment=False,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="Explaining a Large Mean Prediction",
    legend_loc="upper left",
)

plt.savefig(PLOT_DIR / "(Real) Mean Explanation SHAP.png");

#### (b) Average Case

In [None]:
# Calculate the mean differences between DRN and GLM predictions
means_diff = drn.distributions(X_test).mean - glm.distributions(X_test).mean

# Find indices where the percentage change between means is betweenn than 30% and 50%
valid_indices = (0.4 < (torch.abs(means_diff) / glm.distributions(X_test).mean)) & (
    0.8 > (torch.abs(means_diff) / glm.distributions(X_test).mean)
)

# Filter X_test, x_test_raw, and y_test based on these valid indices
X_test_new = X_test[valid_indices]
x_test_raw_new = x_test_raw.iloc[valid_indices.numpy()]
y_test_new = Y_test[valid_indices]

# Recalculate mean differences with the filtered dataset
means_diff_new = drn.distributions(X_test_new).mean - glm.distributions(X_test_new).mean
glm_means_new = glm.distributions(X_test_new).mean

# Find instances where the GLM predictions are close to the actual y values
y_diff = torch.abs(y_test_new - glm_means_new)
close_y_indices = y_diff < 0.3 * torch.abs(
    y_test_new
)  # Assuming 30% closeness threshold

# Filter further based on the closeness of GLM predictions to y_test
X_test_final = X_test_new[close_y_indices]
means_diff_final = means_diff_new[close_y_indices]
y_test_final = y_test_new[close_y_indices]
glm_means_final = glm_means_new[close_y_indices]
drn_means_final = drn.distributions(X_test_final).mean

# Ensure we have enough data points after filtering
if len(means_diff_final) >= 4:
    # Find the top 4 values and their indices based on the filtered dataset
    values, indices = torch.topk(torch.abs(means_diff_final).view(-1), 4, largest=False)

    # Extract the original indices for the closest Four instances
    original_indices = valid_indices.nonzero(as_tuple=True)[0][close_y_indices][indices]

    idx_first = original_indices[0].item()
    idx_second = original_indices[1].item()
    idx_third = original_indices[2].item()
    idx_forth = original_indices[3].item()

    # Output the results for the selected instances
    print("Original Indices:", idx_first, idx_second)
    print("Actual y values:", y_test.values[idx_first], y_test.values[idx_second])
    print(
        "DRN mean predictions:",
        drn.distributions(X_test).mean[idx_first].item(),
        drn.distributions(X_test).mean[idx_second].item(),
    )
    print(
        "GLM mean predictions:",
        glm.distributions(X_test).mean[idx_first].item(),
        glm.distributions(X_test).mean[idx_second].item(),
    )
else:
    print("Not enough data points meet the criteria.")

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)
idx = idx_first
drn_explainer.plot_dp_adjustment_shap(
    instance_raw=x_test_raw.iloc[idx : (idx + 1)],
    method="Kernel",
    nsamples_background_fraction=0.5,
    top_K_features=5,
    labelling_gap=0.1,
    dist_property="Mean",
    # other_df_models = [mdn, ddr], model_names = ["MDN", "DDR"],\
    x_range=(0.0, 6.5),
    y_range=(0.0, 2.5),
    observation=Y_test[idx : (idx + 1)],
    density_transparency=0.5,
    adjustment=True,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="Explaining an Average Mean Adjustment",
    legend_loc="upper left",
)

plt.savefig(PLOT_DIR / "(Real) Average Mean Adjustment SHAP.png");

#### (c) Mild Case

In [None]:
means_diff = drn.distributions(X_test).mean - glm.distributions(X_test).mean

valid_indices = (0.1 < (torch.abs(means_diff) / glm.distributions(X_test).mean)) & (
    0.3 > (torch.abs(means_diff) / glm.distributions(X_test).mean)
)

X_test_new = X_test[valid_indices]
x_test_raw_new = x_test_raw.iloc[valid_indices.numpy()]
y_test_new = Y_test[valid_indices]

means_diff_new = drn.distributions(X_test_new).mean - glm.distributions(X_test_new).mean
glm_means_new = glm.distributions(X_test_new).mean

y_diff = torch.abs(y_test_new - glm_means_new)
close_y_indices = y_diff < 0.2 * torch.abs(
    y_test_new
)  # Assuming 20% closeness threshold

X_test_final = X_test_new[close_y_indices]
means_diff_final = means_diff_new[close_y_indices]
y_test_final = y_test_new[close_y_indices]
glm_means_final = glm_means_new[close_y_indices]
drn_means_final = drn.distributions(X_test_final).mean

if len(means_diff_final) >= 4:
    values, indices = torch.topk(torch.abs(means_diff_final).view(-1), 4, largest=False)
    original_indices = valid_indices.nonzero(as_tuple=True)[0][close_y_indices][indices]

    idx_first = original_indices[0].item()
    idx_second = original_indices[1].item()
    idx_third = original_indices[2].item()
    idx_forth = original_indices[3].item()

    # Output the results for the selected instances
    print("Original Indices:", idx_first, idx_second)
    print("Actual y values:", y_test.values[idx_first], y_test.values[idx_second])
    print(
        "DRN mean predictions:",
        drn.distributions(X_test).mean[idx_first].item(),
        drn.distributions(X_test).mean[idx_second].item(),
    )
    print(
        "GLM mean predictions:",
        glm.distributions(X_test).mean[idx_first].item(),
        glm.distributions(X_test).mean[idx_second].item(),
    )
else:
    print("Not enough data points meet the criteria.")

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)
idx = idx_second
drn_explainer.plot_dp_adjustment_shap(
    instance_raw=x_test_raw.iloc[idx : (idx + 1)],
    method="Kernel",
    nsamples_background_fraction=0.5,
    top_K_features=5,
    labelling_gap=0.1,
    dist_property="Mean",
    # other_df_models = [mdn, ddr], model_names = ["MDN", "DDR"],\
    x_range=(0.0, 6.0),
    y_range=(0.0, 2.0),
    observation=Y_test[idx : (idx + 1)],
    density_transparency=0.5,
    adjustment=True,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="Explaining a Mild Mean Adjustment",
    legend_loc="upper left",
)

plt.savefig(PLOT_DIR / "(Real) Mild Mean Adjustment SHAP.png");

#### CDF Plot

In [None]:
drn_explainer = DRNExplainer(
    drn,
    glm,
    drn.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    column_transformer=ct,
)

idx = idx_first
drn_explainer.cdf_plot(
    instance=x_test_raw.iloc[idx : (idx + 1)],
    nsamples_background_fraction=0.2,
    top_K_features=5,
    labelling_gap=0.1,
    dist_property="90% Quantile",
    quantile_bounds=(
        torch.Tensor([drn.cutpoints[0]]),
        torch.Tensor([drn.cutpoints[-1]]),
    ),
    x_range=(0.0, 8.5),
    y_range=(0.0, 1.0),
    density_transparency=0.9,
    adjustment=True,
    shap_fontsize=15,
    figsize=(7, 7),
    plot_title="90% Quantile Adjustment Explanation",
)

### Sec 6.4.2: Global Interpretability

In [None]:
drn_explainer = DRNExplainer(
    drn, glm, drn.cutpoints, x_train_raw, cat_features, all_categories, ct
)
kernel_shap_drn = drn_explainer.kernel_shap(
    explaining_data=x_test_raw,
    distributional_property="Mean",
    nsamples_background_fraction=0.2,
    adjustment=True,
    glm_output=True,
)

In [None]:
kernel_shap_drn.global_importance_plot(num_features + cat_features, output="drn")
plt.tight_layout()
plt.savefig(PLOT_DIR / "(Real) Global Importance.png");

In [None]:
kernel_shap_drn.beeswarm_plot(num_features + cat_features, output="drn")
plt.tight_layout()
plt.savefig(PLOT_DIR / "(Real) Beeswarm Summary Plot.png");

In [None]:
kernel_shap_drn.shap_dependence_plot(("MariStat", "Exposure"), output="drn")
plt.tight_layout()
plt.savefig(PLOT_DIR / "(SHAP Dependence) MariState X Exposure.png");

In [None]:
kernel_shap_drn.shap_dependence_plot(("MariStat", "LicAge"), output="drn")
plt.tight_layout()
plt.savefig(PLOT_DIR / "(SHAP Dependence) MariState X LicAge.png");

In [None]:
kernel_shap_drn.global_importance_plot(num_features + cat_features, output="value")
plt.savefig(PLOT_DIR / "(Real Adjustment) Global Importance.png");

In [None]:
kernel_shap_drn.beeswarm_plot(num_features + cat_features, output="value")
plt.savefig(PLOT_DIR / "(Real Adjustment) Beeswarm Summary Plot.png");

In [None]:
kernel_shap_drn.shap_dependence_plot(("VehEnergy", "Exposure"), output="value")
plt.savefig(PLOT_DIR / "(SHAP Adjustment) VehEnergy X Exposure.png");

In [None]:
kernel_shap_drn.shap_dependence_plot(("RiskVar", "Exposure"), output="value")
plt.savefig(PLOT_DIR / "(SHAP Adjustment) RiskVar X Exposure.png");

In [None]:
kernel_shap_drn.shap_dependence_plot(("LicAge", "DrivAge"), output="value")
plt.savefig(PLOT_DIR / "(SHAP Adjustment) LicAge X DrivAge.png");