In [None]:
%%capture output
!pip install lightning-uq-box
!pip install SALib
!pip install uncertainty-toolbox

# Install library for learning deep UQ baselines.
!git clone https://github.com/uncertainty-toolbox/simple-uq
!pip install -e ./simple-uq
%mv simple-uq/simple_uq .

In [None]:
from collections.abc import Callable

import gpytorch
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
from lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer, OneHotEncoder, LabelEncoder, StandardScaler
from torch.utils.data import DataLoader, TensorDataset

# from .utils import collate_fn_tensordataset

import math

def convert_float64(X):
    return X.astype(np.float64)

import os
import tempfile
from functools import partial

import torch.nn as nn
from laplace import Laplace
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping

from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import MVERegression, DeterministicRegression, LaplaceRegression, NLL, BNN_VI_ELBO_Regression, BNN_VI_Regression
from lightning_uq_box.viz_utils import (
    plot_calibration_uq_toolbox,
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

from SALib.sample import saltelli
from SALib.analyze import sobol
import seaborn as sns

import uncertainty_toolbox as uct

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2

In [None]:
# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.

"""Utility functions for datamodules."""

def collate_fn_tensordataset(batch):
    """Collate function for tensor dataset to our framework."""
    inputs = torch.stack([item[0] for item in batch])
    targets = torch.stack([item[1] for item in batch])
    return {"input": inputs, "target": targets}

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Data Processing

In [None]:
st_pete_property_df = pd.read_csv('/content/drive/MyDrive/why_people_stay/zillow_parcel_census_coast_flood_st_pete.csv', index_col=0)

In [None]:
st_pete_property_df.columns

In [None]:
st_pete_property_df = st_pete_property_df[['zpid', 'streetAddress', 'zipcode', 'city', 'state',
                                           'latitude', 'longitude', 'price', 'bathrooms', 'bedrooms', 'livingArea', 'POOL',
                                           'EVAC_ZONE', 'GEOID', 'hhinc_k',
                                           'flood_risk', 'MC_std', 'dist2coast']].dropna().reset_index(drop=True)

In [None]:
len(st_pete_property_df)

In [None]:
st_pete_property_df['lg_price'] = np.log10(st_pete_property_df.price)
st_pete_property_df['dist2coast_km'] = st_pete_property_df['dist2coast']/1000

In [None]:
st_pete_property_df.EVAC_ZONE.value_counts()

In [None]:
def merge_zone(value):
    if value == 'A':
        return 'A'
    elif value == 'NON EVAC':
        return 'NON EVAC'
    else:
        return 'Other'

st_pete_property_df['EVAC_MERGE'] = st_pete_property_df.EVAC_ZONE.apply(merge_zone)

In [None]:
other_df, test_df = train_test_split(st_pete_property_df, test_size=0.1, random_state=42)
train_df, val_df = train_test_split(other_df, test_size=0.1, random_state=42)

In [None]:
cols = ['lg_price', 'bathrooms', 'livingArea', 'POOL', 'hhinc_k', 'flood_risk', 'dist2coast_km', 'latitude', 'longitude']

In [None]:
train_mod = train_df[cols]
val_mod = val_df[cols]
test_mod = test_df[cols]

In [None]:
numerical_columns = ['bathrooms', 'livingArea', 'hhinc_k', 'flood_risk', 'dist2coast_km', 'latitude', 'longitude']
numerical_pipeline = make_pipeline(
    FunctionTransformer(func=convert_float64, validate=False),
    StandardScaler()
)

categorical_columns = ['POOL']
categorical_pipeline = make_pipeline(
    OneHotEncoder(categories="auto"),
)

preprocessor = ColumnTransformer(
    [
        ("numerical_preprocessing", numerical_pipeline, numerical_columns),
        ("categorical_preprocessing", categorical_pipeline, categorical_columns),
    ],
)

In [None]:
train_mod.columns[1:]

In [None]:
X_train = train_mod[train_mod.columns[1:]]
Y_train = train_mod[train_mod.columns[0]]
X_val = val_mod[val_mod.columns[1:]]
Y_val = val_mod[val_mod.columns[0]]
X_test = test_mod[test_mod.columns[1:]]
Y_test = test_mod[test_mod.columns[0]]

In [None]:
# Fit the preprocessor first
preprocessor.fit(X_train)

# Extract feature names for numerical columns
numerical_feature_names = numerical_columns

# Extract feature names for categorical columns using the OneHotEncoder step
categorical_feature_names = (
    preprocessor.named_transformers_['categorical_preprocessing']
    .named_steps['onehotencoder']
    .get_feature_names_out(categorical_columns)
)

# Combine both numerical and categorical feature names
all_feature_names = np.concatenate([numerical_feature_names, categorical_feature_names])

# Display the feature names
print("Feature Names After Transformation:")
print(all_feature_names)


In [None]:
# def generate_multivariate_y(x):
#     """Custom function to generate dependent variable with noise."""
#     noise = np.random.normal(scale=0.1, size=(x.shape[0],))  # Add some noise
#     y = 3 * x[:, 0] + 2 * x[:, 1] - x[:, 2] + np.sin(x[:, 3]) - 0.5 * x[:, 4] ** 2 + 0.1 * x[:, 5] + noise
#     return y


# class CustomMultivariateDatamodule(LightningDataModule):
#     """Implement Dataset with 7 independent variables and 1 dependent variable."""

#     def __init__(
#         self,
#         n_points: int = 500,
#         batch_size: int = 100,
#         test_fraction: float = 0.1,
#         val_fraction: float = 0.1,
#         calib_fraction: float = 0.4,
#         noise_seed: int = 42,
#         split_seed: int = 42,
#     ) -> None:
#         """Define a multivariate regression dataset.

#         Split `n_points` data points into train, validation, and test sets.

#         Args:
#             n_points: Number of data points to generate.
#             batch_size: Batch size for data loader.
#             test_fraction: Fraction of n_points for test set.
#             val_fraction: Fraction of n_points for validation set.
#             calib_fraction: Fraction of n_points for calibration set.
#             noise_seed: Random seed for data generation.
#             split_seed: Random seed for train/test/val split.
#         """
#         super().__init__()

#         np.random.seed(noise_seed)
#         self.batch_size = batch_size

#         # Generate independent variables (X) and dependent variable (Y)
#         x = np.random.uniform(-5, 5, size=(n_points, 7))
#         y = generate_multivariate_y(x)

#         # full dataset
#         self.X_all = x
#         self.Y_all = y[:, None]  # Make Y a 2D array

#         # Split data into train and held-out IID test
#         X_other, self.X_test, Y_other, self.Y_test = train_test_split(
#             self.X_all, self.Y_all, test_size=test_fraction, random_state=split_seed
#         )

#         # Split train data into train and validation
#         self.X_train, self.X_val, self.Y_train, self.Y_val = train_test_split(
#             X_other,
#             Y_other,
#             test_size=val_fraction / (1 - test_fraction),
#             random_state=split_seed,
#         )

#         # Split validation data into validation and calibration (for conformal)
#         self.X_val, self.X_calib, self.Y_val, self.Y_calib = train_test_split(
#             self.X_val, self.Y_val, test_size=calib_fraction, random_state=split_seed
#         )

#         # Fit scalers on train data
#         scalers = dict(
#             X=StandardScaler().fit(self.X_train), Y=StandardScaler().fit(self.Y_train)
#         )

#         # Apply scaling to all splits, convert to torch tensors
#         for xy in ["X", "Y"]:
#             for arr_type in ["train", "test", "val", "calib", "all"]:
#                 arr_name = f"{xy}_{arr_type}"
#                 setattr(
#                     self,
#                     arr_name,
#                     self._n2t(scalers[xy].transform(getattr(self, arr_name))),
#                 )

#     @staticmethod
#     def _n2t(x):
#         return torch.from_numpy(x).type(torch.float32)

#     def train_dataloader(self) -> DataLoader:
#         """Return train dataloader."""
#         return DataLoader(
#             TensorDataset(self.X_train, self.Y_train),
#             batch_size=self.batch_size,
#             shuffle=True,
#             collate_fn=collate_fn_tensordataset,
#         )

#     def val_dataloader(self) -> DataLoader:
#         """Return val dataloader."""
#         return DataLoader(
#             TensorDataset(self.X_val, self.Y_val),
#             batch_size=self.batch_size,
#             collate_fn=collate_fn_tensordataset,
#         )

#     def calib_dataloader(self) -> DataLoader:
#         """Return calibration dataloader."""
#         return DataLoader(
#             TensorDataset(self.X_calib, self.Y_calib),
#             batch_size=self.batch_size,
#             collate_fn=collate_fn_tensordataset,
#         )

#     def test_dataloader(self) -> DataLoader:
#         """Return test dataloader."""
#         return DataLoader(
#             TensorDataset(self.X_test, self.Y_test),
#             batch_size=self.batch_size,
#             collate_fn=collate_fn_tensordataset,
#         )

In [None]:
class CustomMultivariateDatamodule(LightningDataModule):
    """Implement Dataset with 7 independent variables and 1 dependent variable."""

    def __init__(
        self,
        # n_points: int = 500,
        batch_size: int = 100,
        test_fraction: float = 0.1,
        val_fraction: float = 0.1,
        calib_fraction: float = 0.4,
        # noise_seed: int = 42,
        split_seed: int = 42,
    ) -> None:
        """Define a multivariate regression dataset.

        Split `n_points` data points into train, validation, and test sets.

        Args:
            n_points: Number of data points to generate.
            batch_size: Batch size for data loader.
            test_fraction: Fraction of n_points for test set.
            val_fraction: Fraction of n_points for validation set.
            calib_fraction: Fraction of n_points for calibration set.
            noise_seed: Random seed for data generation.
            split_seed: Random seed for train/test/val split.
        """
        super().__init__()

        # np.random.seed(noise_seed)
        self.batch_size = batch_size

        # # full dataset
        # self.X_all = X
        # self.Y_all = np.array(Y)[:, None]  # Make Y a 2D array

        # # Split data into train and held-out IID test
        # X_other, self.X_test, Y_other, self.Y_test = train_test_split(
        #     self.X_all, self.Y_all, test_size=test_fraction, random_state=split_seed
        # )

        # # Split train data into train and validation
        # self.X_train, self.X_val, self.Y_train, self.Y_val = train_test_split(
        #     X_other,
        #     Y_other,
        #     test_size=val_fraction / (1 - test_fraction),
        #     random_state=split_seed,
        # )

        self.X_train = X_train
        self.Y_train = np.array(Y_train)[:, None]
        self.X_val = X_val
        self.Y_val = np.array(Y_val)[:, None]
        self.X_test = X_test
        self.Y_test = np.array(Y_test)[:, None]

        # Split validation data into validation and calibration (for conformal)
        self.X_val, self.X_calib, self.Y_val, self.Y_calib = train_test_split(
            self.X_val, self.Y_val, test_size=calib_fraction, random_state=split_seed
        )

        # Fit scalers on train data
        scalers = dict(
            X=preprocessor.fit(self.X_train), Y=StandardScaler().fit(self.Y_train)
        )

        # Apply scaling to all splits, convert to torch tensors
        for xy in ["X", "Y"]:
            for arr_type in ["train", "test", "val", "calib"]:
                arr_name = f"{xy}_{arr_type}"
                processed_arr = scalers[xy].transform(getattr(self, arr_name))
                if xy == "X":
                  processed_arr = np.delete(processed_arr, -2, axis=1)
                setattr(
                    self,
                    arr_name,
                    self._n2t(processed_arr),
                )

    @staticmethod
    def _n2t(x):
        return torch.from_numpy(x).type(torch.float32)

    def train_dataloader(self) -> DataLoader:
        """Return train dataloader."""
        assert isinstance(self.X_train, torch.Tensor), "X_train is not a Tensor"
        assert isinstance(self.Y_train, torch.Tensor), "Y_train is not a Tensor"
        assert self.X_train.size(0) == self.Y_train.size(0), "Size mismatch in train data"

        return DataLoader(
            TensorDataset(self.X_train, self.Y_train),
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=collate_fn_tensordataset,
        )

    def val_dataloader(self) -> DataLoader:
        """Return val dataloader."""
        return DataLoader(
            TensorDataset(self.X_val, self.Y_val),
            batch_size=self.batch_size,
            collate_fn=collate_fn_tensordataset,
        )

    def calib_dataloader(self) -> DataLoader:
        """Return calibration dataloader."""
        return DataLoader(
            TensorDataset(self.X_calib, self.Y_calib),
            batch_size=self.batch_size,
            collate_fn=collate_fn_tensordataset,
        )

    def test_dataloader(self) -> DataLoader:
        """Return test dataloader."""
        return DataLoader(
            TensorDataset(self.X_test, self.Y_test),
            batch_size=self.batch_size,
            collate_fn=collate_fn_tensordataset,
        )

    # def gt_dataloader(self) -> DataLoader:
    #     """Return test dataloader."""
    #     return DataLoader(
    #         TensorDataset(self.X_gtext, self.Y_gtext),
    #         batch_size=self.batch_size,
    #         collate_fn=collate_fn_tensordataset,
    #     )

In [None]:
dm = CustomMultivariateDatamodule()

X_train, Y_train, train_loader, X_test, Y_test, test_loader = (
    dm.X_train,
    dm.Y_train,
    dm.train_dataloader(),
    dm.X_test,
    dm.Y_test,
    dm.test_dataloader(),
    # dm.X_gtext,
    # dm.Y_gtext,
)

In [None]:
"""Perform Variance-Based Sensitivity Analysis using Sobol sampling and analysis."""
problem = {
    'num_vars': X_train.shape[1],
    'names': all_feature_names.tolist(),
    'bounds': [[X_train[:,i].numpy().min(), X_train[:,i].numpy().max()] for i in range(5)]
}

param_values = saltelli.sample(problem, 1000)

In [None]:
def enforce_one_hot_last_n_columns(samples, n):
    """
    Enforce valid one-hot encoding only for the last `n` columns.

    Args:
        samples (numpy.ndarray): The input array.
        n (int): The number of columns to apply one-hot encoding to (from the end).

    Returns:
        numpy.ndarray: The modified array with valid one-hot encoding enforced on the last `n` columns.
    """
    valid_samples = []
    for row in samples:
        # Split the row into non-categorical and categorical parts
        non_categorical = row[:-n]
        categorical = row[-n:]

        # Enforce one-hot encoding only on the categorical part
        one_hot = np.zeros(n)
        one_hot[np.argmax(categorical)] = 1  # Set only the highest value to 1

        # Reconstruct the row with non-categorical + modified categorical
        valid_samples.append(np.concatenate([non_categorical, one_hot]))

    return np.array(valid_samples)

# Example usage:
# Assume param_values is already generated and has mixed columns
param_values = enforce_one_hot_last_n_columns(param_values, n=1)

# Verify the enforcement (each row should have exactly one '1' in the last 8 columns)
print(param_values[:5])


In [None]:
param_values = torch.from_numpy(param_values).type(torch.float32)

In [None]:
param_values.shape

## Mean Variance Estimation

In [None]:
network = MLP(n_inputs=13, n_hidden=[50, 50, 50], n_outputs=2, activation_fn=nn.Tanh())
network

In [None]:
mve_model = MVERegression(
    model=network, optimizer=partial(torch.optim.Adam, lr=1e-3), burnin_epochs=5
)

In [None]:
my_temp_dir = '/content/mve'

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=250,  # number of epochs we want to train
    accelerator="cpu",  # use distributed training
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(mve_model, dm)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE","trainR2","trainMAE"]
)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["val_loss", "valRMSE","valR2","valMAE"]
)

