In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth" # mth, pce

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"

if target == "pce":
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt"
    )
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt"
    )


  from .autonotebook import tqdm as notebook_tqdm


# Import of model and computation of six different attribution methods with two evaluation metrics per method

In [3]:
#### 1D Model (no border)

hypparams = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness"
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

dataset = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

batch_size = 100

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)


tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded


In [4]:
# Select batch
x_batch = next(iter(loader))

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

x_batch = x_batch[0]


In [5]:
# Init pertubation function for infidelity metric
from tqdm import tqdm

std_noise = 0.01


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise


In [6]:
from captum.attr import GradientShap
from captum.metrics import sensitivity_max, infidelity

method = GradientShap(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(
        x_batch[n].unsqueeze(0),
        n_samples=80,
        stdevs=0.001,
        baselines=x_batch,
        target=0,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(
        sensitivity_max(
            method.attribute, x_batch[n].unsqueeze(0), target=0, baselines=x_batch
        )
    )

attr_eg = torch.cat(attr_sum).mean(dim=0)
infid_eg = torch.Tensor(infid_sum).mean()
sens_eg = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [10:46<00:00,  6.47s/it]


In [7]:
from captum.attr import IntegratedGradients

method = IntegratedGradients(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr, delta = method.attribute(
        x_batch[n].unsqueeze(0),
        baselines=x_batch[n].unsqueeze(0) * 0,
        return_convergence_delta=True,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(
        sensitivity_max(
            method.attribute,
            x_batch[n].unsqueeze(0),
            target=0,
            baselines=x_batch[n].unsqueeze(0) * 0,
        )
    )

attr_ig = torch.cat(attr_sum).mean(dim=0)
infid_ig = torch.Tensor(infid_sum).mean()
sens_ig = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [1:07:26<00:00, 40.46s/it]


In [8]:
from captum.attr import GuidedBackprop

method = GuidedBackprop(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_gbp = torch.cat(attr_sum).mean(dim=0)
infid_gbp = torch.Tensor(infid_sum).mean()
sens_gbp = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [01:33<00:00,  1.07it/s]


In [9]:
from captum.attr import GuidedGradCam

method = GuidedGradCam(model, model.conv1)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.detach().abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_ggc = torch.cat(attr_sum).mean(dim=0)
infid_ggc = torch.Tensor(infid_sum).mean()
sens_ggc = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [02:48<00:00,  1.69s/it]


## Visualization of global abs. attribution

In [10]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

x = x_batch.mean(0)

fig = make_subplots(
    rows=2,
    cols=2,
    specs=[
        [{"secondary_y": True}, {"secondary_y": True}],
        [{"secondary_y": True}, {"secondary_y": True}],
    ],
    subplot_titles=(
        format_title(
            "",
            "Expected Gradients ("
            + str(np.round(infid_eg.numpy(), 4))
            + ", "
            + str(np.round(sens_eg.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Integrated Gradients ("
            + str(np.round(infid_ig.numpy(), 4))
            + ", "
            + str(np.round(sens_ig.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided Backprop ("
            + str(np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(np.round(sens_gbp.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided GradCAM ("
            + str(np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(np.round(sens_ggc.numpy(), 4))
            + ")",
        ),
    ),
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze().sum(axis=0),
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze().sum(axis=0),
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=2,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze().sum(axis=0),
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze().sum(axis=0),
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None, secondary_y=True)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=1, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=1, col=2)
fig.update_yaxes(title="Intensity", row=2, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=2, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=2, col=2)
fig.update_xaxes(title="Timesteps", row=2, col=1)
fig.update_xaxes(title="Timesteps", row=2, col=2)

if target == "pce":
    subtitle = "Perovskite 1D Model / Target: PCE /(mean Infidelity, mean Sensitivity)"
else:
    subtitle = "Perovskite 1D Model / Target: Mean Thickness / (mean Infidelity, mean Sensitivity)"

fig.update_layout(
    title=format_title(
        "Global Attribution: Mean abs. Attribution (n = " + str(batch_size) + ")",
        subtitle=subtitle,
    ),
    legend_title=None,
    title_y=0.965,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1000,
)

fig.write_image("xai/images/" + target + "/1D/1D_cmp_global.png", scale=2)

fig.show()


In [11]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, font_size=12, subtitle_font_size=10):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

x = x_batch.mean(0)

fig = make_subplots(
    rows=4,
    cols=4,
    specs=[
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
        [
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
            {"secondary_y": True},
        ],
    ],
    subplot_titles=(
        format_title(
            "Expected Gradients ("
            + str(np.round(infid_eg.numpy(), 4))
            + ", "
            + str(np.round(sens_eg.numpy(), 4))
            + ")",
            "ND",
        ),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title(
            "Integrated Gradients ("
            + str(np.round(infid_ig.numpy(), 4))
            + ", "
            + str(np.round(sens_ig.numpy(), 4))
            + ")",
            " ",
        ),
        None,
        None,
        None,
        format_title(
            "Guided Backprob ("
            + str(np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(np.round(sens_gbp.numpy(), 4))
            + ")",
            " ",
        ),
        None,
        None,
        None,
        format_title(
            "Guided GradCAM ("
            + str(np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(np.round(sens_ggc.numpy(), 4))
            + ")",
            " ",
        ),
        None,
        None,
        None,
    ),
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[0],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[1],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[2],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze()[3],
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=4,
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[0],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[1],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[2],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze()[3],
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=4,
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=3,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[0],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=3,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[1],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=3,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[2],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=3,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze()[3],
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=3,
    col=4,
)


fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=4,
    col=1,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[0],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=4,
    col=2,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[1],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=2,
)

fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=4,
    col=3,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[2],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=3,
)

fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=4,
    col=4,
)
fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze()[3],
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=4,
    col=4,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None, secondary_y=True)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title_text=None, secondary_y=True, row=1, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=1, col=4)
fig.update_xaxes(title="Timesteps", row=4, col=1)

if target == "pce":
    subtitle = "Perovskite 1D Model / Target: PCE /(mean Infidelity, mean Sensitivity)"
else:
    subtitle = "Perovskite 1D Model / Target: Mean Thickness /(mean Infidelity, mean Sensitivity)"

fig.update_layout(
    title=format_title(
        "Global Attribution: Mean abs. Attribution (n = "
        + str(batch_size)
        + ") per Wavelength",
        subtitle,
    ),
    legend_title=None,
    title_y=0.97,
    title_x=0.5,
    template="plotly_white",
    height=700,
    width=1400,
)

fig.write_image("xai/images/" + target +"/1D/1D_cmp_global_wl.png", scale=2)

fig.show()
