# Interactive sensitivity analysis

## Google Colab

If you are running this notebook with Google colab, you need to install the [pyro](https://pyro.ai/) package and download the trained model parameters by uncommenting and running the commands in the cells below.

In [None]:
# install dependencies
# !pip3 install pyro-ppl

In [None]:
# download model parameters
# !wget https://github.com/BIG-MAP/sa_p2d_sei_interactive/raw/main/sgpr_params_sei.p
# !wget https://github.com/BIG-MAP/sa_p2d_sei_interactive/raw/main/sgpr_params_icl.p

## Import dependencies

In [None]:
from IPython.display import display

import matplotlib.pyplot as plt
import torch
import pyro
import pyro.contrib.gp as gp
import ipywidgets as widgets
from matplotlib.ticker import FormatStrFormatter

pyro.set_rng_seed(0)
print(f"torch version: {torch.__version__}")
print(f"pyro version: {pyro.__version__}")

torch.set_default_dtype(torch.float64)

## Load params

In [None]:
features = [
    "i_app (A)",
    "rp_pos (m)",
    "Eeq_side (V)",
    "kappa_film (S/m)",
    "epsl_pos",
    "Dl_elect (m^2/s)",
    "Ds_pos (m^2/s)",
    "i0ref_pos (A/m^2)",
    "E_min (V)",
    "i0_SEI (A/m^2)",
    "csmax_pos (mol/m^3)",
    "cl_0 (mol/m^3)",
    "t_plus",
    "i0ref_metal (A/m^2)",
    "sigma_pos"
]

SEI = "SEI thickness (m)"
ICL = "Irreversible charge loss (%)"

# Select the target
target = SEI
# target = ICL

In [None]:
pyro.clear_param_store()

if target == SEI:
    pyro.get_param_store().load("sgpr_params_sei.p")
if target == ICL:
    pyro.get_param_store().load("sgpr_params_icl.p")

params = pyro.get_param_store()
params.keys()

In [None]:
x_nominal = torch.tensor([
    1.3,
    5.5e-06,
    0.4,
    0.00024,
    0.3,
    3.75e-10,
    3.6e-14,
    0.96,
    0.05,
    4.5e-07,
    31500.0,
    1150.0,
    0.363,
    100.0,
    100.0
])

## Setup model

In [None]:
kernel = gp.kernels.RBF(input_dim=params["data.x_train"].shape[1], variance=params["kernel.variance"], lengthscale=params["kernel.lengthscale"])
model = gp.models.SparseGPRegression(params["data.x_train"], params["data.y_train"], kernel, Xu=params["Xu"], noise=params["noise"])

In [None]:
class SensitivityAnalysisAutograd:
    
    def __init__(self, model, y_scale=1.0, y_offset=0.0):
        self.model = model
        self.y_scale = y_scale
        self.y_offset= y_offset
    
    def __call__(self, x, reduce=None):
        x.requires_grad = True
        # compute gradient of the mean prediction
        self.model.zero_grad()
        y_mean, _ = self.model(x, full_cov=False, noiseless=False)
        y_mean = y_mean * self.y_scale + self.y_offset
        g_mean = torch.autograd.grad(y_mean.sum(), x)[0]
        # compute gradient of the variance prediction
        self.model.zero_grad()
        _, y_var = self.model(x, full_cov=False, noiseless=False)
        y_var = y_var * self.y_scale**2
        g_var = torch.autograd.grad(y_var.sum(), x)[0]
        # finish up
        x.requires_grad = False
        if reduce == "sum":
            return y_mean, y_var, torch.sqrt(torch.sum(g_mean**2, dim=0)), torch.sqrt(torch.sum(g_var**2, dim=0))
        elif reduce == "mean":
            return y_mean, y_var, torch.sqrt(torch.mean(g_mean**2, dim=0)), torch.sqrt(torch.mean(g_var**2, dim=0))
        else:
            return y_mean, y_var, torch.sqrt(g_mean**2), torch.sqrt(g_var**2)

saa = SensitivityAnalysisAutograd(model, y_scale=params["norm.y_scale"], y_offset=params["norm.y_offset"])

## Interactive exploration

### Interactive global

In [None]:
def create_predict_and_plot_global(saa, features, n_sample=1000, normalise=False, figsize=(12,7)):
    
    def predict_and_plot(**x_dict):
        x_min = torch.tensor([x_dict[f][0] for f in features])
        x_max = torch.tensor([x_dict[f][1] for f in features])
        
        # create inputs
        X = torch.distributions.Uniform(x_min, x_max).sample((n_sample,))
        
        # predict
        mean, var, s_mean, s_var = saa(X, reduce="mean")
        mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
        std = var.sqrt().detach()
        
        # normalise
        if normalise:
            s_mean = s_mean / s_mean.sum()
            s_var = s_var / s_var.sum()

        plt.figure(figsize=figsize)
        # plot sensitivity of mean prediction
        plt.subplot(121)
        plt.bar(range(len(features)), s_mean, label="s_mean")
        if normalise:
            plt.ylim((0,1.0))
        plt.xticks(range(len(features)), ["$x_{" + str(i+1) + "}$: " + f"{f}" for i,f in enumerate(features)], rotation=90)
        plt.ylabel("$s_*(\mu)$")
        plt.grid(axis='y')
        # plot sensitivity of var prediction
        plt.subplot(122)
        plt.bar(range(len(features)), s_var, color="C1", label="s_var")
        if normalise:
            plt.ylim((0,1.0))
        plt.xticks(range(len(features)), ["$x_{" + str(i+1) + "}$: " + f"{f}" for i,f in enumerate(features)], rotation=90)
        plt.ylabel("$s_*(\sigma^2)$")
        plt.grid(axis='y')
        plt.tight_layout()
        plt.show()
    return predict_and_plot


def interactive_global(on_change, features, x_min, x_max, n_steps=20):
    sliders = {}
    for i, f in enumerate(features):
        sliders[f] = widgets.FloatRangeSlider(
            value=[x_min[i], x_max[i]], 
            min=x_min[i],
            max=x_max[i],
            step=(x_max[i] - x_min[i]) / n_steps,
            description=f"x{i+1}: {f}",
            readout_format=".1f",
        )
    # setup ui
    out = widgets.interactive_output(on_change, sliders)
    controls = widgets.VBox(list(sliders.values()))
    ui = widgets.HBox([controls, out])
    # display ui
    display(ui)

In [None]:
# on change function
on_change = create_predict_and_plot_global(
    saa,
    features,
    n_sample=5000,
    normalise=True,
)

# setup ui
interactive_global(
    on_change,
    features,
    x_min=torch.zeros(len(features)),
    x_max=torch.ones(len(features)),
)

### Interactive 1D

In [None]:
def create_predict_and_plot_1d(saa, features, target, x_min, x_max, x_offset=0.0, x_scale=1.0, y_lim=None, n_points=100, figsize=(12,7)):
    
    def predict_and_plot(d, **x_dict):
        d = d - 1
        x_list = [x_dict[f] for f in features]
        
        # create inputs
        x = torch.tensor(x_list)
        X = x.repeat(n_points, 1)
        xd = torch.linspace(x_min[d], x_max[d], n_points)
        X[:,d] = xd
        
        # predict point
        mean0, var0, s_mean0, s_var0 = saa(x.unsqueeze(0))
        mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()
        std0 = var0.sqrt()
        # predict grid
        mean, var, s_mean, s_var = saa(X)
        mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
        std = var.sqrt().detach()
        
        # denormalise x for plotting
        x_min_plot = x_min * x_scale + x_offset
        x_max_plot = x_max * x_scale + x_offset
        x = x * x_scale + x_offset
        X = None
        xd = xd * x_scale[d] + x_offset[d]
        
        xmargin = (x_max_plot[d] - x_min_plot[d]) * 0.005
        plt.figure(figsize=figsize)
        # plot mean prediction with uncertainty
        plt.subplot(221)
        plt.title("Mean prediction with uncertainty")
        plt.plot(xd.numpy(), mean.numpy(), label="$\mu\pm2\sigma$")
        plt.fill_between(xd.numpy(), (mean.numpy() - 2.0 * std.numpy()), (mean.numpy() + 2.0 * std.numpy()), color='C0', alpha=0.3)
        plt.axvline(x[d].numpy(), color="k", linewidth=1, linestyle="--", label=f"{x[d]:.2f}, {mean0.item():.2f} ({std0.item():.2f})")
        plt.xlim((x_min_plot[d]-xmargin, x_max_plot[d]+xmargin))
        if y_lim is not None:
            plt.ylim(y_lim)
        plt.xlabel("$x_{" + str(d+1) + "}$: " + f"{features[d]}")
        plt.ylabel(f"y: log {target}")
        plt.grid()
        plt.legend(loc=1)
        # plot uncertainty
        plt.subplot(222)
        plt.title("Uncertainty prediction")
        plt.plot(xd.numpy(), 2*std.numpy(), label="$2\sigma$")
        plt.axvline(x[d].numpy(), color="k", linewidth=1, linestyle="--", label=f"{x[d]:.2f}, {2 * std0.item():.2f}")
        plt.xlim((x_min_plot[d]-xmargin, x_max_plot[d]+xmargin))
        plt.ylim((0,1))
        plt.xlabel(f"x{d}: {features[d]}")
        plt.ylabel("Uncertainty")
        plt.grid()
        plt.legend(loc=1)
        # plot sensitivity of mean
        plt.subplot(223)
        plt.title("Sensitivity of mean prediction")
        plt.plot(xd.numpy(), s_mean[:, d].numpy(), label="$s(\mu)$")
        plt.axvline(x[d].numpy(), color="k", linewidth=1, linestyle="--", label=f"{x[d]:.2f}, {s_mean0[:,d].item():.2f}")
        plt.xlim((x_min_plot[d]-xmargin, x_max_plot[d]+xmargin))
        plt.ylim((0,10))
        plt.xlabel(f"x{d}: {features[d]}")
        plt.ylabel("Sensitivity")
        plt.grid()
        plt.legend(loc=1)
        # plot sensitivity of var
        plt.subplot(224)
        plt.title("Sensitivity of uncertainty prediction")
        plt.plot(xd.numpy(), s_var[:, d].numpy(), label="$s(\sigma^2)$")
        plt.axvline(x[d].numpy(), color="k", linewidth=1, linestyle="--", label=f"{x[d]:.2f}, {s_var0[:,d].item():.2f}")
        plt.xlim((x_min_plot[d]-xmargin, x_max_plot[d]+xmargin))
        plt.ylim((0,0.3))
        plt.xlabel(f"x{d}: {features[d]}")
        plt.ylabel("Sensitivity")
        plt.grid()
        plt.legend(loc=1)
        
        plt.tight_layout()
        plt.show()
    return predict_and_plot


def interactive_1d(on_change, features, x_min, x_max, x_init, n_steps=20):
    sliders = {}
    sliders["d"] = widgets.IntSlider(value=1, min=1, max=len(features), description="dim")
    for i, f in enumerate(features):
        sliders[f] = widgets.FloatSlider(
            value=x_init[i], 
            min=x_min[i],
            max=x_max[i],
            step=(x_max[i] - x_min[i]) / n_steps,
            description=f"x{i+1}: {f}",
            readout_format=".1f",
        )
    # setup ui
    out = widgets.interactive_output(on_change, sliders)
    controls = widgets.VBox(list(sliders.values()))
    ui = widgets.HBox([controls, out])
    # display ui
    display(ui)

In [None]:
# on change function
on_change_1d = create_predict_and_plot_1d(
    saa,
    features,
    target,
    x_min=torch.zeros(len(features)),
    x_max=torch.ones(len(features)),
    x_offset=params["norm.x_offset"].detach(),
    x_scale=params["norm.x_scale"].detach(),
    y_lim=((params["data.y_min"] * params["norm.y_scale"] + params["norm.y_offset"]).item(), (params["data.y_max"] * params["norm.y_scale"] + params["norm.y_offset"]).item())
)

# setup ui
interactive_1d(
    on_change_1d,
    features,
    x_min=torch.zeros(len(features)),
    x_max=torch.ones(len(features)),
    x_init=(x_nominal - params["norm.x_offset"]) / params["norm.x_scale"],
)

## Interactive 2D

In [None]:
# setup on_change function
def create_predict_and_plot_2d(saa, features, target, x_min, x_max, x_offset=0.0, x_scale=1.0, y_lim=None, s_lim=10.0, n_points=100, n_levels=21, figsize=(12,10)):
    
    def predict_and_plot(d0, d1, **x_dict):
        d0, d1 = d0-1, d1-1
        x_list = [x_dict[f] for f in features]
        
        # create inputs
        x = torch.tensor(x_list)
        X = x.repeat(n_points**2, 1)
        
        # setup grid
        xd0 = torch.linspace(x_min[d0], x_max[d0], n_points)
        xd1 = torch.linspace(x_min[d1], x_max[d1], n_points)
        grid_xd0, grid_xd1 = torch.meshgrid(xd0, xd1)        
        X[:,d0] = grid_xd0.reshape(len(X))
        X[:,d1] = grid_xd1.reshape(len(X))
        
        # predict point
        mean0, var0, s_mean0, s_var0 = saa(x.unsqueeze(0))
        mean0, var0, s_mean0, s_var0 = mean0.detach(), var0.detach(), s_mean0.detach(), s_var0.detach()
        std0 = var0.sqrt()
        # predict grid
        mean, var, s_mean, s_var = saa(X)
        mean, var, s_mean, s_var = mean.detach(), var.detach(), s_mean.detach(), s_var.detach()
        std = var.sqrt()
        
        s_mean0_d = (s_mean0[:, d0] + s_mean0[:, d1]).item()
        s_var0_d = (s_var0[:, d0] + s_var0[:, d1]).item()
        
        s_mean_d = (s_mean[:, d0] + s_mean[:, d1]).reshape(n_points, n_points)
        s_var_d = (s_var[:, d0] + s_var[:, d1]).reshape(n_points, n_points)
        
        # denormalise x
        x_min_plot = x_min * x_scale + x_offset
        x_max_plot = x_max * x_scale + x_offset 
        x = x * x_scale + x_offset
        X = None
        xd0 = None
        xd1 = None
        grid_xd0 = grid_xd0 * x_scale[d0] + x_offset[d0]
        grid_xd1 = grid_xd1 * x_scale[d1] + x_offset[d1]
        
        margin0 = (x_max_plot[d0] - x_min_plot[d0]) * 0.005
        margin1 = (x_max_plot[d1] - x_min_plot[d1]) * 0.005
        
        plt.figure(figsize=figsize)
        # plot mean prediction
        ax = plt.subplot(221)
        plt.title("Mean prediction")
        if y_lim is None:
            levels = torch.linspace(mean.min().item(), mean.max().item(), n_levels).numpy()
        else:
            levels = torch.linspace(y_lim[0], y_lim[1], n_levels).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), mean.reshape(n_points, n_points).numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, linestyle="--", label=f"{mean0.item():.4f} ({std0.item():.4f})")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1, linestyle="--")
        plt.xlabel("$x_{" + str(d0+1) + "}$: " + f"{features[d0]}"); plt.ylabel("$x_{" + str(d1+1) + "}$: " + f"{features[d1]}")
        cbar = plt.colorbar(shrink=0.9)
        cbar.set_label(f"$\mu$: log {target}")
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        plt.xlim((min(x_min_plot[d0], x[d0] - margin0), max(x_max_plot[d0], x[d0] + margin0)))
        plt.ylim((min(x_min_plot[d1], x[d1] - margin1), max(x_max_plot[d1], x[d1] + margin1)))
    
        # plot uncertainty
        ax = plt.subplot(222)
        plt.title("Uncertainty prediction")
        levels = torch.linspace(0, 1.0, 21).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), 2*std.reshape(n_points, n_points).numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, linestyle="--", label=f"{std0.item()*2:.4f}")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1, linestyle="--")
        plt.xlabel("$x_{" + str(d0+1) + "}$: " + f"{features[d0]}"); plt.ylabel("$x_{" + str(d1+1) + "}$: " + f"{features[d1]}")
        cbar = plt.colorbar(shrink=0.9)
        cbar.set_label("Uncertainty: $2\sigma$")
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        plt.xlim((min(x_min_plot[d0], x[d0] - margin0), max(x_max_plot[d0], x[d0] + margin0)))
        plt.ylim((min(x_min_plot[d1], x[d1] - margin1), max(x_max_plot[d1], x[d1] + margin1)))
        
        # plot sensitivity of mean prediction
        ax = plt.subplot(223)
        plt.title("Sensitivity of mean prediction")
        levels = torch.linspace(0, s_lim, 21).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_mean_d.numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, linestyle="--", label=f"{s_mean0_d:.4f}")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1, linestyle="--")
        plt.xlabel("$x_{" + str(d0+1) + "}$: " + f"{features[d0]}"); plt.ylabel("$x_{" + str(d1+1) + "}$: " + f"{features[d1]}")
        cbar = plt.colorbar(shrink=0.9)
        cbar.set_label("Sensitivity: $s(\mu)$")
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        plt.xlim((min(x_min_plot[d0], x[d0] - margin0), max(x_max_plot[d0], x[d0] + margin0)))
        plt.ylim((min(x_min_plot[d1], x[d1] - margin1), max(x_max_plot[d1], x[d1] + margin1)))
        
        # plot sensitivity of uncertainty prediction
        ax = plt.subplot(224)
        plt.title("Sensitivity of uncertainty prediction")
        levels = torch.linspace(0, 0.25, 21).numpy()
        plt.contourf(grid_xd0.numpy(), grid_xd1.numpy(), s_var_d.numpy(), levels=levels, cmap="plasma")
        plt.axvline(x[d0].numpy(), color="k", linewidth=1, linestyle="--", label=f"{s_var0_d:.4f}")
        plt.axhline(x[d1].numpy(), color="k", linewidth=1, linestyle="--")
        plt.xlabel("$x_{" + str(d0+1) + "}$: " + f"{features[d0]}"); plt.ylabel("$x_{" + str(d1+1) + "}$: " + f"{features[d1]}")
        cbar = plt.colorbar(shrink=0.9)
        cbar.set_label("Sensitivity: $s(\sigma^2)$")
        ax.yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        plt.legend(loc=4)
        plt.xlim((min(x_min_plot[d0], x[d0] - margin0), max(x_max_plot[d0], x[d0] + margin0)))
        plt.ylim((min(x_min_plot[d1], x[d1] - margin1), max(x_max_plot[d1], x[d1] + margin1)))
        
        plt.tight_layout()
        plt.show()
    
    return predict_and_plot


