In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth" # mth, pce

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"

if target == "pce":
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "3D-epoch=999-val_MAE=0.000-train_MAE=0.360.ckpt"
    )
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "mT_3D_SF_full-epoch=999-val_MAE=0.000-train_MAE=20.877.ckpt" 
    )

  from .autonotebook import tqdm as notebook_tqdm


# Import of model and computation of six different attribution methods with two evaluation metrics per method

In [3]:
#### 3D Model

hypparams = {
    "dataset": "Perov_3d",
    "dims": 3,
    "bottleneck": False,
    "name": "SlowFast",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness"
}

model = SlowFast.load_from_checkpoint(
    path_to_checkpoint, num_classes=1, hypparams=hypparams
)

print("Loaded")
model.eval()

dataset = PerovskiteDataset3d(
    data_dir,
    transform=normalize_3d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

batch_size = 256

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)


tensor([0.2687, 0.0188, 0.0056, 0.0215]) tensor([0.1645, 0.0109, 0.0030, 0.0147])
Loaded


In [4]:
# Select observation
n = 1

x_batch = next(iter(loader))

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

x_batch = x_batch[0]

y = float(np.round(y_batch[n].detach().numpy(), 2))



In [5]:
# Init pertubation function for infidelity metric

std_noise = 0.1


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise


In [6]:
# Compute Attribution via expected gradients

from captum.attr import GradientShap
from captum.metrics import sensitivity_max, infidelity

gradient_shap = GradientShap(model)
attr_eg = gradient_shap.attribute(
    x_batch[n].unsqueeze(0),
    n_samples=200,
    stdevs=0.001,
    baselines=x_batch,
    target=0,
)

# infid_eg = infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr_eg)
# sens_eg = sensitivity_max(
#     gradient_shap.attribute, x_batch[n].unsqueeze(0), target=0, baselines=x_batch
# )  # lower is better



In [None]:
# Integrated Gradients

from captum.attr import IntegratedGradients

ig = IntegratedGradients(model)
attr_ig, delta = ig.attribute(
    x_batch[n].unsqueeze(0),
    baselines=x_batch[n].unsqueeze(0) * 0,
    return_convergence_delta=True,
)

infid_ig = infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr_ig)
sens_ig = sensitivity_max(
    ig.attribute,
    x_batch[n].unsqueeze(0),
    target=0,
    baselines=x_batch[n].unsqueeze(0) * 0,
)



In [7]:
# Guided Backprob

from captum.attr import GuidedBackprop
from captum.metrics import sensitivity_max, infidelity

gbp = GuidedBackprop(model)
attr_gbp = gbp.attribute(x_batch[n].unsqueeze(0), target=0)

infid_gbp = infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr_gbp)
sens_gbp = sensitivity_max(gbp.attribute, x_batch[n].unsqueeze(0))




In [39]:
# Guided GradCAM

from captum.attr import GuidedGradCam

ggc = GuidedGradCam(model, model.model.blocks[0].multipathway_blocks[1].conv)
attr_ggc = ggc.attribute(x_batch[n].unsqueeze(0), target=0)
attr_ggc = attr_ggc.detach()

infid_ggc = infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr_ggc)
sens_ggc = sensitivity_max(ggc.attribute, x_batch[n].unsqueeze(0))


# Visualization of single methods

In [None]:
import scipy.stats as ss
import plotly.figure_factory as ff

if target == "pce":
    attr_eg = ss.zscore(attr_eg.squeeze().numpy(), axis=None)
    attr_ig = ss.zscore(attr_ig.squeeze().numpy(), axis=None)
    attr_gbp = ss.zscore(attr_gbp.squeeze().numpy(), axis=None)
    attr_ggc = ss.zscore(attr_ggc.squeeze().numpy(), axis=None)

    q_eg = np.quantile(attr_eg, 0.9996)
    q_ig = np.quantile(attr_ig, 0.9996)
    q_gbp = np.quantile(attr_gbp, 0.9996)
    q_ggc = np.quantile(attr_ggc, 0.9996)

    attr_eg = np.clip(attr_eg, -q_eg, q_eg)
    attr_ig = np.clip(attr_ig, -q_ig, q_ig)
    attr_gbp = np.clip(attr_gbp, -q_gbp, q_gbp)
    attr_ggc = np.clip(attr_ggc, -q_ggc, q_ggc)

    attr_eg = ss.zscore(attr_eg, axis=None)
    attr_ig = ss.zscore(attr_ig, axis=None)
    attr_gbp = ss.zscore(attr_gbp, axis=None)
    attr_ggc = ss.zscore(attr_ggc, axis=None)

    attr_eg = np.clip(attr_eg, -q_eg, q_eg)
    attr_ig = np.clip(attr_ig, -q_ig, q_ig)
    attr_gbp = np.clip(attr_gbp, -q_gbp, q_gbp)
    attr_ggc = np.clip(attr_ggc, -q_ggc, q_ggc)

    #group_labels = ["distplot"]
    #fig = ff.create_distplot([attr_gbp[0].flatten()], group_labels)
    #fig.show()