In [None]:
preds = mve_model.predict_step(X_gtext)

fig = plot_predictions_regression(
    X_train[:,3],
    Y_train,
    X_gtext[:,3],
    Y_gtext,
    preds["pred"],
    preds["pred_uct"].squeeze(-1),
    aleatoric=preds["aleatoric_uct"],
    title="Mean Variance Estimation Network",
    show_bands=False,
)

In [None]:
preds = mve_model.predict_step(X_test)

fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].cpu().numpy(),
    Y_test.cpu().numpy(),
    X_test[:,3].cpu().numpy(),
)

## Gaussian Process Regression

In [None]:
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [None]:
# initialize the likelihood
likelihood = gpytorch.likelihoods.GaussianLikelihood()

# init the GP model
gp_model = ExactGPModel(X_train.squeeze(), Y_train.squeeze(), likelihood)

# Find optimal model hyperparameters
gp_model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(
    gp_model.parameters(), lr=1e-2
)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp_model)

bar = tqdm(range(100))
for i in bar:
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = gp_model(X_train.squeeze())
    # Calc loss and backprop gradients
    loss = -mll(output, Y_train.squeeze())
    loss.backward()
    optimizer.step()
    bar.set_postfix(loss=f"{loss.detach().cpu().item()}")

In [None]:
gp_model.eval()
likelihood.eval()

