In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth" # mth, pce

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"

if target == "pce":
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt"
    )
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt"
    )



  from .autonotebook import tqdm as notebook_tqdm


# Import of model and computation of six different attribution methods with two evaluation metrics per method

In [3]:
#### 1D Model

hypparams = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness"
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

dataset = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

batch_size = 100

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)


tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded


In [4]:
# Select observation
n = 1

x_batch = next(iter(loader))
x = x_batch[0][n]

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

y = float(np.round(y_batch[n].detach().numpy(), 2))



In [65]:
### Only for graphic

import plotly.express as px
import plotly.graph_objects as go

fig = go.Figure()

color = ["#0059A0","#5F3893","#FF8777","#E1462C"]

for i in range(4):
    fig.add_traces(go.Scatter(y=x[i],marker_color=color[i],line=dict(width=4)))

fig.update_layout(
    template="plotly_white",
    height=300,
    width=300,
    showlegend=False,
    
)

fig.update_yaxes(title=None, showticklabels=False)
fig.update_xaxes(title=None, showticklabels=False)
fig.show()

In [5]:
# Init pertubation function for infidelity metric

std_noise = 0.01


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise


In [6]:
# Compute Attribution via expected gradients

from captum.attr import GradientShap
from captum.metrics import sensitivity_max, infidelity

gradient_shap = GradientShap(model)
attr_eg = gradient_shap.attribute(
    x_batch[0][n].unsqueeze(0),
    n_samples=100,
    stdevs=0.001,
    baselines=x_batch[0],
    target=0,
)

infid_eg = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_eg)
sens_eg = sensitivity_max(
    gradient_shap.attribute, x_batch[0][n].unsqueeze(0), target=0, baselines=x_batch[0]
)  # lower is better



In [7]:
# Integrated Gradients

from captum.attr import IntegratedGradients

ig = IntegratedGradients(model)
attr_ig, delta = ig.attribute(
    x_batch[0][n].unsqueeze(0),
    baselines=x_batch[0][n].unsqueeze(0) * 0,
    return_convergence_delta=True,
)

infid_ig = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_ig)
sens_ig = sensitivity_max(
    ig.attribute,
    x_batch[0][n].unsqueeze(0),
    target=0,
    baselines=x_batch[0][n].unsqueeze(0) * 0,
)



In [8]:
# LIME

from captum.attr import Lime
from torch import Tensor
from captum._utils.models.linear_model import SkLearnLinearModel

a = 0.01
kernel_width = 7.0


def similarity_kernel(
    original_input, perturbed_input, perturbed_interpretable_input, **kwargs
) -> Tensor:

    if type(perturbed_input) == tuple:
        original_input = original_input[0]
        perturbed_input = perturbed_input[0]

    # kernel_width will be provided to attribute as a kwarg
    l2_dist = torch.norm(original_input - perturbed_input)
    return torch.exp(-(l2_dist ** 2) / (kernel_width ** 2))


lime = Lime(
    model,
    SkLearnLinearModel("linear_model.Lasso", alpha=a),
    similarity_func=similarity_kernel,
)
attr_lime = lime.attribute(
    inputs=x_batch[0][n].unsqueeze(0), n_samples=500, perturbations_per_eval=20
)

infid_lime = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_lime)
sens_lime = sensitivity_max(
    lime.attribute, x_batch[0][n].unsqueeze(0), n_samples=500, perturbations_per_eval=20
)



  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


In [9]:
# Kernel SHAP

from captum.attr import KernelShap, NoiseTunnel

stdev = 8.0

ks = KernelShap(model)
nt = NoiseTunnel(ks)
attr_ks = nt.attribute(
    x_batch[0][n].unsqueeze(0),
    target=0,
    n_samples=100,
    nt_samples=10,
    stdevs=stdev,
    nt_type="smoothgrad",
)