else:
    attr_eg = attr_eg.squeeze().numpy()
    attr_ig = attr_ig.squeeze().numpy()
    attr_gbp = attr_gbp.squeeze().numpy()
    attr_ggc = attr_ggc.squeeze().numpy()
    q_eg = np.quantile(attr_eg, 0.9996)
    q_ig = np.quantile(attr_ig, 0.9996)
    q_gbp = np.quantile(attr_gbp, 0.9996)
    q_ggc = np.quantile(attr_ggc, 0.9996)

    #group_labels = ["distplot"]
    #fig = ff.create_distplot([attr_gbp[0].flatten()], group_labels)
    #fig.show()


In [None]:
# Normalization for Bar charts


def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))


nd_bar = NormalizeData(
    NormalizeData(np.abs(attr_eg[0]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ig[0]).sum((1, 2)))
    + NormalizeData(np.abs(attr_gbp[0]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ggc[0]).sum((1, 2)))
)
lp725_bar = NormalizeData(
    NormalizeData(np.abs(attr_eg[1]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ig[1]).sum((1, 2)))
    + NormalizeData(np.abs(attr_gbp[1]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ggc[1]).sum((1, 2)))
)
lp780_bar = NormalizeData(
    NormalizeData(np.abs(attr_eg[2]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ig[2]).sum((1, 2)))
    + NormalizeData(np.abs(attr_gbp[2]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ggc[2]).sum((1, 2)))
)
sp775_bar = NormalizeData(
    NormalizeData(np.abs(attr_eg[3]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ig[3]).sum((1, 2)))
    + NormalizeData(np.abs(attr_gbp[3]).sum((1, 2)))
    + NormalizeData(np.abs(attr_ggc[3]).sum((1, 2)))
)



In [8]:
import scipy.stats as ss
attr_gbp =  attr_eg.squeeze().numpy()
q_gbp = np.quantile(attr_gbp, 0.9996)
attr_gbp = np.clip(attr_gbp, -q_gbp, q_gbp)

In [11]:
def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

from plotly.subplots import make_subplots
import plotly.graph_objects as go

cd = ["#E1462C", "#0059A0", "#5F3893", "#FF8777","#0A2C6E", "#CEDEEB"]
frames = [0,6,12,30]

fig = make_subplots(
    rows=2,
    cols=6,
        specs=[[{}, {}, {}, {}, {"colspan": 2}, None],
        [{},{},{},{},{"colspan": 2}, None]],
    subplot_titles=(
            format_title(" ", "Frame " + str(frames[0]), font_size=14),
            format_title(" ", "Frame " + str(frames[1]), font_size=14),
            format_title(" ", "Frame " + str(frames[2]), font_size=14),
            format_title("", "Frame " + str(frames[3]), font_size=14),
            None,
            None,None,None,None,None,None,None,
    )
)

colors = [(0, "#E1462C"),(0.22, "#ffffff"),(0.75, "#ffffff"), (1, "#0059A0")]
wl = 2
bars = [cd[5]] * 36

bars[frames[0]] = "#E3AF5F"
bars[frames[1]] = "#E3AF5F"
bars[frames[2]] = "#E3AF5F"
bars[frames[3]] = "#E3AF5F"

for idx, x in enumerate(frames):
    fig.add_trace(
        go.Heatmap(z=x_batch[0][wl][x], colorscale="gray", showscale=False), row=1, col=idx+1
    )

