### Imports & Paths

In [None]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "pce"  # mth, pce


In [None]:
import torch
import torch.nn as nn
import numpy as np
import scipy
import plotly.graph_objects as go
import scipy.stats as ss

from torch.utils.data import DataLoader
from tqdm import tqdm
from captum.attr import GradientShap, IntegratedGradients, GuidedBackprop, GuidedGradCam
from captum.metrics import sensitivity_max, infidelity
from os.path import join
from plotly.subplots import make_subplots

from data.perovskite_dataset import PerovskiteDataset2d_time
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from data.augmentations.perov_2d import normalize as normalize_2d
from base_model import seed_worker

data_dir = os.getcwd() + "/preprocessed"

if target == "pce":
    checkpoint_dir = (
        "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"
    )

    path_to_checkpoint = join(checkpoint_dir, "2D_time-epoch=999-val_MAE=0.000-train_MAE=0.725.ckpt")
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(checkpoint_dir, "mT_2Dtime_RN18_full3-epoch=999-val_MAE=0.000-train_MAE=36.879.ckpt")


### Model Init

In [None]:
#### 2D Model

hypparams = {
    "dataset": "Perov_time_2d",
    "dims": 2,
    "bottleneck": False,
    "name": "ResNet18",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness",
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[2, 2, 2, 2],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

dataset = PerovskiteDataset2d_time(
    data_dir,
    transform=normalize_2d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

batch_size = 100

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)


In [None]:
# Select observation
n = 1

x_batch = next(iter(loader))
x = x_batch[0][n]

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

y = float(np.round(y_batch[n].detach().numpy(), 2))

std_noise = 0.01


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise


### Local Attribution Computation and Evaluation

#### Expected Gradients

In [None]:
gradient_shap = GradientShap(model)
attr_eg = gradient_shap.attribute(
    x_batch[0][n].unsqueeze(0),
    n_samples=100,
    # stdevs=0.001,
    baselines=x_batch[0],
    target=0,
)

infid_eg = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_eg)
sens_eg = sensitivity_max(
    gradient_shap.attribute, x_batch[0][n].unsqueeze(0), target=0, baselines=x_batch[0]
)  # lower is better


#### Integrated Gradients

In [None]:
ig = IntegratedGradients(model)
attr_ig, delta = ig.attribute(
    x_batch[0][n].unsqueeze(0),
    baselines=x_batch[0][n].unsqueeze(0) * 0,
    return_convergence_delta=True,
)

infid_ig = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_ig)
sens_ig = sensitivity_max(
    ig.attribute,
    x_batch[0][n].unsqueeze(0),
    target=0,
    baselines=x_batch[0][n].unsqueeze(0) * 0,
)


#### Guided Backprob

In [None]:
gbp = GuidedBackprop(model)
attr_gbp = gbp.attribute(x_batch[0][n].unsqueeze(0), target=0)

infid_gbp = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_gbp)
sens_gbp = sensitivity_max(gbp.attribute, x_batch[0][n].unsqueeze(0))


#### Guided GradCAM

In [None]:
ggc = GuidedGradCam(model, model.conv1)
attr_ggc = ggc.attribute(x_batch[0][n].unsqueeze(0), target=0)
attr_ggc = attr_ggc.detach()

infid_ggc = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_ggc)
sens_ggc = sensitivity_max(ggc.attribute, x_batch[0][n].unsqueeze(0))


## Attribution Visualization

#### Preprocessing

In [None]:
if target == "pce":
    attr_eg = ss.zscore(attr_eg.squeeze().numpy(), axis=None)
    attr_ig = ss.zscore(attr_ig.squeeze().numpy(), axis=None)
    attr_gbp = ss.zscore(attr_gbp.squeeze().numpy(), axis=None)
    attr_ggc = ss.zscore(attr_ggc.squeeze().numpy(), axis=None)

    q_eg = np.quantile(attr_eg, 0.9996)
    q_ig = np.quantile(attr_ig, 0.9996)
    q_gbp = np.quantile(attr_gbp, 0.9996)
    q_ggc = np.quantile(attr_ggc, 0.9996)

    attr_eg = np.clip(attr_eg, -q_eg, q_eg)
    attr_ig = np.clip(attr_ig, -q_ig, q_ig)
    attr_gbp = np.clip(attr_gbp, -q_gbp, q_gbp)
    attr_ggc = np.clip(attr_ggc, -q_ggc, q_ggc)

    attr_eg = ss.zscore(attr_eg, axis=None)
    attr_ig = ss.zscore(attr_ig, axis=None)
    attr_gbp = ss.zscore(attr_gbp, axis=None)
    attr_ggc = ss.zscore(attr_ggc, axis=None)

    attr_eg = np.clip(attr_eg, -q_eg, q_eg)
    attr_ig = np.clip(attr_ig, -q_ig, q_ig)
    attr_gbp = np.clip(attr_gbp, -q_gbp, q_gbp)
    attr_ggc = np.clip(attr_ggc, -q_ggc, q_ggc)

