# Sensitivity analysis with GP regression model

Now that we are familiar the data and have a trained GP regression model, we can proceed to the actual sensitivity analysis.

## Dependencies

As in the previous notebooks, we start by importing all dependencies.

If you are in Colab, you need to install the [pyro](https://pyro.ai/) package by uncommenting and running the line `!pip3 install pyro-ppl` below before proceeding.

In [None]:
# install dependencies
# !pip3 install pyro-ppl

In [None]:
# imports
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import numpy as np
import pandas as pd
import torch
import pyro
import pyro.contrib.gp as gp

pyro.set_rng_seed(0)
print(f"torch version: {torch.__version__}")
print(f"pyro version: {pyro.__version__}")

## Load the dataset and model parameters

We can load the dataset directly from the GitHub URL.
Alternatively, the dataset can be loaded from a local file.

In [None]:
# load dataset
dataset_path = "https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_10k.csv"
# dataset_path = "data/p2d_sei_10k.csv"  # local
df = pd.read_csv(dataset_path, index_col=0)

# store the names of the features and the name of the target variable
features = df.columns[:15].tolist()  # use input parameters as features
target = "SEI_thickness(m)"  # primary target
# target = "Capacity loss (%)"  # secondary target

We also need to load the trained model parameters that we saved in the previous notebook. 

If you are running this notebook in Colab, you need to make the parameter file available in the working directory by uploading it to the Files section to the left.

In [None]:
pyro.clear_param_store()

if target == "SEI_thickness(m)":
    pyro.get_param_store().load("sgpr_params_sei.p")
if target == "Capacity loss (%)":
    pyro.get_param_store().load("sgpr_params_cap.p")

params = pyro.get_param_store()
params.keys()

## Setup model

Setup the model with the trained parameters.

In [None]:
kernel = gp.kernels.RBF(input_dim=params["data.x_train"].shape[1], variance=params["kernel.variance"], lengthscale=params["kernel.lengthscale"])
model = gp.models.SparseGPRegression(params["data.x_train"], params["data.y_train"], kernel, Xu=params["Xu"], noise=params["noise"])

## Global sensitivity analysis

Here we compute the average sensitivity of each input parameter $j$ using the validation dataset.
The sensitivities are computed by taking the gradient of the predicted output $f(\mathbf{x}_n)$ with respect to each input $x_{n,j}$ averaged over the data:

$$
s_j^f = \sqrt{ \frac{1}{N} \sum_{n=1}^N \Big( \frac{\partial f(\mathbf{x}_n)}{\partial x_{n,j}} \Big)^2 }
$$

In [None]:
def sa_autograd(model, X, reduce=None):   
    """Sensitivity analysis of GP regression model with automatic differentiation.
    
    Args:
        model: Gaussian process regression model
        X (tensor): Input data (design matrix)
        reduce (string): method used to reduce the sensitivity result: sum, mean, none.
    """
    X.requires_grad = True
    # compute gradient of the mean prediction
    model.zero_grad()
    mean, _ = model(X, full_cov=False, noiseless=False)
    gmean = torch.autograd.grad(mean.sum(), X)[0]
    # compute gradient of the variance prediction
    model.zero_grad()
    _, var = model(X, full_cov=False, noiseless=False)
    gvar = torch.autograd.grad(var.sum(), X)[0]
    X.requires_grad = False
    if reduce == "sum":
        return mean, var, torch.sqrt(torch.sum(gmean**2, dim=0)), torch.sqrt(torch.sum(gvar**2, dim=0))
    elif reduce == "mean":
        return mean, var, torch.sqrt(torch.mean(gmean**2, dim=0)), torch.sqrt(torch.mean(gvar**2, dim=0))
    else:
        return mean, var, torch.sqrt(gmean**2), torch.sqrt(gvar**2)

In [None]:
def plot_sensitivity_bar(s_mean, s_var, features=None, normalise=False):
    features = list(range(len(s_mean))) if features is None else features
    
    # normalise
    if normalise:
        s_mean = s_mean / s_mean.sum()
        s_var = s_var / s_var.sum()

    plt.figure(figsize=(6,3))
    plt.title("average sensitivities of the mean prediction")
    plt.bar(range(len(features)), s_mean)
    plt.xticks(range(len(features)), [f"x{i}: {f}" for i,f in enumerate(features)], rotation=90)
    plt.xlabel("Feature"); plt.ylabel("Sensitivity")
    plt.show()

    plt.figure(figsize=(6,3))
    plt.title("average sensitivities of the variance prediction")
    plt.bar(range(len(features)), s_var, color="C1")
    plt.xticks(range(len(features)), [f"x{i}: {f}" for i,f in enumerate(features)], rotation=90)
    plt.xlabel("Feature"); plt.ylabel("Sensitivity")
    plt.show()

In [None]:
_, _, s_mean, s_var = sa_autograd(model, params["data.x_valid"], reduce="mean")

plot_sensitivity_bar(s_mean, s_var, features, normalise=True)

The sensitivities are normalised so they sum to 1 as we are mainly interested in the relative sensitivities.

Notice how only a few of the input parameters seem to have high average sensitivity and thus be important.

If you made note of any particular input parameters while doing the initial data exploration, how does it compare to the sensitivities? 
Do the inputs you noticed correspond to the most important inputs found by the sensitivity analysis?

If you did the optional analysis of the Bayesian linear model, how does the results compare?

Here we used the validation dataset to compute the sensitivities. 
We could also have sampled new inputs in the appropriate range and used that for the sensitivity analysis (since we do not need to know the true outputs in this analysis). 
However, since we know the validation data is sampled at random, we would expect to get very similar results.

If you are familiar with automatic relevance determination (ARD), you can try to compute feature importances based on ARD defined as the inverse of the kernel length scale parameters (available in `params["kernel.lengthscale"]`) and compare the result with the global sensitivity analysis above.
Note that [ARD has been shown to overestimate the importance of nonlinear features](http://proceedings.mlr.press/v89/paananen19a/paananen19a.pdf).

## Local sensitivity analysis

Looking at the sensitivities averaged over the data is useful for identifying the most important inputs.
But we might get a better understanding of the data by considering the predictions and sensitivities along the entire range of variation of each input (while keeping all other inputs fixed at their nominal values).

In [None]:
# helper functions
def predict_sa(x):
    return sa_autograd(model, x, reduce=None)

def predict_and_plot_1d(d, predict_sa, features, target, x_min, x_max, x_nominal, y_lim=None, n_points=100, figsize=(12,3)):
    # create inputs
    x = x_nominal
    X = x.repeat(n_points, 1)
    xd = torch.linspace(x_min[d], x_max[d], n_points)
    X[:,d] = xd
    # predict point
    mean0, var0, s_mean0, s_var0 = predict_sa(x.unsqueeze(0))
    mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()
    std0 = var0.sqrt()
    # predict grid
    mean, var, s_mean, s_var = predict_sa(X)
    mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
    std = var.sqrt().detach()
    # plot
    plt.figure(figsize=figsize)
    # plot mean prediction with uncertainty
    plt.subplot(121)
    plt.title("mean prediction with uncertainty (2*std)")
    plt.plot(xd.numpy(), mean.numpy())
    plt.fill_between(xd.numpy(), (mean.numpy() - 2.0 * std.numpy()), (mean.numpy() + 2.0 * std.numpy()), color='C0', alpha=0.3)
    plt.axvline(x[d].numpy(), color="k", linewidth=1, label=f"{mean0.item():.4f} ({std0.item():.4f})")
    plt.xlim((x_min[d], x_max[d]))
    if y_lim is not None:
        plt.ylim(y_lim)
    plt.xlabel(f"x{d}: {features[d]}")
    plt.ylabel(f"log y: {target}")
    plt.grid()
    plt.legend(loc=4)
    # plot sensitivity of mean prediction
    plt.subplot(122)
    plt.title("sensitivity of mean prediction")
    plt.plot(xd.numpy(), s_mean[:, d].numpy())
    plt.axvline(x[d].numpy(), color="k", linewidth=1, label=f"{s_mean0[:,d].item():.4f}")
    plt.xlim((x_min[d], x_max[d]))
    plt.ylim((0,5))
    plt.xlabel(f"x{d}: {features[d]}")
    plt.ylabel("sensitivity")
    plt.grid()
    plt.legend(loc=4)

In [None]:
for d in range(len(features)):
    predict_and_plot_1d(
        d,
        predict_sa,
        features,
        target,
        x_min=params["data.x_min"].detach().numpy(),
        x_max=params["data.x_max"].detach().numpy(),
        x_nominal=params["data.x_train"][0].detach(),  # the first training point correponds to the nominal values
        y_lim=(params["data.y_min"].item(), params["data.y_max"].item()),
    )

Some of the prediction curves are almost entirely flat because changing their value does not change the output.
These correspond to the inputs with low average sensitivity that we identified above.

Maybe you also notice that some inputs seem to affect the output along their entire range while some other inputs only seem to affect the output at some specific range of values (for example only high or low values). 

For each of the important inputs, try to characterise the effect they have on the output:
 * Is it linear or nonlinear?
 * Is it sensitive along its entire range of values or not?

Rather than looking at the inputs in one dimension, we can also plot two inputs against each other in two dimensions.

In [None]:
def predict_and_plot_2d(d0, d1, predict_sa, features, target, x_min, x_max, x_nominal, y_lim=None, n_points=100, n_levels=21, figsize=(12,10)):
    # create inputs
    x = x_nominal
    X = x.repeat(n_points**2, 1)
    # setup grid
    xd0 = torch.linspace(x_min[d0], x_max[d0], n_points)
    xd1 = torch.linspace(x_min[d1], x_max[d1], n_points)
    grid_xd0, grid_xd1 = torch.meshgrid(xd0, xd1)        
    X[:,d0] = grid_xd0.reshape(len(X))
    X[:,d1] = grid_xd1.reshape(len(X))
    # predict point
    mean0, var0, s_mean0, s_var0 = predict_sa(x.unsqueeze(0))
    mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()
    std0 = var0.sqrt()
    # predict grid
    mean, var, s_mean, s_var = predict_sa(X)
    mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
    std = var.sqrt()

    s_mean0_d = (s_mean0[:, d0] + s_mean0[:, d1]).item()
    s_var0_d = (s_var0[:, d0] + s_var0[:, d1]).item()

    s_mean_d = (s_mean[:, d0] + s_mean[:, d1]).reshape(n_points, n_points)
    s_var_d = (s_var[:, d0] + s_var[:, d1]).reshape(n_points, n_points)

    plt.figure(figsize=figsize)
    # plot mean prediction
    ax = plt.subplot(221)
    plt.title("mean prediction of log y")
    if y_lim is None:
        levels = torch.linspace(mean.min().item(), mean.max().item(), n_levels).numpy()
    else:
        levels = torch.linspace(y_lim[0], y_lim[1], n_levels).numpy()
    plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), mean.reshape(n_points, n_points).numpy(), levels=levels, cmap="plasma")
    plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{mean0.item():.4f} ({std0.item():.4f})")
    plt.axhline(x[d1].numpy(), color="k", linewidth=1)
    plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
    plt.colorbar(shrink=0.9)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
    plt.legend(loc=4)
    # plot uncertainty
    ax = plt.subplot(222)
    plt.title("uncertainty (2*std)")
    levels = torch.linspace(0, 1.0, 21).numpy()
    plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), 2*std.reshape(n_points, n_points).numpy(), levels=levels, cmap="plasma")
    plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{std0.item()*2:.4f}")
    plt.axhline(x[d1].numpy(), color="k", linewidth=1)
    plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
    plt.colorbar(shrink=0.9)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
    plt.legend(loc=4)
    # plot sensitivity of mean prediction
    ax = plt.subplot(223)
    plt.title("sensitivity of mean prediction")
    levels = torch.linspace(0, 5.0, 21).numpy()
    plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_mean_d.numpy(), levels=levels, cmap="plasma")
    plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{s_mean0_d:.4f}")
    plt.axhline(x[d1].numpy(), color="k", linewidth=1)
    plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
    plt.colorbar(shrink=0.9)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
    plt.legend(loc=4)
    # plot sensitivity of uncertainty prediction
    ax = plt.subplot(224)
    plt.title("sensitivity of uncertainty prediction")
    levels = torch.linspace(0, 0.25, 21).numpy()
    plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_var_d.numpy(), levels=levels, cmap="plasma")
    plt.axvline(x[d0].numpy(), color="k", linewidth=1, label=f"{s_var0_d:.4f}")
    plt.axhline(x[d1].numpy(), color="k", linewidth=1)
    plt.xlabel(f"x{d0}: {features[d0]}"); plt.ylabel(f"x{d1}: {features[d1]}")
    plt.colorbar(shrink=0.9)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
    plt.legend(loc=4)
    plt.tight_layout()
    plt.show()

