In [2]:
import os
os.chdir("..")
os.getcwd()

'/home/l727n/Projects/ml_perovskite'

In [3]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/ml_perovskite/preprocessed"
checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

path_to_checkpoint = join(
    checkpoint_dir, "2D-epoch=999-val_MAE=0.000-train_MAE=0.289.ckpt"
)

# Import of model and computation of six different attribution methods with two evaluation metrics per method

In [4]:
#### 2D Model

hypparams = {
    "dataset": "Perov_2d",
    "dims": 2,
    "bottleneck": False,
    "name": "ResNet18",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[2, 2, 2, 2],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

dataset = PerovskiteDataset2d(
    data_dir,
    transform=normalize_2d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
)

batch_size = 256

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)

[0.23169845 0.00265788 0.00174048 0.00421168] [3.4151509e-02 3.0193795e-04 9.2120092e-05 9.2122407e-04]
Loaded


In [420]:
# Select observation
n = 2

x_batch = next(iter(loader))
x = x_batch[0][n]

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

y = float(np.round(y_batch[n].detach().numpy(),2))

In [421]:
# Init pertubation function for infidelity metric 

std_noise = 0.01

def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise

In [422]:
# Compute Attribution via expected gradients

from captum.attr import GradientShap
from captum.metrics import sensitivity_max, infidelity

gradient_shap = GradientShap(model)
attr_eg = gradient_shap.attribute(            
    x_batch[0][n].unsqueeze(0),
    n_samples=100,
    stdevs=0.001,
    baselines= x_batch[0],
    target=0,
)

infid_eg = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_eg)
sens_eg = sensitivity_max(gradient_shap.attribute, x_batch[0][n].unsqueeze(0), target = 0, baselines= x_batch[0]) # lower is better

In [423]:
# Integrated Gradients

from captum.attr import IntegratedGradients

ig = IntegratedGradients(model)
attr_ig, delta = ig.attribute(x_batch[0][n].unsqueeze(0), baselines=x_batch[0][n].unsqueeze(0) * 0, return_convergence_delta=True)

infid_ig = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_ig)
sens_ig = sensitivity_max(ig.attribute, x_batch[0][n].unsqueeze(0), target = 0, baselines= x_batch[0][n].unsqueeze(0) * 0)

In [424]:
# Guided Backprob

from captum.attr import GuidedBackprop

gbp = GuidedBackprop(model)
attr_gbp = gbp.attribute(x_batch[0][n].unsqueeze(0),target = 0)

infid_gbp = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_gbp)
sens_gbp = sensitivity_max(gbp.attribute, x_batch[0][n].unsqueeze(0))


Input Tensor 0 did not already require gradients, required_grads has been set automatically.


Setting backward hooks on ReLU activations.The hooks will be removed after the attribution is finished



In [425]:
# Guided GradCAM

from captum.attr import GuidedGradCam

ggc = GuidedGradCam(model, model.conv1)
attr_ggc = ggc.attribute(x_batch[0][n].unsqueeze(0),target = 0)
attr_ggc = attr_ggc.detach()

infid_ggc = infidelity(model, perturb_fn, x_batch[0][n].unsqueeze(0), attr_ggc)
sens_ggc = sensitivity_max(ggc.attribute, x_batch[0][n].unsqueeze(0))

# Visualization of single methods

In [426]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f'<b>{title}</b>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f'{title}<br>{subtitle}'



fig = make_subplots(rows=2, 
                    cols=4,
                    subplot_titles=(format_title("", "ND"), 
                    format_title("", "Attribution"),format_title("", "LP725"),format_title("", "Attribution"),format_title("", "LP780"), 
                    format_title("", "Attribution"),format_title("", "SP775"),format_title("", "Attribution"))
                    )

colors = [(0, "#F00B48"),(0.40, "#ffffff"),(0.60, "#ffffff"), (1, "#00BE34")]

fig.add_trace(go.Heatmap(z = x.numpy()[0], colorscale="gray", showscale=False), row=1, col=1)
fig.add_trace(go.Heatmap(z = attr_eg.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[1], colorscale="gray", showscale=False), row=1, col=3)
fig.add_trace(go.Heatmap(z = attr_eg.numpy().squeeze()[1],colorscale=colors, showscale=False), row=1, col=4)

