In [2]:
import os

os.chdir("..")
os.getcwd()


'/home/l727n/Projects/Applied Projects/ml_perovskite'

In [3]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"
checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

path_to_checkpoint = join(
    checkpoint_dir, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt"
)



IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



In [29]:
d = np.load("./xai/results/eval_1D_results.npz")
log_n_max = 2.7
log_n_ticks = 0.4
n_list = np.logspace(0, log_n_max, int(log_n_max / log_n_ticks), base=10.0, dtype=int)

In [101]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1,
    cols=5,
    subplot_titles=(
        format_title(
            "Importance",
            "Sensitivity-N |" + "\u2197" + "|" ,
        ),
        format_title(
            "",
            "Insertion " + "\u2197",
        ),
        format_title(
            "",
            "Deletion " + "\u2197",
        ),
        format_title(
            "Robustness",
            "Infidelity " + "\u2198",
        ),
        format_title(
            "",
            "Sensitivity " + "\u2198",
        ),
    ),
)


fig.add_trace(
    go.Scatter(y=d["arr_0"][0].mean(0),x=n_list, name="EG", marker_color="#042940", showlegend=False, mode='lines+markers'),
    row=1,
    col=1,
)

fig.add_traces([go.Scatter(x=n_list, y=d["arr_0"][0].mean(0) + 1.960 * (np.std(d["arr_0"][0], axis = 0)/d["arr_0"][0].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           showlegend = False),
                go.Scatter(x=n_list, y = d["arr_0"][0].mean(0) - 1.960 * (np.std(d["arr_0"][0], axis = 0)/d["arr_0"][0].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           name = '95% confidence interval', showlegend=False,
                           fill='tonexty', fillcolor = 'rgba(4,41,64,0.2)')])

fig.add_trace(
    go.Scatter(y=d["arr_0"][1].mean(0),x=n_list, name="IG", marker_color="#005C53", showlegend=False, mode='lines+markers'),
    row=1,
    col=1,
)

fig.add_traces([go.Scatter(x=n_list, y=d["arr_0"][1].mean(0) + 1.960 * (np.std(d["arr_0"][1], axis = 0)/d["arr_0"][1].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           showlegend = False),
                go.Scatter(x=n_list, y = d["arr_0"][1].mean(0) - 1.960 * (np.std(d["arr_0"][1], axis = 0)/d["arr_0"][1].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           name = '95% confidence interval', showlegend=False,
                           fill='tonexty', fillcolor = 'rgba(0,92,83,0.2)')])

fig.add_trace(
    go.Scatter(y=d["arr_0"][2].mean(0),x=n_list, name="GBC", marker_color="#9FC131", showlegend=False, mode='lines+markers'),
    row=1,
    col=1,
)

fig.add_traces([go.Scatter(x=n_list, y=d["arr_0"][2].mean(0) + 1.960 * (np.std(d["arr_0"][2], axis = 0)/d["arr_0"][2].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           showlegend = False),
                go.Scatter(x=n_list, y = d["arr_0"][2].mean(0) - 1.960 * (np.std(d["arr_0"][2], axis = 0)/d["arr_0"][2].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           name = '95% confidence interval', showlegend=False,
                           fill='tonexty', fillcolor = 'rgba(159,193,49,0.2)')])

fig.add_trace(
    go.Scatter(y=d["arr_0"][3].mean(0),x=n_list, name="GGC", marker_color="#DBF227", showlegend=False, mode='lines+markers'),
    row=1,
    col=1,
)

fig.add_traces([go.Scatter(x=n_list, y=d["arr_0"][3].mean(0) + 1.960 * (np.std(d["arr_0"][3], axis = 0)/d["arr_0"][3].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           showlegend = False),
                go.Scatter(x=n_list, y = d["arr_0"][3].mean(0) - 1.960 * (np.std(d["arr_0"][3], axis = 0)/d["arr_0"][3].shape[0]),
                           mode = 'lines', line_color = 'rgba(0,0,0,0)',
                           name = '95% confidence interval', showlegend=False,
                           fill='tonexty', fillcolor = 'rgba(219,242,39,0.2)')])

fig.add_trace(
    go.Box(y=d["arr_1"][0], name="EG", marker_color="#042940", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=d["arr_1"][1], name="IG", marker_color="#005C53", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=d["arr_1"][2], name="GBC", marker_color="#9FC131", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=d["arr_1"][3], name="GGC", marker_color="#DBF227", showlegend=False),
    row=1,
    col=2,
)

fig.add_trace(
    go.Box(y=d["arr_2"][0], name="EG", marker_color="#042940", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=d["arr_2"][1], name="IG", marker_color="#005C53", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=d["arr_2"][2], name="GBC", marker_color="#9FC131", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=d["arr_2"][3], name="GGC", marker_color="#DBF227", showlegend=False),
    row=1,
    col=3,
)

fig.add_trace(
    go.Box(y=d["arr_3"][0].squeeze(), name="EG", marker_color="#042940", showlegend=False),
    row=1,
    col=4,
)
fig.add_trace(
    go.Box(y=d["arr_3"][1].squeeze(), name="IG", marker_color="#005C53", showlegend=False),
    row=1,
    col=4,
)
fig.add_trace(
    go.Box(y=d["arr_3"][2].squeeze(), name="GBC", marker_color="#9FC131", showlegend=False),
    row=1,
    col=4,
)
fig.add_trace(
    go.Box(y=d["arr_3"][3].squeeze(), name="GGC", marker_color="#DBF227", showlegend=False),
    row=1,
    col=4,
)

fig.add_trace(
    go.Box(y=d["arr_4"][0].squeeze(), name="EG", marker_color="#042940"),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(y=d["arr_4"][1].squeeze(), name="IG", marker_color="#005C53"),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(y=d["arr_4"][2].squeeze(), name="GBC", marker_color="#9FC131"),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(y=d["arr_4"][3].squeeze(), name="GGC", marker_color="#DBF227"),
    row=1,
    col=5,
)

fig.update_xaxes(title="N", row=1, col=1)
fig.update_xaxes(title="Methods", row=1, col=2)

fig.update_yaxes(title="Pearson Correlation", row=1, col=1)
fig.update_yaxes(title="Area Between Curves",rangemode="tozero", row=1, col=2)
fig.update_yaxes(rangemode="tozero",row=1, col=3)
fig.update_yaxes(title="Score",rangemode="tozero", row=1, col=4)
fig.update_yaxes(rangemode="tozero",row=1, col=5)

fig.update_layout(
    title=format_title(
        "Attribution Evaluation",
        "Perovskite 1D Model (n = " + str(d["arr_0"][3].shape[0]) + ")",
    ),
    legend_title=None,
    legend={'traceorder':'normal'},
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=400,
    width=1800,
)

fig.write_image("xai/images/1D/1D_evaluation.png", scale=2)

fig.show()

# Import of model and computation of six different attribution methods with two evaluation metrics per method

In [3]:
#### 1D Model (no border)

hypparams = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

dataset = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
)

batch_size = 100

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)


Load data
tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded


In [4]:
# Select batch
x_batch = next(iter(loader))

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

x_batch = x_batch[0]


In [5]:
# Init pertubation function for infidelity metric
from tqdm import tqdm

std_noise = 0.01


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise


In [6]:
from captum.attr import GradientShap
from captum.metrics import sensitivity_max, infidelity

method = GradientShap(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(
        x_batch[n].unsqueeze(0),
        n_samples=80,
        stdevs=0.001,
        baselines=x_batch,
        target=0,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(
        sensitivity_max(
            method.attribute, x_batch[n].unsqueeze(0), target=0, baselines=x_batch
        )
    )

attr_eg = torch.cat(attr_sum).mean(dim=0)
infid_eg = torch.Tensor(infid_sum).mean()
sens_eg = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [11:47<00:00,  7.08s/it]


In [7]:
from captum.attr import IntegratedGradients

method = IntegratedGradients(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr, delta = method.attribute(
        x_batch[n].unsqueeze(0),
        baselines=x_batch[n].unsqueeze(0) * 0,
        return_convergence_delta=True,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(
        sensitivity_max(
            method.attribute,
            x_batch[n].unsqueeze(0),
            target=0,
            baselines=x_batch[n].unsqueeze(0) * 0,
        )
    )

attr_ig = torch.cat(attr_sum).mean(dim=0)
infid_ig = torch.Tensor(infid_sum).mean()
sens_ig = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [1:04:28<00:00, 38.69s/it]


In [8]:
from captum.attr import GuidedBackprop

method = GuidedBackprop(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_gbp = torch.cat(attr_sum).mean(dim=0)
infid_gbp = torch.Tensor(infid_sum).mean()
sens_gbp = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [01:24<00:00,  1.18it/s]


In [9]:
from captum.attr import GuidedGradCam

method = GuidedGradCam(model, model.conv1)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.detach().abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_ggc = torch.cat(attr_sum).mean(dim=0)
infid_ggc = torch.Tensor(infid_sum).mean()
sens_ggc = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [02:38<00:00,  1.58s/it]


## Visualization of global abs. attribution

In [10]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

x = x_batch.mean(0)

fig = make_subplots(
    rows=2,
    cols=2,
    specs=[
        [{"secondary_y": True}, {"secondary_y": True}],
        [{"secondary_y": True}, {"secondary_y": True}],
    ],
    subplot_titles=(
        format_title(
            "",
            "Expected Gradients ("
            + str(np.round(infid_eg.numpy(), 4))
            + ", "
            + str(np.round(sens_eg.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Integrated Gradients ("
            + str(np.round(infid_ig.numpy(), 4))
            + ", "
            + str(np.round(sens_ig.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided Backprop ("
            + str(np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(np.round(sens_gbp.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided GradCAM ("
            + str(np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(np.round(sens_ggc.numpy(), 4))
            + ")",
        ),
    ),
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze().sum(axis=0),
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze().sum(axis=0),
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=2,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze().sum(axis=0),
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze().sum(axis=0),
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None, secondary_y=True)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=1, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=1, col=2)
fig.update_yaxes(title="Intensity", row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=2, col=2)
fig.update_xaxes(title="Timesteps", row=2, col=1)
fig.update_xaxes(title="Timesteps", row=2, col=2)

fig.update_layout(
    title=format_title(
        "Global Attribution: Mean abs. Attribution (n = " + str(batch_size) + ")",
        "Perovskite 1D Model / (mean Infidelity, mean Sensitivity)",
    ),
    legend_title=None,
    title_y=0.965,
    title_x=0.035,
    template="plotly_white",
    height=800,
    width=2000,
)

fig.write_image("xai/images/1D/1D_cmp_global.png", scale=2)

fig.show()


In [12]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

x = x_batch.mean(0)

fig = make_subplots(
    rows=4,
    cols=4,
    specs=[
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title(
            "Expected Gradients ("
            + str(np.round(infid_eg.numpy(), 4))
            + ", "
            + str(np.round(sens_eg.numpy(), 4))
            + ")",
            "ND",
        ),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title(
            "Integrated Gradients ("
            + str(np.round(infid_ig.numpy(), 4))
            + ", "
            + str(np.round(sens_ig.numpy(), 4))
            + ")",
            " ",
        ),
        None,
        None,
        None,
        format_title(
            "Guided Backprob ("
            + str(np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(np.round(sens_gbp.numpy(), 4))
            + ")",
            " ",
        ),
        None,
        None,
        None,
        format_title(
            "Guided GradCAM ("
            + str(np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(np.round(sens_ggc.numpy(), 4))
            + ")",
            " ",
        ),
        None,
        None,
        None,
    ),
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[0],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[1],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[2],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[3],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=4,
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[0],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[1],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[2],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[3],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=3,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[0],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=3,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[1],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=3,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[2],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=3,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[3],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=4,
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=4,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[0],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=4,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[1],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=4,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[2],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=4,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[3],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=4,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None, secondary_y=True)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=1, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=1, col=4)
fig.update_xaxes(title="Timesteps", row=4, col=1)

fig.update_layout(
    title=format_title(
        "Global Attribution: Mean abs. Attribution (n = "
        + str(batch_size)
        + ") per Wavelength",
        "Perovskite 1D Model / (mean Infidelity, mean Sensitivity)",
    ),
    legend_title=None,
    title_y=0.97,
    title_x=0.5,
    template="plotly_white",
    height=1000,
    width=2400,
)

fig.write_image("xai/images/1D/1D_cmp_global_wl.png", scale=2)

fig.show()