In [None]:
predict_and_plot_2d(
    0, 2,  # <-- change the input dimensions that are plotted here
    predict_sa,
    features,
    target,
    x_min=params["data.x_min"].detach().numpy(),
    x_max=params["data.x_max"].detach().numpy(),
    x_nominal=params["data.x_train"][0].detach(),  # the first training point correponds to the nominal values
    y_lim=(params["data.y_min"].item(), params["data.y_max"].item()),
)

Here we plotted input 0 against input 2.
You can change the inputs that are plotted in the code above.
How about for example inputs 8 and 9?

These figures can reveal interesting properties of the data.
However, even when plotting two inputs against each other along their entire ranges of values, we still need to assume fixed values for all the other inputs.
But changing the value of some sensitive input could potentially interact with other sensitive inputs.
Unfortunately, it is difficult to visualize such effects for high dimensional problems like this one.
In the next notebooks we will try to mitigate this and make exploring the results of the sensitivity analysis more intuitive by creating interactive plots.

As always, we should be aware of the assumptions we made in the analysis and keep them in mind when interpreting the results.
* The validity of the results depends on how well the model fits the data.
* In this example we made the analysis with regards to the log transformed output and care should be taken if we were to back-transform the results to the original scale since this is a nonlinear transformation and the predictive distribution would no longer be Gaussian.