# Bart model

### Import the necessary libraries

In [None]:
from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
from sklearn.preprocessing import StandardScaler
import math
from scipy.stats import norm


from sklearn.model_selection import train_test_split

%config InlineBackend.figure_format = "retina"

print(f"Running on PyMC v{pm.__version__}")

az.style.use("arviz-darkgrid")

### Prepare the data

In [None]:
df_SGA_env = pd.read_csv('env_info.csv')
df_SGA_gag = pd.read_csv('gag_info.csv')

exclude_cols = ['ID', 'TSI', 'multiple', 'art', 'art_start_days']
cols_inf = [
    col for col in df_SGA_env.columns
    if col.startswith("m") and col[1:].isdigit() and int(col[1:]) < 7186
]
cols_sup = [
    col for col in df_SGA_env.columns
    if col.startswith("m") and col[1:].isdigit() and int(col[1:]) >= 7186
]


df_SGA_gag['mean_gag'] = df_SGA_gag.drop(columns=exclude_cols, axis=1).mean(axis=1)
df_SGA_env['mean_env_inf'] = df_SGA_env[cols_inf].mean(axis=1)
df_SGA_env['mean_env_sup'] = df_SGA_env[cols_sup].mean(axis=1)

df_SGA_gag_small = df_SGA_gag[exclude_cols + ['mean_gag']]
df_SGA_env_small = df_SGA_env[exclude_cols + ['mean_env_inf', 'mean_env_sup']]

df_SGA = pd.merge(
    df_SGA_gag_small,
    df_SGA_env_small,
    on=exclude_cols,
    how='inner'
)
df_SGA = df_SGA.fillna(0)
df_SGA.head()

In [None]:
df_ID = pd.read_csv('id_test.csv')
df_SGA_filter = df_SGA.merge(df_ID, 
                            on=['ID', 'TSI'], 
                            how='left', 
                            indicator=True)
df_SGA_train = df_SGA_filter[df_SGA_filter['_merge'] == 'left_only']
df_SGA_test = df_SGA_filter[df_SGA_filter['_merge'] != 'left_only']
df_SGA_train = df_SGA_train.drop(columns=['_merge'])
df_SGA_test = df_SGA_test.drop(columns=['_merge'])

X_train = df_SGA_train[['multiple', 'art', 'art_start_days', 'mean_gag', 'mean_env_inf', 'mean_env_sup']].to_numpy()
Y_train = df_SGA_train[['TSI']].to_numpy().reshape(-1)

X_test = df_SGA_test[['multiple', 'art', 'art_start_days', 'mean_gag', 'mean_env_inf', 'mean_env_sup']].to_numpy()
Y_test = df_SGA_test[['TSI']].to_numpy().reshape(-1)

df_SGA_test

### Model

In [None]:
with pm.Model() as SQRT_regression:
    X_shared = pm.MutableData("X", X_train)
    # Prior
    sqrt_mu = pmb.BART("sqrt_mu", X=X_shared, Y=np.sqrt(Y_train), m=200, beta=2.5)
    mu = pm.Deterministic("mu", sqrt_mu**2)
    sigma = pm.Exponential("sigma", 1)
    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=Y_train)

    trace = pm.sample(tune=500, draws=1500, random_seed=4242, target_accept=0.85)  # target_accept=0.95
    posterior_predictive_oos_regression_train = pm.sample_posterior_predictive(
        trace, var_names=["y_obs", "mu"],
        extend_inferencedata=True, predictions=True, random_seed=123
    )

### Convergence diagnostic

In [None]:
pmb.plot_convergence(trace, var_name="mu")

In [None]:
pm.plot_trace(trace)

In [None]:
az.plot_pair(
    trace,
    var_names=["mu", "sigma"],
    coords={"mu_dim_0": [11, 22, 33, 44, 55]},
    kind="scatter",
    marginals=True
)
plt.show()

### Plot the predictions on the train dataset

In [None]:
with SQRT_regression:
    pm.set_data({"X": X_train})
    posterior_predictive_train = pm.sample_posterior_predictive(
        trace, var_names=["mu", "sigma"], random_seed=123, predictions=True
    )


y_pred_train_samples = posterior_predictive_train.predictions["mu"].values  # (chains, draws, n_obs_train)

