In [None]:
%load_ext dotenv
%dotenv

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
import torch
from drn import DRNExplainer, split_and_preprocess

from generate_synthetic_dataset import generate_synthetic_gaussian

torch.set_num_threads(1)

In [None]:
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["xtick.labelsize"] = 15
plt.rcParams["ytick.labelsize"] = 15

In [None]:
MODEL_DIR = Path("models/reg")
PLOT_DIR = Path("plots/reg")
PLOT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
features, target, means, dispersion = generate_synthetic_gaussian(40000)
(
    x_train,
    x_val,
    x_test,
    y_train,
    y_val,
    y_test,
    x_train_raw,
    x_val_raw,
    x_test_raw,
    num_features,
    cat_features,
    all_categories,
    ct,
) = split_and_preprocess(
    features, target, ["X_1", "X_2"], [], seed=0, num_standard=False
)

In [None]:
X_train = torch.Tensor(x_train.values)
Y_train = torch.Tensor(y_train.values)
X_val = torch.Tensor(x_val.values)
Y_val = torch.Tensor(y_val.values)
X_test = torch.Tensor(x_test.values)
Y_test = torch.Tensor(y_test.values)

train_dataset = torch.utils.data.TensorDataset(X_train, Y_train)
val_dataset = torch.utils.data.TensorDataset(X_val, Y_val)

In [None]:
glm = torch.load(MODEL_DIR / "glm.pkl", weights_only=False)
drn_no_penalty = torch.load(MODEL_DIR / "drn_no_penalty.pkl", weights_only=False)
drn_kl_penalty = torch.load(MODEL_DIR / "drn_kl_penalty.pkl", weights_only=False)
drn_dv_large_penalty = torch.load(
    MODEL_DIR / "drn_dv_large_penalty.pkl", weights_only=False
)
drn_everything = torch.load(MODEL_DIR / "drn_everything.pkl", weights_only=False)

In [None]:
x_1 = 0.5
x_2 = 0.5
instance = pd.DataFrame(np.array([x_1, x_2]).reshape(1, 2), columns=["X_1", "X_2"])

true_mean = -x_1 + x_2
true_scale = 0.5 * (x_1**2 + x_2**2)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 3.5))

# Calculate the PDF with 'true_mean' and 'true_scale' parameters
true_y_grid = np.linspace(-1.25, 1.25, 1000)
true_densities = scipy.stats.norm.pdf(true_y_grid, loc=true_mean, scale=true_scale)

# Create the Explainer class to generate plots
drn_no_penalty_exp = DRNExplainer(
    drn_no_penalty,
    glm,
    drn_no_penalty.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    ct,
)
drn_no_penalty_exp.plot_adjustment_factors(
    instance=instance,
    num_interpolations=1_000,
    plot_adjustments_labels=False,
    axes=ax1,
    x_range=(-1.25, 1.25),
    y_range=(0, 2),
    plot_title="",
    plot_y_label="$f(y|\\boldsymbol{X}=(0.5, 0.5)^{\\top})$",
)

ax1.plot(true_y_grid, true_densities, color="red", label="True", lw=2, zorder=-1)
plt.legend()
sns.despine()

drn_kl_penalty_exp = DRNExplainer(
    drn_kl_penalty,
    glm,
    drn_kl_penalty.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    ct,
)
drn_kl_penalty_exp.plot_adjustment_factors(
    instance=instance,
    num_interpolations=1_000,
    plot_adjustments_labels=False,
    axes=ax2,
    x_range=(-1.25, 1.25),
    y_range=(0, 2),
    plot_title="",
    plot_y_label="$f(y|\\boldsymbol{X}=(0.5, 0.5)^{\\top})$",
)

ax2.plot(true_y_grid, true_densities, color="red", label="True", lw=2, zorder=-1)
plt.legend()
sns.despine()

plt.tight_layout()
plt.savefig(PLOT_DIR / "DRN KL Penalty.png");

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 3.5))

# Calculate the PDF with 'true_mean' and 'true_scale' parameters
true_y_grid = np.linspace(-1.25, 1.25, 1000)
true_densities = scipy.stats.norm.pdf(true_y_grid, loc=true_mean, scale=true_scale)

drn_explainer = DRNExplainer(
    drn_dv_large_penalty,
    glm,
    drn_dv_large_penalty.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    ct,
)
drn_explainer.plot_adjustment_factors(
    instance=instance,
    num_interpolations=1_000,
    plot_adjustments_labels=False,
    axes=ax1,
    x_range=(-1.25, 1.25),
    y_range=(0, 2),
    plot_title="",
    plot_y_label="$f(y|\\boldsymbol{X}=(0.5, 0.5)^{\\top})$",
)

ax1.plot(true_y_grid, true_densities, color="red", label="True", lw=2, zorder=-1)
plt.legend()
sns.despine()

drn_explainer = DRNExplainer(
    drn_everything,
    glm,
    drn_everything.cutpoints,
    x_train_raw,
    cat_features,
    all_categories,
    ct,
)
drn_explainer.plot_adjustment_factors(
    instance=instance,
    num_interpolations=1_000,
    plot_adjustments_labels=False,
    axes=ax2,
    x_range=(-1.25, 1.25),
    y_range=(0, 2),
    plot_title="",
    plot_y_label="$f(y|\\boldsymbol{X}=(0.5, 0.5)^{\\top})$",
)

ax2.plot(true_y_grid, true_densities, color="red", label="True", lw=2, zorder=-1)
plt.legend()
sns.despine()

plt.tight_layout()
plt.savefig(PLOT_DIR / "DRN DV Penalty.png");