def interactive_2d(on_change, features, x_min, x_max, x_init, n_steps=20):
    sliders = {}
    sliders["d0"] = widgets.IntSlider(value=1, min=1, max=len(features), description="dim 1")
    sliders["d1"] = widgets.IntSlider(value=2, min=1, max=len(features), description="dim 2")
    for i, f in enumerate(features):
        sliders[f] = widgets.FloatSlider(
            value=x_init[i],
            min=x_min[i],
            max=x_max[i],
            step=(x_max[i] - x_min[i]) / n_steps,
            description=f"x{i+1}: {features[i]}",
            readout_format=".1f",
        )
    # setup ui
    out = widgets.interactive_output(on_change, sliders)
    controls = widgets.VBox(list(sliders.values()))
    ui = widgets.HBox([controls, out])
    # display ui
    display(ui)

In [None]:
# on change function  
on_change_2d = create_predict_and_plot_2d(
    saa,
    features,
    target,
    x_min=torch.zeros(len(features)),
    x_max=torch.ones(len(features)),
    x_offset=params["norm.x_offset"].detach(),
    x_scale=params["norm.x_scale"].detach(),
    y_lim=((params["data.y_min"] * params["norm.y_scale"] + params["norm.y_offset"]).item(), (params["data.y_max"] * params["norm.y_scale"] + params["norm.y_offset"]).item()),
    s_lim=15.0,
)

# setup ui
interactive_2d(
    on_change_2d,
    features,
    x_min=torch.zeros(len(features)),
    x_max=torch.ones(len(features)),
    x_init=(x_nominal - params["norm.x_offset"]) / params["norm.x_scale"],
)