fig.add_trace(go.Heatmap(z = x.numpy()[2], colorscale="gray", showscale=False), row=2, col=1)
fig.add_trace(go.Heatmap(z =attr_eg.numpy().squeeze()[2],colorscale=colors, showscale=False), row=2, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[3], colorscale="gray", showscale=False), row=2, col=3)
fig.add_trace(go.Heatmap(z = attr_eg.numpy().squeeze()[3],colorscale=colors, showscale=False), row=2, col=4)

fig.update_yaxes(autorange='reversed', showticklabels = False)
fig.update_xaxes(showticklabels = False)

fig.update_layout(
    title= format_title("Perovskite 1D Image Model",
    "Predicted PCE: " + str(y) + " / Method: Expected Gradients / Infidelity = " + str(*np.round(infid_eg.numpy(),4)) + " ("+ u"\u03C3" + "(" + u"\u03B5" + ") = " + str(std_noise) + ")" + " / Sensitivity = " + str(*np.round(sens_eg.numpy(),4))),
    title_y = 0.95,
    title_x = 0.087,
    height=600, width=900
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)

fig.write_image("xai/images/2D_image/2D_eg.png", scale=2)

fig.show()

In [427]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f'<b>{title}</b>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f'{title}<br>{subtitle}'



fig = make_subplots(rows=2, 
                    cols=4,
                    subplot_titles=(format_title("", "ND"), 
                    format_title("", "Attribution"),format_title("", "LP725"),format_title("", "Attribution"),format_title("", "LP780"), 
                    format_title("", "Attribution"),format_title("", "SP775"),format_title("", "Attribution"))
                    )

colors = [(0, "#F00B48"),(0.40, "#ffffff"),(0.60, "#ffffff"), (1, "#00BE34")]

fig.add_trace(go.Heatmap(z = x.numpy()[0], colorscale="gray", showscale=False), row=1, col=1)
fig.add_trace(go.Heatmap(z = attr_ig.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[1], colorscale="gray", showscale=False), row=1, col=3)
fig.add_trace(go.Heatmap(z = attr_ig.numpy().squeeze()[1],colorscale=colors, showscale=False), row=1, col=4)

fig.add_trace(go.Heatmap(z = x.numpy()[2], colorscale="gray", showscale=False), row=2, col=1)
fig.add_trace(go.Heatmap(z =attr_ig.numpy().squeeze()[2],colorscale=colors, showscale=False), row=2, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[3], colorscale="gray", showscale=False), row=2, col=3)
fig.add_trace(go.Heatmap(z = attr_ig.numpy().squeeze()[3],colorscale=colors, showscale=False), row=2, col=4)

fig.update_yaxes(autorange='reversed', showticklabels = False)
fig.update_xaxes(showticklabels = False)

fig.update_layout(
    title= format_title("Perovskite 1D Image Model",
    "Predicted PCE: " + str(y) + " / Method: Integrated Gradients / Infidelity = " + str(*np.round(infid_ig.numpy(),4)) + " ("+ u"\u03C3" + "(" + u"\u03B5" + ") = " + str(std_noise) + ")" + " / Sensitivity = " + str(*np.round(sens_ig.numpy(),4))),
    title_y = 0.95,
    title_x = 0.087,
    height=600, width=900
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)

fig.write_image("xai/images/2D_image/2D_ig.png", scale=2)

fig.show()

In [428]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f'<b>{title}</b>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f'{title}<br>{subtitle}'



fig = make_subplots(rows=2, 
                    cols=4,
                    subplot_titles=(format_title("", "ND"), 
                    format_title("", "Attribution"),format_title("", "LP725"),format_title("", "Attribution"),format_title("", "LP780"), 
                    format_title("", "Attribution"),format_title("", "SP775"),format_title("", "Attribution"))
                    )

colors = [(0, "#F00B48"),(0.40, "#ffffff"),(0.60, "#ffffff"), (1, "#00BE34")]