with torch.no_grad():
    gp_preds = gp_model(X_gtext.cpu())

gp_mean = gp_preds.mean.detach().cpu().numpy()
gp_var = gp_preds.variance.detach().cpu().numpy()
gp_covar = gp_preds.covariance_matrix.detach().cpu().numpy()

In [None]:
fig = plot_predictions_regression(
    X_train[:,3],
    Y_train,
    X_gtext[:,3],
    Y_gtext,
    gp_mean[:, None],
    np.sqrt(gp_var),
    epistemic=np.sqrt(gp_var),
    title="Gaussian Process",
    show_bands=False,
)

In [None]:
with torch.no_grad():
    gp_preds = gp_model(X_test.cpu())

gp_mean = gp_preds.mean.detach().cpu().numpy()
gp_var = gp_preds.variance.detach().cpu().numpy()
gp_covar = gp_preds.covariance_matrix.detach().cpu().numpy()

fig = plot_calibration_uq_toolbox(
    gp_mean, np.sqrt(gp_var), Y_test.cpu().numpy(), X_test[:,3].cpu().numpy()
)

## Laplace Approximation

In [None]:
network = MLP(n_inputs=13, n_hidden=[50, 50], n_outputs=1, activation_fn=nn.Tanh())
network

In [None]:
deterministic_model = DeterministicRegression(
    model=network,
    optimizer=partial(torch.optim.Adam, lr=1e-2),
    loss_fn=torch.nn.MSELoss(),
)