infid_ks = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_ks)
sens_ks = sensitivity_max(
    nt.attribute,
    x_batch[0][n].unsqueeze(0),
    n_samples=100,
    nt_samples=10,
    stdevs=stdev,
    nt_type="smoothgrad",
)





In [10]:
# Guided Backprob

from captum.attr import GuidedBackprop

gbp = GuidedBackprop(model)
attr_gbp = gbp.attribute(x_batch[0][n].unsqueeze(0), target=0)

infid_gbp = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_gbp)
sens_gbp = sensitivity_max(gbp.attribute, x_batch[0][n].unsqueeze(0))




In [11]:
# Guided GradCAM

from captum.attr import GuidedGradCam

ggc = GuidedGradCam(model, model.conv1)
attr_ggc = ggc.attribute(x_batch[0][n].unsqueeze(0), target=0)
attr_ggc = attr_ggc.detach()

infid_ggc = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_ggc)
sens_ggc = sensitivity_max(ggc.attribute, x_batch[0][n].unsqueeze(0))


# Visualization of single methods and comparision of methods via sum and absolut values

In [21]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    specs=[
        [{"secondary_y": True, "colspan": 4}, None, None, None],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title("", "Total Attribution"),
        format_title("", "Attribution ND"),
        format_title("", "Attribution LP725"),
        format_title("", "Attribution LP780"),
        format_title("", "Attribution SP775"),
    ),
)

fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze().sum(axis=0),
        name="Attribution",
        marker_color=np.where(attr_eg.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
)

fig.add_trace(go.Scatter(y=x[0], name="ND", marker_color="#042940"))
fig.add_trace(go.Scatter(y=x[1], name="LP725", marker_color="#005C53"))
fig.add_trace(go.Scatter(y=x[2], name="LP780", marker_color="#9FC131"))
fig.add_trace(go.Scatter(y=x[3], name="SP775", marker_color="#DBF227"))

fig.update_yaxes(title="Intensity")
fig.update_yaxes(title_text="Attribution", secondary_y=True)
fig.update_xaxes(title="Timestep")

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.5, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[0],
        marker_color=np.where(attr_eg.squeeze()[0] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.update_yaxes(title=None, row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_xaxes(title=None, row=2, col=1)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[1],
        marker_color=np.where(attr_eg.squeeze()[1] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None, row=2, col=2)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=2)
fig.update_xaxes(title=None, row=2, col=2)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[2],
        marker_color=np.where(attr_eg.squeeze()[2] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None, row=2, col=3)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=3)
fig.update_xaxes(title=None, row=2, col=3)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[3],
        marker_color=np.where(attr_eg.squeeze()[3] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)

fig.update_yaxes(title=None, row=2, col=4)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=4)
fig.update_xaxes(title=None, row=2, col=4)

if target == "pce":
    subtitle = "Predicted PCE: "
else:
    subtitle = "Predicted Mean Thickness: "

fig.update_layout(
    title=format_title(
        "Perovskite 1D Model",
        subtitle
        + str(np.round(y,2))
        + " / Method: Expected Gradients / Infidelity = "
        + str(*np.round(infid_eg.numpy(), 4))
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + " / Sensitivity = "
        + str(*np.round(sens_eg.numpy(), 4)),
    ),
    legend_title="Wavelength",
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/"+ target + "/1D/1D_eg.png", scale=2)

fig.show()


In [22]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    subtitle = "Predicted PCE: "
else:
    subtitle = "Predicted Mean Thickness: "


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    specs=[
        [{"secondary_y": True, "colspan": 4}, None, None, None],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title("", "Total Attribution"),
        format_title("", "Attribution ND"),
        format_title("", "Attribution LP725"),
        format_title("", "Attribution LP780"),
        format_title("", "Attribution SP775"),
    ),
)

fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze().sum(axis=0),
        name="Attribution",
        marker_color=np.where(attr_ig.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
)

