In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth" # mth, pce

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"

if target == "pce":
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "2D_time-epoch=999-val_MAE=0.000-train_MAE=0.725.ckpt"
    )
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "mT_2Dtime_RN18_full3-epoch=999-val_MAE=0.000-train_MAE=36.879.ckpt" 
    )

  from .autonotebook import tqdm as notebook_tqdm


# Import of model and computation of six different attribution methods with two evaluation metrics per method

In [3]:
#### 2D Model

hypparams = {
    "dataset": "Perov_time_2d",
    "dims": 2,
    "bottleneck": False,
    "name": "ResNet18",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness"
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[2, 2, 2, 2],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

dataset = PerovskiteDataset2d_time(
    data_dir,
    transform=normalize_2d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

batch_size = 100

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)

[0.26849303 0.01902202 0.00568256 0.02159704] [0.16095641 0.01068098 0.00295246 0.01455085]
Loaded


In [4]:
# Select observation
n = 1

x_batch = next(iter(loader))

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

x_batch = x_batch[0]

y = float(np.round(y_batch[n].detach().numpy(), 2))



In [5]:
# Init pertubation function for infidelity metric
from tqdm import tqdm

std_noise = 0.1


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise


In [6]:
# Expected Gradients

from captum.attr import GradientShap
from captum.metrics import sensitivity_max, infidelity

method = GradientShap(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(
        x_batch[n].unsqueeze(0),
        n_samples=80,
        stdevs=0.001,
        baselines=x_batch,
        target=0,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(
        sensitivity_max(
            method.attribute, x_batch[n].unsqueeze(0), target=0, baselines=x_batch
        )
    )

attr_eg = torch.cat(attr_sum).mean(dim=0)
infid_eg = torch.Tensor(infid_sum).mean()
sens_eg = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [12:41<00:00,  7.61s/it]


In [7]:
# Integrated Gradients

from captum.attr import IntegratedGradients

method = IntegratedGradients(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr, delta = method.attribute(
        x_batch[n].unsqueeze(0),
        baselines=x_batch[n].unsqueeze(0) * 0,
        return_convergence_delta=True,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(
        sensitivity_max(
            method.attribute,
            x_batch[n].unsqueeze(0),
            target=0,
            baselines=x_batch[n].unsqueeze(0) * 0,
        )
    )

attr_ig = torch.cat(attr_sum).mean(dim=0)
infid_ig = torch.Tensor(infid_sum).mean()
sens_ig = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [31:37<00:00, 18.98s/it]


In [8]:
# Guided Backprob

from captum.attr import GuidedBackprop

method = GuidedBackprop(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_gbp = torch.cat(attr_sum).mean(dim=0)
infid_gbp = torch.Tensor(infid_sum).mean()
sens_gbp = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [00:39<00:00,  2.55it/s]


In [9]:
# Guided GradCAM

from captum.attr import GuidedGradCam

method = GuidedGradCam(model, model.conv1)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.detach().abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_ggc = torch.cat(attr_sum).mean(dim=0)
infid_ggc = torch.Tensor(infid_sum).mean()
sens_ggc = torch.Tensor(sens_sum).mean()


100%|██████████| 100/100 [01:09<00:00,  1.43it/s]


In [9]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots


def format_title(title, subtitle=None, subtitle_font_size=12):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"
x = x_batch.mean(0)

fig = make_subplots(
    rows=2,
    cols=4,
    subplot_titles=(
        format_title("", "ND"),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title("", " "),
        format_title("", " "),
        format_title("", " "),
        format_title("", " "),
    ),
)

colors = [(0, "#FFFFFF"), (1, "#5F3893")]

fig.add_trace(
    go.Heatmap(z=x.numpy()[0], colorscale="gray", showscale=False), row=1, col=1
)
fig.add_trace(
    go.Heatmap(z=np.abs(attr_eg[0]), colorscale=colors, showscale=False), row=2, col=1
)

fig.add_trace(
    go.Heatmap(z=x.numpy()[1], colorscale="gray", showscale=False), row=1, col=2
)
fig.add_trace(
    go.Heatmap(z=np.abs(attr_eg[1]), colorscale=colors, showscale=False), row=2, col=2
)

fig.add_trace(
    go.Heatmap(z=x.numpy()[2], colorscale="gray", showscale=False), row=1, col=3
)
fig.add_trace(
    go.Heatmap(z=np.abs(attr_eg[2]), colorscale=colors, showscale=False), row=2, col=3
)

fig.add_trace(
    go.Heatmap(z=x.numpy()[3], colorscale="gray", showscale=False), row=1, col=4
)
fig.add_trace(
    go.Heatmap(z=np.abs(attr_eg[3]), colorscale=colors, showscale=False), row=2, col=4
)

fig.update_yaxes(showticklabels=False)
fig.update_xaxes(showticklabels=False)


fig.update_layout(
    title_y=0.95,
    title_x=0.01,
    height=500,
    width=700,
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)

fig.write_image("xai/images/" + target + "/2D_time/2D_global_paper.png", scale=2)

fig.show()

In [10]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots


def format_title(title, subtitle=None, font_size=14, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


colors = [(0, "#ffffff"), (0.3, "#ffffff"), (1, "#005C53")]

x = x_batch.mean(0)

fig = make_subplots(
    rows=4,
    cols=5,
    vertical_spacing=0.1,
    subplot_titles=(
        format_title("ND", "Mean Original", font_size=12),
        format_title(
            "Expected Grad.",
            "("
            + str(np.round(infid_eg.numpy(), 4))
            + ", "
            + str(np.round(sens_eg.numpy(), 4))
            + ")",
            font_size=12,
        ),
        format_title(
            "Integrated Grad.",
            "("
            + str(np.round(infid_ig.numpy(), 4))
            + ", "
            + str(np.round(sens_ig.numpy(), 4))
            + ")",
            font_size=12,
        ),
        format_title(
            "Guided Backprob",
            "("
            + str(np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(np.round(sens_gbp.numpy(), 4))
            + ")",
            font_size=12,
        ),
        format_title(
            "Guided GradCAM",
            "("
            + str(np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(np.round(sens_ggc.numpy(), 4))
            + ")",
            font_size=12,
        ),
        format_title("LP725", None, font_size=12),
        None,
        None,
        None,
        None,
        format_title("LP780", None, font_size=12),
        None,
        None,
        None,
        None,
        format_title("SP775", None, font_size=12),
        None,
        None,
        None,
        None,
    ),
)

fig.add_trace(
    go.Heatmap(z=x.numpy()[0], colorscale="gray", showscale=False), row=1, col=1
)
fig.add_trace(
    go.Heatmap(z=attr_eg[0], colorscale=colors, showscale=False), row=1, col=2
)
fig.add_trace(
    go.Heatmap(z=attr_ig[0], colorscale=colors, showscale=False), row=1, col=3
)
fig.add_trace(
    go.Heatmap(z=attr_gbp[0], colorscale=colors, showscale=False), row=1, col=4
)
fig.add_trace(
    go.Heatmap(z=attr_ggc[0], colorscale=colors, showscale=False), row=1, col=5
)

fig.add_trace(
    go.Heatmap(z=x.numpy()[1], colorscale="gray", showscale=False), row=2, col=1
)
fig.add_trace(
    go.Heatmap(z=attr_eg[1], colorscale=colors, showscale=False), row=2, col=2
)
fig.add_trace(
    go.Heatmap(z=attr_ig[1], colorscale=colors, showscale=False), row=2, col=3
)
fig.add_trace(
    go.Heatmap(z=attr_gbp[1], colorscale=colors, showscale=False), row=2, col=4
)
fig.add_trace(
    go.Heatmap(z=attr_ggc[1], colorscale=colors, showscale=False), row=2, col=5
)

fig.add_trace(
    go.Heatmap(z=x.numpy()[2], colorscale="gray", showscale=False), row=3, col=1
)
fig.add_trace(
    go.Heatmap(z=attr_eg[2], colorscale=colors, showscale=False), row=3, col=2
)
fig.add_trace(
    go.Heatmap(z=attr_ig[2], colorscale=colors, showscale=False), row=3, col=3
)
fig.add_trace(
    go.Heatmap(z=attr_gbp[2], colorscale=colors, showscale=False), row=3, col=4
)
fig.add_trace(
    go.Heatmap(z=attr_ggc[2], colorscale=colors, showscale=False), row=3, col=5
)

fig.add_trace(
    go.Heatmap(z=x.numpy()[3], colorscale="gray", showscale=False), row=4, col=1
)
fig.add_trace(
    go.Heatmap(z=attr_eg[3], colorscale=colors, showscale=False), row=4, col=2
)
fig.add_trace(
    go.Heatmap(z=attr_ig[3], colorscale=colors, showscale=False), row=4, col=3
)
fig.add_trace(
    go.Heatmap(z=attr_gbp[3], colorscale=colors, showscale=False), row=4, col=4
)
fig.add_trace(
    go.Heatmap(z=attr_ggc[3], colorscale=colors, showscale=False), row=4, col=5
)

fig.update_yaxes(showticklabels=False)
fig.update_xaxes(showticklabels=False)

if target == "pce":
    subtitle = "Perovskite 1D Model / Target: PCE /(mean Infidelity, mean Sensitivity)"
else:
    subtitle = "Perovskite 1D Model / Target: Mean Thickness / (mean Infidelity, mean Sensitivity)"

fig.update_layout(
    title=format_title(
        "Global Attribution: Mean abs. Attribution (n = "
        + str(batch_size)
        + ") per Wavelength",
        subtitle,
    ),
    title_y=0.97,
    title_x=0.1,
    height=800,
    width=800,
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)

fig.write_image("xai/images/" + target + "/2D_time/2D_cmp_global_wl.png", scale=2)

fig.show()


In [18]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots


def format_title(title, subtitle=None, font_size=14, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


colors = [(0, "#ffffff"), (0.25, "#ffffff"), (1, "#005C53")]

x = x_batch.mean(0)

fig = make_subplots(
    rows=1,
    cols=5,
    vertical_spacing=0.1,
    subplot_titles=(
        format_title("Mean Image", "Original", font_size=12),
        format_title(
            "Expected Grad.",
            "("
            + str(np.round(infid_eg.numpy(), 4))
            + ", "
            + str(np.round(sens_eg.numpy(), 4))
            + ")",
            font_size=12,
        ),
        format_title(
            "Integrated Grad.",
            "("
            + str(np.round(infid_ig.numpy(), 4))
            + ", "
            + str(np.round(sens_ig.numpy(), 4))
            + ")",
            font_size=12,
        ),
        format_title(
            "Guided Backprob",
            "("
            + str(np.round(infid_gbp.numpy(), 4))
            + ", "
            + str(np.round(sens_gbp.numpy(), 4))
            + ")",
            font_size=12,
        ),
        format_title(
            "Guided GradCAM",
            "("
            + str(np.round(infid_ggc.numpy(), 4))
            + ", "
            + str(np.round(sens_ggc.numpy(), 4))
            + ")",
            font_size=12,
        ),
    ),
)

fig.add_trace(
    go.Heatmap(z=x.numpy().sum(0), colorscale="gray", showscale=False), row=1, col=1
)
fig.add_trace(
    go.Heatmap(z=attr_eg.sum(0), colorscale=colors, showscale=False), row=1, col=2
)
fig.add_trace(
    go.Heatmap(z=attr_ig.sum(0), colorscale=colors, showscale=False), row=1, col=3
)
fig.add_trace(
    go.Heatmap(z=attr_gbp.sum(0), colorscale=colors, showscale=False), row=1, col=4
)
fig.add_trace(
    go.Heatmap(z=attr_ggc.sum(0), colorscale=colors, showscale=False), row=1, col=5
)

fig.update_yaxes(showticklabels=False)
fig.update_xaxes(showticklabels=False)

if target == "pce":
    subtitle = "Perovskite 1D Model / Target: PCE /(mean Infidelity, mean Sensitivity)"
else:
    subtitle = "Perovskite 1D Model / Target: Mean Thickness / (mean Infidelity, mean Sensitivity)"

fig.update_layout(
    title=format_title(
        "Global Attribution: Mean abs. Attribution (n = " + str(batch_size) + ")",
        subtitle,
    ),
    title_y=0.90,
    title_x=0.1,
    height=300,
    width=800,
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)

fig.write_image("xai/images/" + target + "/2D_time/2D_cmp_global.png", scale=2)

fig.show()
