# Interactive sensitivity analysis with GP regression model

In this notebook we create an interactive version of the sensitivity analysis we developed the previous step of the tutorial. 
You can use this to explore the predictions and sensitivities when varying the inputs to the model.
In particular, you can see how the behaviour of certain important inputs is affected when changing the values of other important inputs.

This notebook contains a lot of code, however, most of it is for creating the interactive plots. The code used for the sensitivity analysis is the same as we have seen earlier in the tutorial. 

## Dependencies

As in the previous notebooks, we start by importing all dependencies.

If you are in Colab, you need to install the [pyro](https://pyro.ai/) package by uncommenting and running the line `!pip3 install pyro-ppl` below before proceeding.

In [None]:
# install dependencies
# !pip3 install pyro-ppl

In [None]:
# imports
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import numpy as np
import pandas as pd
import torch
import pyro
import pyro.contrib.gp as gp
import ipywidgets as widgets
from matplotlib.ticker import FormatStrFormatter

pyro.set_rng_seed(0)
print(f"torch version: {torch.__version__}")
print(f"pyro version: {pyro.__version__}")

## Load the dataset and model parameters

We can load the dataset directly from the GitHub URL.
Alternatively, the dataset can be loaded from a local file.

In [None]:
# load dataset
dataset_path = "https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_10k.csv"
# dataset_path = "data/p2d_sei_10k.csv"  # local
df = pd.read_csv(dataset_path, index_col=0)

# store the names of the features and the name of the target variable
features = df.columns[:15].tolist()  # use input parameters as features
target = "SEI_thickness(m)"  # primary target
# target = "Capacity loss (%)"  # secondary target

We also need to load the trained model parameters that we saved previously. 

If you are running this notebook in Colab, you need to make the parameter file available in the working directory by uploading it to the Files section to the left.

In [None]:
pyro.clear_param_store()

if target == "SEI_thickness(m)":
    pyro.get_param_store().load("sgpr_params_sei.p")
if target == "Capacity loss (%)":
    pyro.get_param_store().load("sgpr_params_cap.p")

params = pyro.get_param_store()
params.keys()

## Setup model

Setup the model with the trained parameters.

In [None]:
kernel = gp.kernels.RBF(input_dim=params["data.x_train"].shape[1], variance=params["kernel.variance"], lengthscale=params["kernel.lengthscale"])
model = gp.models.SparseGPRegression(params["data.x_train"], params["data.y_train"], kernel, Xu=params["Xu"], noise=params["noise"])

## Interactive sensitivity analysis



In [None]:
def sa_autograd(model, X, reduce=None):   
    """Sensitivity analysis of GP regression model with automatic differentiation.
    
    Args:
        model: Gaussian process regression model
        X (tensor): Input data (design matrix)
        reduce (string): method used to reduce the sensitivity result: sum, mean, none.
    """
    X.requires_grad = True
    # compute gradient of the mean prediction
    model.zero_grad()
    mean, _ = model(X, full_cov=False, noiseless=False)
    gmean = torch.autograd.grad(mean.sum(), X)[0]
    # compute gradient of the variance prediction
    model.zero_grad()
    _, var = model(X, full_cov=False, noiseless=False)
    gvar = torch.autograd.grad(var.sum(), X)[0]
    X.requires_grad = False
    if reduce == "sum":
        return mean, var, torch.sqrt(torch.sum(gmean**2, dim=0)), torch.sqrt(torch.sum(gvar**2, dim=0))
    elif reduce == "mean":
        return mean, var, torch.sqrt(torch.mean(gmean**2, dim=0)), torch.sqrt(torch.mean(gvar**2, dim=0))
    else:
        return mean, var, torch.sqrt(gmean**2), torch.sqrt(gvar**2)

In [None]:
def predict_sa(x, reduce=None):
    return sa_autograd(model, x, reduce=reduce)

### Interactive sensitivity analysis: global

Here you can experiment with how the average sensitivities are affected by range we consider for each input.

In [None]:
def create_predict_and_plot_global(predict_sa, features, n_sample=1000, normalise=False, figsize=(12,6)):
    def predict_and_plot(**x_dict):
        x_min = torch.tensor([x_dict[f][0] for f in features])
        x_max = torch.tensor([x_dict[f][1] for f in features])
        # create inputs
        X = torch.distributions.Uniform(x_min, x_max).sample((n_sample,))
        # predict
        mean, var, s_mean, s_var = predict_sa(X, reduce="mean")
        mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
        std = var.sqrt().detach()
        # normalise
        if normalise:
            s_mean = s_mean / s_mean.sum()
            s_var = s_var / s_var.sum()
        # plot
        plt.figure(figsize=figsize)
        # plot sensitivity of mean prediction
        plt.subplot(121)
        plt.bar(range(len(features)), s_mean, label="s_mean")
        #plt.bar(range(len(features)), s_var, alpha=0.75, label="s_var")
        if normalise:
            plt.ylim((0,.8))
        plt.xticks(range(len(features)), [f"x{i}: {f}" for i,f in enumerate(features)], rotation=90)
        plt.xlabel("Input feature"); plt.ylabel("Sensitivity")
        plt.legend()
        plt.grid(axis='y')
        # plot sensitivity of var prediction
        plt.subplot(122)
        #plt.bar(range(len(features)), s_mean, label="s_mean")
        plt.bar(range(len(features)), s_var, color="C1", label="s_var")
        if normalise:
            plt.ylim((0,.8))
        plt.xticks(range(len(features)), [f"x{i}: {f}" for i,f in enumerate(features)], rotation=90)
        plt.xlabel("Input feature"); plt.ylabel("Sensitivity")
        plt.legend()
        plt.grid(axis='y')
        plt.tight_layout()
        plt.show()
    return predict_and_plot


def interactive_global(on_change, features, x_min, x_max, n_steps=20):
    sliders = {}
    #sliders["d"] = widgets.IntSlider(value=0, min=0, max=len(features)-1, description="dim")
    for i, f in enumerate(features):
        sliders[f] = widgets.FloatRangeSlider(
            value=[x_min[i], x_max[i]], 
            min=x_min[i],
            max=x_max[i],
            step=(x_max[i] - x_min[i]) / n_steps,
            description=f"x{i}: {f}",
            readout_format=".1f",
        )
    # setup ui
    out = widgets.interactive_output(on_change, sliders)
    controls = widgets.VBox(list(sliders.values()))
    ui = widgets.HBox([controls, out])
    # display ui
    display(ui)


# on change function
on_change_global = create_predict_and_plot_global(
    predict_sa,
    features,
    n_sample=5000,
    normalise=True,
)

# setup ui
interactive_global(
    on_change_global,
    features,
    x_min=params["data.x_min"].detach().numpy(),
    x_max=params["data.x_max"].detach().numpy(),
)

### Interactive sensitivity analysis: 1D

Here you can see how the behaviour of certain important inputs along their range of variation is affected when changing the values of other important inputs.

In [None]:
def create_predict_and_plot_1d(predict_sa, features, target, x_min, x_max, y_lim=None, n_points=100, figsize=(12,7)):
    def predict_and_plot(d, **x_dict):
        x_list = [x_dict[f] for f in features]
        # create inputs
        x = torch.tensor(x_list)
        X = x.repeat(n_points, 1)
        xd = torch.linspace(x_min[d], x_max[d], n_points)
        X[:,d] = xd
        # predict point
        mean0, var0, s_mean0, s_var0 = predict_sa(x.unsqueeze(0))
        mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()
        std0 = var0.sqrt()
        # predict grid
        mean, var, s_mean, s_var = predict_sa(X)
        mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
        std = var.sqrt().detach()
        plt.figure(figsize=figsize)
        # plot mean prediction with uncertainty
        plt.subplot(221)
        plt.title("mean prediction with uncertainty (2*std)")
        plt.plot(xd.numpy(), mean.numpy())
        plt.fill_between(xd.numpy(), (mean.numpy() - 2.0 * std.numpy()), (mean.numpy() + 2.0 * std.numpy()), color='C0', alpha=0.3)
        plt.axvline(x[d].numpy(), color="k", linewidth=1, label=f"{mean0.item():.4f} ({std0.item():.4f})")
        plt.xlim((x_min[d], x_max[d]))
        if y_lim is not None:
            plt.ylim(y_lim)
        plt.xlabel(f"x{d}: {features[d]}")
        plt.ylabel(f"log y: {target}")
        plt.grid()
        plt.legend(loc=4)
        # plot uncertainty
        plt.subplot(222)
        plt.title("uncertainty prediction (2*std)")
        plt.plot(xd.numpy(), 2*std.numpy())
        plt.axvline(x[d].numpy(), color="k", linewidth=1, label=f"{2 * std0.item():.4f}")
        plt.xlim((x_min[d], x_max[d]))
        plt.ylim((0,1))
        plt.xlabel(f"x{d}: {features[d]}")
        plt.ylabel("uncertainty")
        plt.grid()
        plt.legend(loc=4)
        # plot sensitivity of mean
        plt.subplot(223)
        plt.title("sensitivity of mean prediction")
        plt.plot(xd.numpy(), s_mean[:, d].numpy())
        plt.axvline(x[d].numpy(), color="k", linewidth=1, label=f"{s_mean0[:,d].item():.4f}")
        plt.xlim((x_min[d], x_max[d]))
        plt.ylim((0,5))
        plt.xlabel(f"x{d}: {features[d]}")
        plt.ylabel("sensitivity")
        plt.grid()
        plt.legend(loc=4)
        # plot sensitivity of var
        plt.subplot(224)
        plt.title("sensitivity of uncertainty prediction")
        plt.plot(xd.numpy(), s_var[:, d].numpy())
        plt.axvline(x[d].numpy(), color="k", linewidth=1, label=f"{s_var0[:,d].item():.4f}")
        plt.xlim((x_min[d], x_max[d]))
        plt.ylim((0,0.3))
        plt.xlabel(f"x{d}: {features[d]}")
        plt.ylabel("sensitivity")
        plt.grid()
        plt.legend(loc=4)
        plt.tight_layout()
        plt.show()
    return predict_and_plot


def interactive_1d(on_change, features, x_min, x_max, x_init, n_steps=20):
    sliders = {}
    sliders["d"] = widgets.IntSlider(value=0, min=0, max=len(features)-1, description="dim")
    for i, f in enumerate(features):
        sliders[f] = widgets.FloatSlider(
            value=x_init[i], 
            min=x_min[i],
            max=x_max[i],
            step=(x_max[i] - x_min[i]) / n_steps,
            description=f"x{i}: {f}",
            readout_format=".1f",
        )
    # setup ui
    out = widgets.interactive_output(on_change, sliders)
    controls = widgets.VBox(list(sliders.values()))
    ui = widgets.HBox([controls, out])
    # display ui
    display(ui)


# on change function
on_change_1d = create_predict_and_plot_1d(
    predict_sa,
    features,
    target,
    x_min=params["data.x_min"].detach().numpy(),
    x_max=params["data.x_max"].detach().numpy(),
    y_lim=(params["data.y_min"].item(), params["data.y_max"].item()),
)

# setup ui
interactive_1d(
    on_change_1d,
    features,
    x_min=params["data.x_min"].detach().numpy(),
    x_max=params["data.x_max"].detach().numpy(),
    x_init=params["data.x_train"][0],
)

### Interactive sensitivity analysis: 2D

This lets you plot two inputs at a time to see how their combined behaviour is affected when changing the values of other important inputs.

In [None]:
# setup on_change function
def create_predict_and_plot_2d(predict_sa, features, target, x_min, x_max, y_lim=None, n_points=100, n_levels=21, figsize=(12,10)):
    def predict_and_plot(d0, d1, **x_dict):
        x_list = [x_dict[f] for f in features]
        # create inputs
        x = torch.tensor(x_list)
        X = x.repeat(n_points**2, 1)
        # setup grid
        xd0 = torch.linspace(x_min[d0], x_max[d0], n_points)
        xd1 = torch.linspace(x_min[d1], x_max[d1], n_points)
        grid_xd0, grid_xd1 = torch.meshgrid(xd0, xd1)        
        X[:,d0] = grid_xd0.reshape(len(X))
        X[:,d1] = grid_xd1.reshape(len(X))
        # predict point
        mean0, var0, s_mean0, s_var0 = predict_sa(x.unsqueeze(0))
        mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()
        std0 = var0.sqrt()
        # predict grid
        mean, var, s_mean, s_var = predict_sa(X)
        mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
        std = var.sqrt()
        s_mean0_d = (s_mean0[:, d0] + s_mean0[:, d1]).item()
        s_var0_d = (s_var0[:, d0] + s_var0[:, d1]).item()
        s_mean_d = (s_mean[:, d0] + s_mean[:, d1]).reshape(n_points, n_points)
        s_var_d = (s_var[:, d0] + s_var[:, d1]).reshape(n_points, n_points)
        # plot
        plt.figure(figsize=figsize)
        # plot mean prediction
        ax = plt.subplot(221)
        plt.title("mean prediction of log y")
        if y_lim is None:
            levels = torch.linspace(mean.min().item(), mean.max().item(), n_levels).numpy()
        else:
            levels = torch.linspace(y_lim[0], y_lim[1], n_levels).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), mean.reshape(n_points, n_points).numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{mean0.item():.4f} ({std0.item():.4f})")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1)
        plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
        plt.colorbar(shrink=0.9)
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        # plot uncertainty
        ax = plt.subplot(222)
        plt.title("uncertainty (2*std)")
        levels = torch.linspace(0, 1.0, 21).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), 2*std.reshape(n_points, n_points).numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{std0.item()*2:.4f}")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1)
        plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
        plt.colorbar(shrink=0.9)
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        # plot sensitivity of mean prediction
        ax = plt.subplot(223)
        plt.title("sensitivity of mean prediction")
        levels = torch.linspace(0, 5.0, 21).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_mean_d.numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{s_mean0_d:.4f}")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1)
        plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
        plt.colorbar(shrink=0.9)
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        # plot sensitivity of uncertainty prediction
        ax = plt.subplot(224)
        plt.title("sensitivity of uncertainty prediction")
        levels = torch.linspace(0, 0.25, 21).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_var_d.numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{s_var0_d:.4f}")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1)
        plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
        plt.colorbar(shrink=0.9)
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        plt.tight_layout()
        plt.show()
    return predict_and_plot