fig.add_trace(go.Scatter(y=x[0], name="ND", marker_color="#042940"))
fig.add_trace(go.Scatter(y=x[1], name="LP725", marker_color="#005C53"))
fig.add_trace(go.Scatter(y=x[2], name="LP780", marker_color="#9FC131"))
fig.add_trace(go.Scatter(y=x[3], name="SP775", marker_color="#DBF227"))

fig.update_yaxes(title="Intensity")
fig.update_yaxes(title_text="Attribution", secondary_y=True)
fig.update_xaxes(title="Timestep")

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.5, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[0],
        marker_color=np.where(attr_ig.squeeze()[0] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.update_yaxes(title=None, row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_xaxes(title=None, row=2, col=1)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[1],
        marker_color=np.where(attr_ig.squeeze()[1] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None, row=2, col=2)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=2)
fig.update_xaxes(title=None, row=2, col=2)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[2],
        marker_color=np.where(attr_ig.squeeze()[2] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None, row=2, col=3)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=3)
fig.update_xaxes(title=None, row=2, col=3)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[3],
        marker_color=np.where(attr_ig.squeeze()[3] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)

fig.update_yaxes(title=None, row=2, col=4)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=4)
fig.update_xaxes(title=None, row=2, col=4)

fig.update_layout(
    title=format_title(
        "Perovskite 1D Model",
        subtitle
        + str(np.round(y,2))
        + " / Method: Integrated Gradients / Infidelity = "
        + str(*np.round(infid_ig.numpy(), 4))
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + " / Sensitivity = "
        + str(*np.round(sens_ig.numpy(), 4)),
    ),
    legend_title="Wavelength",
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/" + target +"/1D/1D_ig.png", scale=2)

fig.show()


In [23]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    subtitle = "Predicted PCE: "
else:
    subtitle = "Predicted Mean Thickness: "


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    specs=[
        [{"secondary_y": True, "colspan": 4}, None, None, None],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title("", "Total Attribution"),
        format_title("", "Attribution ND"),
        format_title("", "Attribution LP725"),
        format_title("", "Attribution LP780"),
        format_title("", "Attribution SP775"),
    ),
)