In [None]:
my_temp_dir = '/content/laplace'

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=100,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
)

In [None]:
trainer.fit(deterministic_model, dm)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE","trainR2","trainMAE"]
)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["val_loss", "valRMSE","valR2","valMAE"]
)

In [None]:
la = Laplace(
    deterministic_model.model,
    "regression",
    subset_of_weights="last_layer",
    hessian_structure="full",
    sigma_noise=0.4,
)


laplace_model = LaplaceRegression(laplace_model=la, tune_prior_precision=True)

trainer = Trainer(default_root_dir=my_temp_dir)

In [None]:
trainer.test(laplace_model, dm)

In [None]:
preds = laplace_model.predict_step(X_test)

fig = plot_predictions_regression(
    X_train[:,3],
    Y_train,
    X_test[:,3],
    Y_test,
    preds["pred"],
    preds["pred_uct"],
    epistemic=preds["epistemic_uct"],
    aleatoric=preds["aleatoric_uct"],
    title="Laplace Approximation",
)

In [None]:
preds = laplace_model.predict_step(X_test)
fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].numpy(),
    Y_test.cpu().numpy(),
    X_test[:,3].cpu().numpy(),
)

In [None]:
sample_preds = laplace_model.predict_step(param_values)

In [None]:
Si = sobol.analyze(problem, sample_preds["pred"].cpu().numpy())