fig.add_trace(go.Heatmap(z = x.numpy()[0], colorscale="gray", showscale=False), row=1, col=1)
fig.add_trace(go.Heatmap(z = attr_gbp.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[1], colorscale="gray", showscale=False), row=1, col=3)
fig.add_trace(go.Heatmap(z = attr_gbp.numpy().squeeze()[1],colorscale=colors, showscale=False), row=1, col=4)

fig.add_trace(go.Heatmap(z = x.numpy()[2], colorscale="gray", showscale=False), row=2, col=1)
fig.add_trace(go.Heatmap(z =attr_gbp.numpy().squeeze()[2],colorscale=colors, showscale=False), row=2, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[3], colorscale="gray", showscale=False), row=2, col=3)
fig.add_trace(go.Heatmap(z = attr_gbp.numpy().squeeze()[3],colorscale=colors, showscale=False), row=2, col=4)

fig.update_yaxes(autorange='reversed', showticklabels = False)
fig.update_xaxes(showticklabels = False)

fig.update_layout(
    title= format_title("Perovskite 1D Image Model",
    "Predicted PCE: " + str(y) + " / Method: Guided Backprob / Infidelity = " + str(*np.round(infid_gbp.numpy(),4)) + " ("+ u"\u03C3" + "(" + u"\u03B5" + ") = " + str(std_noise) + ")" + " / Sensitivity = " + str(*np.round(sens_gbp.numpy(),4))),
    title_y = 0.95,
    title_x = 0.087,
    height=600, width=900
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)

fig.write_image("xai/images/2D_image/2D_gbp.png", scale=2)

fig.show()

In [429]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f'<b>{title}</b>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f'{title}<br>{subtitle}'

fig = make_subplots(rows=2, 
                    cols=4,
                    subplot_titles=(format_title("", "ND"), 
                    format_title("", "Attribution"),format_title("", "LP725"),format_title("", "Attribution"),format_title("", "LP780"), 
                    format_title("", "Attribution"),format_title("", "SP775"),format_title("", "Attribution"))
                    )

colors = [(0, "#F00B48"),(0.40, "#ffffff"),(0.60, "#ffffff"), (1, "#00BE34")]

fig.add_trace(go.Heatmap(z = x.numpy()[0], colorscale="gray", showscale=False), row=1, col=1)
fig.add_trace(go.Heatmap(z = attr_ggc.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[1], colorscale="gray", showscale=False), row=1, col=3)
fig.add_trace(go.Heatmap(z = attr_ggc.numpy().squeeze()[1],colorscale=colors, showscale=False), row=1, col=4)

fig.add_trace(go.Heatmap(z = x.numpy()[2], colorscale="gray", showscale=False), row=2, col=1)
fig.add_trace(go.Heatmap(z =attr_ggc.numpy().squeeze()[2],colorscale=colors, showscale=False), row=2, col=2)

fig.add_trace(go.Heatmap(z = x.numpy()[3], colorscale="gray", showscale=False), row=2, col=3)
fig.add_trace(go.Heatmap(z = attr_ggc.numpy().squeeze()[3],colorscale=colors, showscale=False), row=2, col=4)

fig.update_yaxes(autorange='reversed', showticklabels = False)
fig.update_xaxes(showticklabels = False)

fig.update_layout(
    title= format_title("Perovskite 1D Image Model",
    "Predicted PCE: " + str(y) + " / Method: Guided GradCAM / Infidelity = " + str(*np.round(infid_ggc.numpy(),4)) + " ("+ u"\u03C3" + "(" + u"\u03B5" + ") = " + str(std_noise) + ")" + " / Sensitivity = " + str(*np.round(sens_ggc.numpy(),4))),
    title_y = 0.95,
    title_x = 0.087,
    height=600, width=900
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)

fig.write_image("xai/images/2D_image/2D_ggc.png", scale=2)

fig.show()

In [430]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f'<b>{title}</b>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f'{title}<br>{subtitle}'