fig.add_trace(
    go.Bar(
        y=attr_lime.squeeze().sum(axis=0),
        name="Attribution",
        marker_color=np.where(attr_lime.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
)

fig.add_trace(go.Scatter(y=x[0], name="ND", marker_color="#042940"))
fig.add_trace(go.Scatter(y=x[1], name="LP725", marker_color="#005C53"))
fig.add_trace(go.Scatter(y=x[2], name="LP780", marker_color="#9FC131"))
fig.add_trace(go.Scatter(y=x[3], name="SP775", marker_color="#DBF227"))

fig.update_yaxes(title="Intensity")
fig.update_yaxes(title_text="Attribution", secondary_y=True)
fig.update_xaxes(title="Timestep")

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.5, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_lime.squeeze()[0],
        marker_color=np.where(attr_lime.squeeze()[0] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.update_yaxes(title=None, row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_xaxes(title=None, row=2, col=1)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_lime.squeeze()[1],
        marker_color=np.where(attr_lime.squeeze()[1] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None, row=2, col=2)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=2)
fig.update_xaxes(title=None, row=2, col=2)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_lime.squeeze()[2],
        marker_color=np.where(attr_lime.squeeze()[2] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None, row=2, col=3)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=3)
fig.update_xaxes(title=None, row=2, col=3)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_lime.squeeze()[3],
        marker_color=np.where(attr_lime.squeeze()[3] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)

fig.update_yaxes(title=None, row=2, col=4)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=4)
fig.update_xaxes(title=None, row=2, col=4)

fig.update_layout(
    title=format_title(
        "Perovskite 1D Model",
        subtitle
        + str(np.round(y,2))
        + " / Method: Lime ("
        + "\u237A"
        + " = "
        + str(a)
        + ", exp. kernel width = "
        + str(kernel_width)
        + ") / Infidelity = "
        + str(*np.round(infid_lime.numpy(), 4))
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + " / Sensitivity = "
        + str(*np.round(sens_lime.numpy(), 4)),
    ),
    legend_title="Wavelength",
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/" + target + "/1D/1D_lime.png", scale=2)

fig.show()


In [24]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    subtitle = "Predicted PCE: "
else:
    subtitle = "Predicted Mean Thickness: "


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    specs=[
        [{"secondary_y": True, "colspan": 4}, None, None, None],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title("", "Total Attribution"),
        format_title("", "Attribution ND"),
        format_title("", "Attribution LP725"),
        format_title("", "Attribution LP780"),
        format_title("", "Attribution SP775"),
    ),
)

fig.add_trace(
    go.Bar(
        y=attr_ks.squeeze().sum(axis=0),
        name="Attribution",
        marker_color=np.where(attr_ks.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
)

fig.add_trace(go.Scatter(y=x[0], name="ND", marker_color="#042940"))
fig.add_trace(go.Scatter(y=x[1], name="LP725", marker_color="#005C53"))
fig.add_trace(go.Scatter(y=x[2], name="LP780", marker_color="#9FC131"))
fig.add_trace(go.Scatter(y=x[3], name="SP775", marker_color="#DBF227"))


fig.update_yaxes(title="Intensity")
fig.update_yaxes(title_text="Attribution", secondary_y=True)
fig.update_xaxes(title="Timestep")

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.5, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ks.squeeze()[0],
        marker_color=np.where(attr_ks.squeeze()[0] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.update_yaxes(title=None, row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_xaxes(title=None, row=2, col=1)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ks.squeeze()[1],
        marker_color=np.where(attr_ks.squeeze()[1] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None, row=2, col=2)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=2)
fig.update_xaxes(title=None, row=2, col=2)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ks.squeeze()[2],
        marker_color=np.where(attr_ks.squeeze()[2] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None, row=2, col=3)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=3)
fig.update_xaxes(title=None, row=2, col=3)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_ks.squeeze()[3],
        marker_color=np.where(attr_ks.squeeze()[3] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)

fig.update_yaxes(title=None, row=2, col=4)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=4)
fig.update_xaxes(title=None, row=2, col=4)

fig.update_layout(
    title=format_title(
        "Perovskite 1D Model",
        subtitle
        + str(np.round(y,2))
        + " / Method: Kernel SHAP with Noise Tunnel ("
        + "\u03C3"
        + " = "
        + str(stdev)
        + ") / Infidelity = "
        + str(*np.round(infid_ks.numpy(), 4))
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + " / Sensitivity = "
        + str(*np.round(sens_ks.numpy(), 4)),
    ),
    legend_title="Wavelength",
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/" + target + "/1D/1D_ks.png", scale=2)

fig.show()


In [25]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    subtitle = "Predicted PCE: "
else:
    subtitle = "Predicted Mean Thickness: "


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    specs=[
        [{"secondary_y": True, "colspan": 4}, None, None, None],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title("", "Total Attribution"),
        format_title("", "Attribution ND"),
        format_title("", "Attribution LP725"),
        format_title("", "Attribution LP780"),
        format_title("", "Attribution SP775"),
    ),
)

fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze().sum(axis=0),
        name="Attribution",
        marker_color=np.where(attr_gbp.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
)

fig.add_trace(go.Scatter(y=x[0], name="ND", marker_color="#042940"))
fig.add_trace(go.Scatter(y=x[1], name="LP725", marker_color="#005C53"))
fig.add_trace(go.Scatter(y=x[2], name="LP780", marker_color="#9FC131"))
fig.add_trace(go.Scatter(y=x[3], name="SP775", marker_color="#DBF227"))

fig.update_yaxes(title="Intensity")
fig.update_yaxes(title_text="Attribution", secondary_y=True)
fig.update_xaxes(title="Timestep")

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.5, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[0],
        marker_color=np.where(attr_gbp.squeeze()[0] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.update_yaxes(title=None, row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_xaxes(title=None, row=2, col=1)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[1],
        marker_color=np.where(attr_gbp.squeeze()[1] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None, row=2, col=2)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=2)
fig.update_xaxes(title=None, row=2, col=2)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[2],
        marker_color=np.where(attr_gbp.squeeze()[2] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None, row=2, col=3)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=3)
fig.update_xaxes(title=None, row=2, col=3)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[3],
        marker_color=np.where(attr_gbp.squeeze()[3] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)

fig.update_yaxes(title=None, row=2, col=4)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=4)
fig.update_xaxes(title=None, row=2, col=4)

fig.update_layout(
    title=format_title(
        "Perovskite 1D Model",
        subtitle
        + str(np.round(y,2))
        + " / Method: Guided Backprob / Infidelity = "
        + str(*np.round(infid_gbp.numpy(), 4))
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + " / Sensitivity = "
        + str(*np.round(sens_gbp.numpy(), 4)),
    ),
    legend_title="Wavelength",
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/" + target + "/1D/1D_gbp.png", scale=2)

fig.show()


In [26]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    subtitle = "Predicted PCE: "
else:
    subtitle = "Predicted Mean Thickness: "


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    specs=[
        [{"secondary_y": True, "colspan": 4}, None, None, None],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title("", "Total Attribution"),
        format_title("", "Attribution ND"),
        format_title("", "Attribution LP725"),
        format_title("", "Attribution LP780"),
        format_title("", "Attribution SP775"),
    ),
)

fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze().sum(axis=0),
        name="Attribution",
        marker_color=np.where(attr_ggc.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
)

fig.add_trace(go.Scatter(y=x[0], name="ND", marker_color="#042940"))
fig.add_trace(go.Scatter(y=x[1], name="LP725", marker_color="#005C53"))
fig.add_trace(go.Scatter(y=x[2], name="LP780", marker_color="#9FC131"))
fig.add_trace(go.Scatter(y=x[3], name="SP775", marker_color="#DBF227"))

fig.update_yaxes(title="Intensity")
fig.update_yaxes(title_text="Attribution", secondary_y=True)
fig.update_xaxes(title="Timestep")

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.5, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[0],
        marker_color=np.where(attr_ggc.squeeze()[0] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.update_yaxes(title=None, row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_xaxes(title=None, row=2, col=1)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[1],
        marker_color=np.where(attr_ggc.squeeze()[1] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None, row=2, col=2)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=2)
fig.update_xaxes(title=None, row=2, col=2)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[2],
        marker_color=np.where(attr_ggc.squeeze()[2] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None, row=2, col=3)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=3)
fig.update_xaxes(title=None, row=2, col=3)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.5, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[3],
        marker_color=np.where(attr_ggc.squeeze()[3] < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)

fig.update_yaxes(title=None, row=2, col=4)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=4)
fig.update_xaxes(title=None, row=2, col=4)