### Plot Variance Based Sensitivity Analysis

In [None]:
def plot_sensitivity_indices(Si, feature_names):
    """Plot the first-order and total sensitivity indices."""
    s1 = Si['S1']
    st = Si['ST']
    indices = range(len(s1))

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.bar(indices, s1)
    plt.xticks(indices, feature_names, rotation=90)
    plt.title('First-order Sensitivity Indices')
    plt.ylabel('S1')

    plt.subplot(1, 2, 2)
    plt.bar(indices, st)
    plt.xticks(indices, feature_names, rotation=90)
    plt.title('Total Sensitivity Indices')
    plt.ylabel('ST')

    plt.tight_layout()
    plt.show()

In [None]:
import random

def plot_feature_variance_and_interactions(X, y, Si, feature_names, sample_fraction=0.5):
    """
    Plot the variance of each feature with respect to the probability score
    and the strongest interactions, randomly hiding scatter dots for clarity.

    Args:
        X (numpy.ndarray): Feature matrix.
        y (numpy.ndarray): Target values.
        Si (dict): Sensitivity indices.
        feature_names (list): List of feature names.
        sample_fraction (float): Fraction of points to display in scatter plot (default is 50%).
    """
    s1 = Si['S1']
    st = Si['ST']
    s2 = Si['S2']

    for i, feature in enumerate(feature_names):
        fig, ax = plt.subplots(1, 2, figsize=(14, 6))

        # Randomly sample a fraction of points for the scatter plot
        total_samples = X.shape[0]
        sample_size = int(sample_fraction * total_samples)
        sample_indices = random.sample(range(total_samples), sample_size)
        sampled_X = X[sample_indices, i]
        sampled_y = y[sample_indices]

        # Plot first-order sensitivity index
        sns.scatterplot(x=sampled_X, y=sampled_y, ax=ax[0])
        ax[0].set_title(f'Variance Contribution of {feature}', fontsize=14)
        ax[0].set_xlabel(feature, fontsize=12)
        ax[0].set_ylabel('lg_price', fontsize=12)
        ax[0].tick_params(axis='x', rotation=45)
        ax[0].tick_params(axis='y', labelsize=10)

        # Calculate interaction strengths for the current feature
        interaction_strengths = []
        for j, other_feature in enumerate(feature_names):
            if i != j:
                interaction_strength = s2[i, j] if s2.ndim > 1 else 0
                interaction_strengths.append((other_feature, interaction_strength))

        # Sort by interaction strength and select top 5 for the current feature
        interaction_strengths.sort(key=lambda x: abs(x[1]), reverse=True)
        top_interactions = interaction_strengths

        interaction_features = [item[0] for item in top_interactions]
        interaction_values = [item[1] for item in top_interactions]

        # Use a colormap to assign colors
        cmap = get_cmap('coolwarm_r')  # '_r' inverts the colormap
        colors = [cmap(i / len(interaction_features)) for i in range(len(interaction_features))]

        sns.barplot(x=interaction_features, y=interaction_values, ax=ax[1], palette=colors)
        ax[1].set_title(f'Strongest Interactions with {feature}', fontsize=14)
        ax[1].set_xlabel('Feature', fontsize=12)
        ax[1].set_ylabel('Interaction Strength', fontsize=12)
        ax[1].tick_params(axis='x', rotation=45, labelsize=10)
        ax[1].tick_params(axis='y', labelsize=10)

        plt.tight_layout()
        plt.show()