else:
    attr_eg = attr_eg.squeeze().numpy()
    attr_ig = attr_ig.squeeze().numpy()
    attr_gbp = attr_gbp.squeeze().numpy()
    attr_ggc = attr_ggc.squeeze().numpy()


In [None]:
def format_title(title, subtitle=None, subtitle_font_size=12):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


fig = make_subplots(
    rows=2,
    cols=4,
    subplot_titles=(
        format_title("", "ND"),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title("", " "),
        format_title("", " "),
        format_title("", " "),
        format_title("", " "),
    ),
)

colors = [(0, "#FFFFFF"), (1, "#5F3893")]

for i in range(4):
    fig.add_trace(go.Heatmap(z=x.numpy()[i], colorscale="gray", showscale=False), row=1, col=i + 1)
    fig.add_trace(go.Heatmap(z=np.abs(attr_eg[i]), colorscale=colors, showscale=False), row=2, col=i + 1)

fig.update_yaxes(showticklabels=False)
fig.update_yaxes(title="Original Image", row=1, col=1)
fig.update_yaxes(title="Attribution", row=2, col=1)

fig.update_xaxes(showticklabels=False)

if target == "pce":
    subtitle = "Predicted PCE: "
else:
    subtitle = "Predicted Mean Thickness: "

fig.update_layout(
    title=format_title(
        "Perovskite 2D Time Model",
        subtitle
        + str(np.round(y, 2))
        + " / Method: Expected Gradients / Infidelity = "
        + str(*np.round(infid_eg.numpy(), 4))
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + " / Sensitivity = "
        + str(*np.round(sens_eg.numpy(), 4)),
    ),
    title_y=0.95,
    title_x=0.01,
    height=500,
    width=700,
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)

fig.write_image("xai/images/" + target + "/2D_time/2D_eg.png", scale=2)

fig.show()


In [None]:
def format_title(title, subtitle=None, font_size=14, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


fig = make_subplots(
    rows=4,
    cols=5,
    vertical_spacing=0.1,
    subplot_titles=(
        format_title("ND", "Original", font_size=12),
        format_title(
            "Expected Grad.",
            "(" + str(*np.round(infid_eg.numpy(), 4)) + ", " + str(*np.round(sens_eg.numpy(), 4)) + ")",
            font_size=12,
        ),
        format_title(
            "Integrated Grad.",
            "(" + str(*np.round(infid_ig.numpy(), 4)) + ", " + str(*np.round(sens_ig.numpy(), 4)) + ")",
            font_size=12,
        ),
        format_title(
            "Guided Backprob",
            "(" + str(*np.round(infid_gbp.numpy(), 4)) + ", " + str(*np.round(sens_gbp.numpy(), 4)) + ")",
            font_size=12,
        ),
        format_title(
            "Guided GradCAM",
            "(" + str(*np.round(infid_ggc.numpy(), 4)) + ", " + str(*np.round(sens_ggc.numpy(), 4)) + ")",
            font_size=12,
        ),
        format_title("LP725", None, font_size=12),
        None,
        None,
        None,
        None,
        format_title("LP780", None, font_size=12),
        None,
        None,
        None,
        None,
        format_title("SP775", None, font_size=12),
        None,
        None,
        None,
        None,
    ),
)

for row in range(4):
    for i in range(4):
        fig.add_trace(go.Heatmap(z=x.numpy()[i], colorscale="gray", showscale=False), row=row + 1, col=1)
        fig.add_trace(go.Heatmap(z=attr_eg[i], colorscale=colors, showscale=False), row=row + 1, col=2)
        fig.add_trace(go.Heatmap(z=attr_ig[i], colorscale=colors, showscale=False), row=row + 1, col=3)
        fig.add_trace(go.Heatmap(z=attr_gbp[i], colorscale=colors, showscale=False), row=row + 1, col=4)
        fig.add_trace(go.Heatmap(z=attr_ggc[i], colorscale=colors, showscale=False), row=row + 1, col=5)


fig.update_yaxes(showticklabels=False)
fig.update_xaxes(showticklabels=False)

fig.update_layout(
    title=format_title(
        "Method & Wavelength Comparision",
        "Perovskite 2D Time Model / "
        + subtitle
        + str(np.round(y, 2))
        + " / (Infidelity"
        + " ("
        + "\u03C3"
        + "("
        + "\u03B5"
        + ") = "
        + str(std_noise)
        + ")"
        + ", Sensitivity)",
    ),
    title_y=0.97,
    title_x=0.1,
    height=800,
    width=800,
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)

fig.write_image("xai/images/" + target + "/2D_time/2D_cmp.png", scale=2)

fig.show()