for idx, x in enumerate(frames):
    fig.add_trace(
        go.Heatmap(
            z=attr_gbp[wl][x], colorscale=colors, showscale=False, #zmin=-q_eg, zmax=q_eg
        ),
        row=2,
        col=idx+1,
    )

fig.add_trace(
    go.Bar(
        y=np.abs(attr_gbp[wl]).sum((1, 2)),
        marker_color=bars,
        opacity=1,
        showlegend=False,
        marker_line_width=0,
    ),
    row=2,
    col=5,
)

for i in range(5):
    fig.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=i,
    )

    fig.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=i,
    )


fig.update_yaxes(tickfont= dict(size=16, family="Helvetica", color="rgb(0,0,0)"),col=5)
fig.update_xaxes(tickvals = [0,frames[0],frames[1],frames[2],frames[3],35], ticktext = [0,0,120,240,600,719], tickfont= dict(size=16, family="Helvetica", color="rgb(0,0,0)"), col=5)
fig.update_yaxes(col=1,row=1, title = "Original Image")
fig.update_yaxes(col=1,row=2, title = "Attribution")

fig.update_layout(
    template="plotly_white",
    title_x=0.07,
    height=530,
    width=1000,
)


fig.write_image("xai/images/"+ target + "/3D/3D_paper.png", scale=2)

fig.show()

In [25]:
from PIL import Image
import os
import glob
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

colors = [(0, "#F00B48"), (0.45, "#ffffff"), (0.55, "#ffffff"), (1, "#00BE34")]
N = x_batch[0][0].shape[0]