# Mean and 95% confidence intervals
y_pred_train_mean = y_pred_train_samples.mean(axis=(0, 1))
y_pred_train_hdi = az.hdi(y_pred_train_samples.reshape(-1, y_pred_train_samples.shape[-1]), hdi_prob=0.95).T

# Filter for the TSI values under ... days
mask = (Y_train <= 1000) & (y_pred_train_mean <= 1000)
Y_train_filt = Y_train[mask]
y_pred_train_mean_filt = y_pred_train_mean[mask]
y_pred_train_hdi_filt = y_pred_train_hdi[:, mask]

# Plot
plt.figure(figsize=(7, 7))
plt.scatter(Y_train_filt, y_pred_train_mean_filt, alpha=0.6, label="Predicted mean (train)")
plt.errorbar(
    Y_train_filt,
    y_pred_train_mean_filt,
    yerr=[y_pred_train_mean_filt - y_pred_train_hdi_filt[0],
          y_pred_train_hdi_filt[1] - y_pred_train_mean_filt],
    fmt='o', alpha=0.3, color="grey", capsize=2
)
plt.plot([Y_train_filt.min(), Y_train_filt.max()],
         [Y_train_filt.min(), Y_train_filt.max()],
         "r--", lw=2, label="Perfect prediction")

plt.xlabel("True values (Train)")
plt.ylabel("Predicted TSI with 95% confidence intervals")
plt.legend()
plt.title("")
plt.show()

### Variable importance plot

In [None]:
feature_names = ['Multiple', 'Art', 'Art start days', 'pdist gag', 'pdist env inf', 'pdsit env sup']
X_train_df = pd.DataFrame(X_train, columns=feature_names)
vi_results = pmb.compute_variable_importance(trace, sqrt_mu, X_train_df)


ax = pmb.plot_variable_importance(vi_results, figsize=(7.9,4.5))
plt.xticks(rotation=30, fontsize=9)

plt.tight_layout()
plt.show()

### Evaluation on the test dataset

In [None]:
with SQRT_regression:
    pm.set_data({"X": X_test})  # Remplace les données
    posterior_predictive_test = pm.sample_posterior_predictive(
        trace, var_names=["mu", "sigma"], random_seed=123, predictions=True
    )

In [None]:
y_pred_samples = posterior_predictive_test.predictions["mu"].values  # (chains, draws, n_obs_test)
y_pred_mean = y_pred_samples.mean(axis=(0, 1))

from sklearn.metrics import r2_score, mean_squared_error

r2 = r2_score(Y_test, y_pred_mean)
rmse = np.sqrt(mean_squared_error(Y_test, y_pred_mean))

print("R²:", r2)
print("RMSE:", rmse)

mu_samples = posterior_predictive_test.predictions["mu"].values      # (chains, draws, obs)
sigma_samples = trace.posterior["sigma"].values  # (chains, draws)
sigma_samples = sigma_samples[..., None]       

log_likelihood = norm.logpdf(Y_test[None, None, :], loc=mu_samples, scale=sigma_samples)
elpd = log_likelihood.mean()

print("ELPD:", elpd)

# Predictions on the test dataset

In [None]:
# Extract posterior samples
mu_samples = ppc_test.predictions["mu"].values      # (chains, draws, n_obs_test)
sigma_samples = ppc_test.predictions["sigma"].values  # (chains, draws)

# Broadcast sigma to match mu
sigma_samples = sigma_samples[..., None]  # (chains, draws, 1)

# Simulate y_obs manually
rng = np.random.default_rng(4242)
y_pred_ppc = rng.normal(mu_samples, sigma_samples)

# Predictive mean
y_pred_mean = y_pred_ppc.mean(axis=(0, 1))

# 95% predictive interval
y_pred_hdi = np.percentile(y_pred_ppc, [2.5, 97.5], axis=(0, 1))

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(Y_test, y_pred_mean, alpha=0.6, label="Predicted mean")

# Add predictive intervals (error bars)
plt.errorbar(Y_test, y_pred_mean,
             yerr=[y_pred_mean - y_pred_hdi[0], y_pred_hdi[1] - y_pred_mean],
             fmt="o", alpha=0.3, color="gray")

plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         "r--", lw=2, label="Perfect prediction")

plt.xlabel("True TSI")
plt.ylabel("Predicted TSI")
plt.legend()
plt.title("")
plt.show()