### Continue Analyzing

In [None]:
plot_sensitivity_indices(Si, all_feature_names)

In [None]:
plot_feature_variance_and_interactions(param_values, sample_preds["pred"].cpu().numpy(), Si, all_feature_names, sample_fraction=0.1)

## Bayes By Backprop - Mean Field Variational Inference

In [None]:
network = MLP(n_inputs=13, n_hidden=[50, 50], n_outputs=2, activation_fn=nn.ReLU())
network

In [None]:
bbp_model = BNN_VI_ELBO_Regression(
    network,
    optimizer=partial(torch.optim.Adam, lr=3e-3),
    criterion=NLL(),
    stochastic_module_names=[-1],
    num_mc_samples_train=10,
    num_mc_samples_test=25,
    burnin_epochs=20,
)

In [None]:
my_temp_dir = '/content/bbp'

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=150,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=20,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(bbp_model, dm)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE","trainR2","trainMAE"]
)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["val_loss", "valRMSE","valR2","valMAE"]
)

In [None]:
preds = bbp_model.predict_step(X_gtext)

fig = plot_predictions_regression(
    X_train[:,3],
    Y_train,
    X_gtext[:,3],
    Y_gtext,
    preds["pred"].squeeze(-1),
    preds["pred_uct"],
    epistemic=preds["epistemic_uct"],
    aleatoric=preds["aleatoric_uct"],
    title="Bayes By Backprop MFVI",
    show_bands=False,
)