for i in range(0, N):
    fig1 = make_subplots(
        rows=4,
        cols=6,
        vertical_spacing=0.1,
        subplot_titles=(
            format_title("ND", "Original Image", font_size=14),
            format_title(
                "Expected Grad.",
                "("
                + str(*np.round(infid_eg.numpy(), 4))
                + ", "
                + str(*np.round(sens_eg.numpy(), 4))
                + ")",
                font_size=14,
            ),
            format_title(
                "Integrated Grad.",
                "("
                + str(*np.round(infid_ig.numpy(), 4))
                + ", "
                + str(*np.round(sens_ig.numpy(), 4))
                + ")",
                font_size=14,
            ),
            format_title(
                "Guided Backprob",
                "("
                + str(*np.round(infid_gbp.numpy(), 4))
                + ", "
                + str(*np.round(sens_gbp.numpy(), 4))
                + ")",
                font_size=14,
            ),
            format_title(
                "Guided GradCAM",
                "("
                + str(*np.round(infid_ggc.numpy(), 4))
                + ", "
                + str(*np.round(sens_ggc.numpy(), 4))
                + ")",
                font_size=14,
            ),
            format_title("Aggregated Attr.", "Abs. Sum", font_size=14),
            format_title("LP725", None, font_size=14),
            None,
            None,
            None,
            None,
            None,
            format_title("LP780", None, font_size=14),
            None,
            None,
            None,
            None,
            None,
            format_title("SP775", None, font_size=14),
            None,
            None,
            None,
            None,
            None,
        ),
    )

    fig1.add_trace(
        go.Heatmap(z=x_batch[0][0][i], colorscale="gray", showscale=False), row=1, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=x_batch[0][1][i], colorscale="gray", showscale=False), row=2, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=x_batch[0][2][i], colorscale="gray", showscale=False), row=3, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=x_batch[0][3][i], colorscale="gray", showscale=False), row=4, col=1
    )

    fig1.add_trace(
        go.Heatmap(
            z=attr_eg[0][i], colorscale=colors, showscale=False, zmin=-q_eg, zmax=q_eg
        ),
        row=1,
        col=2,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_eg[1][i], colorscale=colors, showscale=False, zmin=-q_eg, zmax=q_eg
        ),
        row=2,
        col=2,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_eg[2][i], colorscale=colors, showscale=False, zmin=-q_eg, zmax=q_eg
        ),
        row=3,
        col=2,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_eg[3][i], colorscale=colors, showscale=False, zmin=-q_eg, zmax=q_eg
        ),
        row=4,
        col=2,
    )

    fig1.add_trace(
        go.Heatmap(
            z=attr_ig[0][i], colorscale=colors, showscale=False, zmin=-q_ig, zmax=q_ig
        ),
        row=1,
        col=3,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_ig[1][i], colorscale=colors, showscale=False, zmin=-q_ig, zmax=q_ig
        ),
        row=2,
        col=3,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_ig[2][i], colorscale=colors, showscale=False, zmin=-q_ig, zmax=q_ig
        ),
        row=3,
        col=3,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_ig[3][i], colorscale=colors, showscale=False, zmin=-q_ig, zmax=q_ig
        ),
        row=4,
        col=3,
    )

    fig1.add_trace(
        go.Heatmap(
            z=attr_gbp[0][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_gbp,
            zmax=q_gbp,
        ),
        row=1,
        col=4,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_gbp[1][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_gbp,
            zmax=q_gbp,
        ),
        row=2,
        col=4,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_gbp[2][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_gbp,
            zmax=q_gbp,
        ),
        row=3,
        col=4,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_gbp[3][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_gbp,
            zmax=q_gbp,
        ),
        row=4,
        col=4,
    )

    fig1.add_trace(
        go.Heatmap(
            z=attr_ggc[0][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_ggc,
            zmax=q_ggc,
        ),
        row=1,
        col=5,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_ggc[1][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_ggc,
            zmax=q_ggc,
        ),
        row=2,
        col=5,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_ggc[2][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_ggc,
            zmax=q_ggc,
        ),
        row=3,
        col=5,
    )
    fig1.add_trace(
        go.Heatmap(
            z=attr_ggc[3][i],
            colorscale=colors,
            showscale=False,
            zmin=-q_ggc,
            zmax=q_ggc,
        ),
        row=4,
        col=5,
    )

    fig1.add_trace(
        go.Bar(
            y=nd_bar,
            marker_color=np.where(nd_bar == nd_bar[i], "red", "#042940"),
            opacity=0.5,
            showlegend=False,
            marker_line_width=0,
        ),
        row=1,
        col=6,
    )
    fig1.add_trace(
        go.Bar(
            y=lp725_bar,
            marker_color=np.where(lp725_bar == lp725_bar[i], "red", "#042940"),
            opacity=0.5,
            showlegend=False,
            marker_line_width=0,
        ),
        row=2,
        col=6,
    )
    fig1.add_trace(
        go.Bar(
            y=lp780_bar,
            marker_color=np.where(lp780_bar == lp780_bar[i], "red", "#042940"),
            opacity=0.5,
            showlegend=False,
            marker_line_width=0,
        ),
        row=3,
        col=6,
    )
    fig1.add_trace(
        go.Bar(
            y=sp775_bar,
            marker_color=np.where(sp775_bar == sp775_bar[i], "red", "#042940"),
            opacity=0.5,
            showlegend=False,
            marker_line_width=0,
        ),
        row=4,
        col=6,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=1,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=1,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=2,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=2,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=3,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=3,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=4,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=4,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=5,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=5,
    )

    fig1.update_yaxes(showticklabels=False, col=6)

    if target == "pce":
        subtitle = "Predicted PCE: "
    else:
        subtitle = "Predicted Mean Thickness: "


    fig1.update_layout(
        title=format_title(
            "Method & Wavelength Comparision",
            "Perovskite 3D Video Model / " + subtitle
            + str(np.round(y,2))
            + " / (Infidelity"
            + " ("
            + "\u03C3"
            + "("
            + "\u03B5"
            + ") = "
            + str(std_noise)
            + ")"
            + f", Sensitivity) / Frame {i}",
        ),
        title_y=0.98,
        template="plotly_white",
        title_x=0.07,
        height=1000,
        width=1200,
    )

    fig1.write_image("xai/images/" + target + "/3D/frame_0" + str(i) + "_.png", scale=2)


imgs = (
    Image.open(f)
    for f in sorted(glob.glob("xai/images/" + target + "/3D/frame_*"), key=os.path.getmtime)
)
img = next(imgs)  # extract first image from iterator
img.save(
    fp="xai/images/" + target + "/3D/3D_attr.gif",
    format="GIF",
    append_images=imgs,
    save_all=True,
    duration=400,
    loop=0,
)

for i in range(0, N):
    os.remove("xai/images/" + target + "/3D/frame_0" + str(i) + "_.png")