fig.update_layout(
    title=format_title(
        "Perovskite 1D Model",
        subtitle
        + str(np.round(y,2))
        + " / Method: Guided GradCAM / Infidelity = "
        + str(*np.round(infid_ggc.numpy(), 4))
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + " / Sensitivity = "
        + str(*np.round(sens_ggc.numpy(), 4)),
    ),
    legend_title="Wavelength",
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/" + target + "/1D/1D_ggc.png", scale=2)

fig.show()


In [27]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    subtitle = "Perovskite 1D Model / Predicted PCE: "
else:
    subtitle = "Perovskite 1D Model / Predicted Mean Thickness: "


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=3,
    specs=[
        [{"secondary_y": True}, {"secondary_y": True}, {"secondary_y": True}],
        [{"secondary_y": True}, {"secondary_y": True}, {"secondary_y": True}],
    ],
    subplot_titles=(
        format_title(
            "",
            "Expected Gradients ("
            + str(*np.round(infid_eg.numpy(), 4))
            + ", "
            + str(*np.round(sens_eg.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Integrated Gradients ("
            + str(*np.round(infid_ig.numpy(), 4))
            + ", "
            + str(*np.round(sens_ig.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "LIME ("
            + str(*np.round(infid_lime.numpy(), 4))
            + ", "
            + str(*np.round(sens_lime.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Kernel SHAP w/ Noise Tunnel ("
            + str(*np.round(infid_ks.numpy(), 4))
            + ", "
            + str(*np.round(sens_ks.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided Backprop ("
            + str(*np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(*np.round(sens_gbp.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided GradCAM ("
            + str(*np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(*np.round(sens_ggc.numpy(), 4))
            + ")",
        ),
    ),
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze().sum(axis=0),
        marker_color=np.where(attr_eg.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze().sum(axis=0),
        marker_color=np.where(attr_ig.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=2,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_lime.squeeze().sum(axis=0),
        marker_color=np.where(attr_lime.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=3,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ks.squeeze().sum(axis=0),
        marker_color=np.where(attr_ks.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze().sum(axis=0),
        marker_color=np.where(attr_gbp.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze().sum(axis=0),
        marker_color=np.where(attr_ggc.squeeze().sum(axis=0) < 0, "red", "green"),
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None, secondary_y=True)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=1, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=1, col=3)
fig.update_yaxes(title="Intensity", row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=2, col=3)
fig.update_xaxes(title="Timesteps", row=2, col=2)

fig.update_layout(
    title=format_title(
        "Method Comparision: Summed Total Attribution",
        subtitle
        + str(np.round(y,2))
        + " / (Infidelity, Sensitivity)",
    ),
    legend_title=None,
    title_y=0.96,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/" + target + "/1D/1D_cmp_sum.png", scale=2)

fig.show()


In [28]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    subtitle = "Perovskite 1D Model / Predicted PCE: "
else:
    subtitle = "Perovskite 1D Model / Predicted Mean Thickness: "


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=3,
    specs=[
        [{"secondary_y": True}, {"secondary_y": True}, {"secondary_y": True}],
        [{"secondary_y": True}, {"secondary_y": True}, {"secondary_y": True}],
    ],
    subplot_titles=(
        format_title(
            "",
            "Expected Gradients ("
            + str(*np.round(infid_eg.numpy(), 4))
            + ", "
            + str(*np.round(sens_eg.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Integrated Gradients ("
            + str(*np.round(infid_ig.numpy(), 4))
            + ", "
            + str(*np.round(sens_ig.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "LIME ("
            + str(*np.round(infid_lime.numpy(), 4))
            + ", "
            + str(*np.round(sens_lime.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Kernel SHAP w/ Noise Tunnel ("
            + str(*np.round(infid_ks.numpy(), 4))
            + ", "
            + str(*np.round(sens_ks.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided Backprop ("
            + str(*np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(*np.round(sens_gbp.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided GradCAM ("
            + str(*np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(*np.round(sens_ggc.numpy(), 4))
            + ")",
        ),
    ),
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze().abs().sum(axis=0),
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze().abs().sum(axis=0),
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=2,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_lime.squeeze().abs().sum(axis=0),
        marker_color="#D6D58E",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=3,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ks.squeeze().abs().sum(axis=0),
        marker_color="#D6D58E",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze().abs().sum(axis=0),
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze().abs().sum(axis=0),
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None, secondary_y=True)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=1, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=1, col=3)
fig.update_yaxes(title="Intensity", row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=2, col=3)
fig.update_xaxes(title="Timesteps", row=2, col=2)

fig.update_layout(
    title=format_title(
        "Method Comparision: Absolute Total Attribution",
        subtitle
        + str(np.round(y,2))
        + " / (Infidelity, Sensitivity)",
    ),
    legend_title=None,
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1200,
)

fig.write_image("xai/images/" + target + "/1D/1D_cmp_abs.png", scale=2)

fig.show()


: 