def interactive_2d(on_change, features, x_min, x_max, x_init, n_steps=20):
    sliders = {}
    sliders["d0"] = widgets.IntSlider(value=0, min=0, max=len(features)-1, description="dim 0")
    sliders["d1"] = widgets.IntSlider(value=1, min=0, max=len(features)-1, description="dim 1")
    for i, f in enumerate(features):
        sliders[f] = widgets.FloatSlider(
            value=x_init[i],
            min=x_min[i],
            max=x_max[i],
            step=(x_max[i] - x_min[i]) / n_steps,
            description=f"x{i}: {features[i]}",
            readout_format=".1f",
        )
    # setup ui
    out = widgets.interactive_output(on_change, sliders)
    controls = widgets.VBox(list(sliders.values()))
    ui = widgets.HBox([controls, out])
    # display ui
    display(ui)


# on change function  
on_change_2d = create_predict_and_plot_2d(
    predict_sa,
    features,
    target,
    x_min=params["data.x_min"].detach().numpy(),
    x_max=params["data.x_max"].detach().numpy(),
    y_lim=(params["data.y_min"].item(), params["data.y_max"].item()),
)

# setup ui
interactive_2d(
    on_change_2d,
    features,
    x_min=params["data.x_min"].detach().numpy(),
    x_max=params["data.x_max"].detach().numpy(),
    x_init=params["data.x_train"][0],
)