In [None]:
preds = bbp_model.predict_step(X_test)
fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].cpu().numpy(),
    Y_test.cpu().numpy(),
    X_test[:,3].cpu().numpy(),
)

## Bayesian Neural Network with Variational Inference and Energy Loss

In [None]:
network = MLP(n_inputs=8, n_hidden=[50, 50], n_outputs=1, activation_fn=nn.Tanh())
network

In [None]:
bnn_vi_model = BNN_VI_Regression(
    network,
    optimizer=partial(torch.optim.Adam, lr=1e-2),
    n_mc_samples_train=10,
    n_mc_samples_test=50,
    output_noise_scale=1.3,
    prior_mu=0.0,
    prior_sigma=1.0,
    posterior_mu_init=0.0,
    posterior_rho_init=-6.0,
    alpha=1e-03,
    bayesian_layer_type="reparameterization",
    stochastic_module_names=[-1],
)

In [None]:
my_temp_dir = '/content/bnn_vi'

In [None]:
logger = CSVLogger(my_temp_dir)

# Define early stopping callback
early_stop_callback = EarlyStopping(
    monitor='valRMSE',  # You can change this to another metric like 'val_accuracy'
    patience=10,  # Number of epochs with no improvement before stopping
    verbose=False,  # Print when early stopping is triggered
    mode='min',  # 'min' means we are looking for a decrease in the metric (e.g., val_loss)
)

trainer = Trainer(
    max_epochs=200,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=True,
    enable_progress_bar=False,
    limit_val_batches=1.0,  # full validation runs
    default_root_dir=my_temp_dir,
    callbacks=[early_stop_callback],
)

In [None]:
trainer.fit(bnn_vi_model, dm)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE","trainR2","trainMAE"]
)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["val_loss", "valRMSE","valR2","valMAE"]
)

In [None]:
preds = bnn_vi_model.predict_step(X_test)

fig = plot_predictions_regression(
    X_train[:,3],
    Y_train,
    X_test[:,3],
    Y_test,
    preds["pred"],
    preds["pred_uct"],
    epistemic=preds["epistemic_uct"],
    show_bands=False,
    title="Bayesian NN with Alpha Divergence Loss",
)

In [None]:
preds = bnn_vi_model.predict_step(X_test)
fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].cpu().numpy(),
    Y_test.cpu().numpy(),
    X_test[:,3].cpu().numpy(),
)

In [None]:
pred_mean = preds["pred"].cpu().numpy().squeeze()
pred_std = preds["pred_uct"].cpu().numpy()
te_y = np.squeeze(Y_test.cpu().numpy())

In [None]:
# Plot adversarial group calibration
uct.viz.plot_adversarial_group_calibration(preds["pred"].cpu().numpy().squeeze(), preds["pred_uct"].cpu().numpy(), np.squeeze(Y_test.cpu().numpy()))

In [None]:
uct.metrics.get_all_metrics(preds["pred"].cpu().numpy().squeeze(), preds["pred_uct"].cpu().numpy(), np.squeeze(Y_test.cpu().numpy()))

In [None]:
test_df['pred_mean'] = pred_mean
test_df['pred_std'] = pred_std
test_df['residual'] = pred_mean - te_y

In [None]:
test_df.to_csv('test_df_flood_with_loc.csv', index=False)

In [None]:
sample_preds = bnn_vi_model.predict_step(param_values)

In [None]:
Si = sobol.analyze(problem, sample_preds["pred"].cpu().numpy().squeeze())

In [None]:
plot_sensitivity_indices(Si, all_feature_names)

In [None]:
plot_feature_variance_and_interactions(param_values, sample_preds["pred"].cpu().numpy().squeeze(), Si, all_feature_names, sample_fraction=0.1)