fig = make_subplots(rows=4, 
                    cols=5,
                    vertical_spacing = 0.1,
                    subplot_titles=(format_title("ND", "Original"),format_title("", "Expected Gradients"), format_title("", "Integrated Gradients"),format_title("", "Guided Backprob"),format_title("", "Guided GradCAM"),
                    format_title("LP725", "Original"),format_title("", "Expected Gradients"), format_title("", "Integrated Gradients"),format_title("", "Guided Backprob"),format_title("", "Guided GradCAM"),
                    format_title("LP780", "Original"),format_title("", "Expected Gradients"), format_title("", "Integrated Gradients"),format_title("", "Guided Backprob"),format_title("", "Guided GradCAM"),
                    format_title("SP775", "Original"),format_title("", "Expected Gradients"), format_title("", "Integrated Gradients"),format_title("", "Guided Backprob"),format_title("", "Guided GradCAM"))
                    )

colors = [(0, "#F00B48"),(0.40, "#ffffff"),(0.60, "#ffffff"), (1, "#00BE34")]

fig.add_trace(go.Heatmap(z = x.numpy()[0], colorscale="gray", showscale=False), row=1, col=1)
fig.add_trace(go.Heatmap(z = attr_eg.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=2)
fig.add_trace(go.Heatmap(z = attr_ig.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=3)
fig.add_trace(go.Heatmap(z = attr_gbp.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=4)
fig.add_trace(go.Heatmap(z = attr_ggc.numpy().squeeze()[0],colorscale=colors, showscale=False), row=1, col=5)

fig.add_trace(go.Heatmap(z = x.numpy()[1], colorscale="gray", showscale=False), row=2, col=1)
fig.add_trace(go.Heatmap(z = attr_eg.numpy().squeeze()[1],colorscale=colors, showscale=False), row=2, col=2)
fig.add_trace(go.Heatmap(z = attr_ig.numpy().squeeze()[1],colorscale=colors, showscale=False), row=2, col=3)
fig.add_trace(go.Heatmap(z = attr_gbp.numpy().squeeze()[1],colorscale=colors, showscale=False), row=2, col=4)
fig.add_trace(go.Heatmap(z = attr_ggc.numpy().squeeze()[1],colorscale=colors, showscale=False), row=2, col=5)

fig.add_trace(go.Heatmap(z = x.numpy()[2], colorscale="gray", showscale=False), row=3, col=1)
fig.add_trace(go.Heatmap(z = attr_eg.numpy().squeeze()[2],colorscale=colors, showscale=False), row=3, col=2)
fig.add_trace(go.Heatmap(z = attr_ig.numpy().squeeze()[2],colorscale=colors, showscale=False), row=3, col=3)
fig.add_trace(go.Heatmap(z = attr_gbp.numpy().squeeze()[2],colorscale=colors, showscale=False), row=3, col=4)
fig.add_trace(go.Heatmap(z = attr_ggc.numpy().squeeze()[2],colorscale=colors, showscale=False), row=3, col=5)

fig.add_trace(go.Heatmap(z = x.numpy()[3], colorscale="gray", showscale=False), row=4, col=1)
fig.add_trace(go.Heatmap(z = attr_eg.numpy().squeeze()[3],colorscale=colors, showscale=False), row=4, col=2)
fig.add_trace(go.Heatmap(z = attr_ig.numpy().squeeze()[3],colorscale=colors, showscale=False), row=4, col=3)
fig.add_trace(go.Heatmap(z = attr_gbp.numpy().squeeze()[3],colorscale=colors, showscale=False), row=4, col=4)
fig.add_trace(go.Heatmap(z = attr_ggc.numpy().squeeze()[3],colorscale=colors, showscale=False), row=4, col=5)

fig.update_yaxes(autorange='reversed', showticklabels = False)
fig.update_xaxes(showticklabels = False)

fig.update_layout(
    title= format_title("Method & Wavelength Comparision",
    "Perovskite 2D Image Model (Backflow patches excluded) / Predicted PCE: " + str(y)),
    title_y = 0.98,
    title_x = 0.08,
    height=1000, width=1000
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor='grey', mirror=True)

fig.write_image("xai/images/2D_image/2D_cmp.png", scale=2)

fig.show()