### Imports & Paths

In [None]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth"  # mth, pce

In [None]:
import torch
import torch.nn as nn
import numpy as np
import scipy
import plotly.graph_objects as go

from torch.utils.data import DataLoader
from tqdm import tqdm
from captum.attr import GradientShap, IntegratedGradients, GuidedBackprop, GuidedGradCam
from captum.metrics import sensitivity_max, infidelity
from os.path import join
from plotly.subplots import make_subplots

from data.perovskite_dataset import PerovskiteDataset1d
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from data.augmentations.perov_1d import normalize
from base_model import seed_worker

data_dir = os.getcwd() + "/preprocessed"

if target == "pce":
    checkpoint_dir = (
        "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"
    )

    path_to_checkpoint = join(checkpoint_dir, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt")
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(checkpoint_dir, "mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt")


### Model Init

In [None]:
#### 1D Model (no border)

hypparams = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "stochastic_depth": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness",
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

test_set = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
    fold=None,
    split="test",
    val=False,
)

train_set = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

batch_size = 100

loader = DataLoader(
    torch.utils.data.ConcatDataset([train_set, test_set]),
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)


In [None]:
# Select batch
x_batch = next(iter(loader))

with torch.no_grad():
    y_batch = model.predict(x_batch).flatten()

x_batch = x_batch[0]

# Init pertubation function for infidelity metric

std_noise = 0.01


def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, std_noise, inputs.shape)).float()
    return noise, inputs - noise


### Global Attribution Computation and Evaluation

#### Expected Gradients

In [None]:
method = GradientShap(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(
        x_batch[n].unsqueeze(0),
        n_samples=80,
        stdevs=0.001,
        baselines=x_batch,
        target=0,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0), target=0, baselines=x_batch))

attr_eg = torch.cat(attr_sum).mean(dim=0)
infid_eg = torch.Tensor(infid_sum).mean()
sens_eg = torch.Tensor(sens_sum).mean()


#### Integrated Gradients

In [None]:
method = IntegratedGradients(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr, delta = method.attribute(
        x_batch[n].unsqueeze(0),
        baselines=x_batch[n].unsqueeze(0) * 0,
        return_convergence_delta=True,
    )

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(
        sensitivity_max(
            method.attribute,
            x_batch[n].unsqueeze(0),
            target=0,
            baselines=x_batch[n].unsqueeze(0) * 0,
        )
    )

attr_ig = torch.cat(attr_sum).mean(dim=0)
infid_ig = torch.Tensor(infid_sum).mean()
sens_ig = torch.Tensor(sens_sum).mean()


#### Guided Backprob

In [None]:
method = GuidedBackprop(model)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_gbp = torch.cat(attr_sum).mean(dim=0)
infid_gbp = torch.Tensor(infid_sum).mean()
sens_gbp = torch.Tensor(sens_sum).mean()


#### Guided GradCAM

In [None]:
method = GuidedGradCam(model, model.layer4)
attr_sum = []
infid_sum = []
sens_sum = []

for n in tqdm(range(x_batch.shape[0])):
    attr = method.attribute(x_batch[n].unsqueeze(0), target=0)

    attr_sum.append(attr.detach().abs())

    infid_sum.append(infidelity(model, perturb_fn, x_batch[n].unsqueeze(0), attr))
    sens_sum.append(sensitivity_max(method.attribute, x_batch[n].unsqueeze(0)))

attr_ggc = torch.cat(attr_sum).mean(dim=0)
infid_ggc = torch.Tensor(infid_sum).mean()
sens_ggc = torch.Tensor(sens_sum).mean()


## Attribution Visualization

In [None]:
#### Paper Vis ####
color = ["#E1462C", "#0059A0", "#5F3893", "#FF8777", "#0A2C6E", "#CEDEEB"]
filter = ["ND", "LP725", "LP780", "SP775"]

fig = make_subplots(
    rows=1,
    cols=1,
    specs=[
        [{"secondary_y": True}],
    ],
)

fig.add_trace(
    go.Scatter(
        y=np.abs(attr_eg.squeeze().sum(axis=0)),
        name="Attribution",
        marker_color=color[2],
        showlegend=False,
        # marker_line_width=0.0,
        stackgroup="one",
    ),
    secondary_y=True,
)

x = x_batch.mean(0)

for i in range(4):
    fig.add_trace(
        go.Scatter(
            y=x[i],
            name=filter[i],
            opacity=0.5,
            marker_color="grey",
            line=dict(width=2.5, dash="dot") if i == 0 else dict(width=2.5),
            showlegend=False,
        )
    )


fig.update_yaxes(
    title="Intensity",
    showticklabels=True,
    range=[-2.05, 2.05],
    tickfont=dict(size=14, family="Helvetica", color="rgb(0,0,0)"),
)
fig.update_yaxes(
    zeroline=False,
    title_text=" ",
    showticklabels=True,
    showgrid=False,
    secondary_y=True,
    range=[0, 0.24],
    tickvals=[0, 0.06, 0.12, 0.18, 0.24],
)
fig.update_xaxes(
    zeroline=False, title="Timesteps", showgrid=False, tickfont=dict(size=14, family="Helvetica", color="rgb(0,0,0)")
)


fig.update_layout(
    showlegend=False,
    bargap=0,
    bargroupgap=0,
    legend_title=" ",
    template="plotly_white",
    height=400,
    width=700,
)

fig.write_image("xai/images/" + target + "/1D/1D_paper.png", scale=2)

fig.show()


In [None]:
#### Large Overview Figure ####


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


filter = ["ND", "LP725", "LP780", "SP775"]

x = x_batch.mean(0)

fig = make_subplots(
    rows=2,
    cols=2,
    specs=[
        [{"secondary_y": True}, {"secondary_y": True}],
        [{"secondary_y": True}, {"secondary_y": True}],
    ],
    subplot_titles=(
        format_title(
            "",
            "Expected Gradients ("
            + str(np.round(infid_eg.numpy(), 4))
            + ", "
            + str(np.round(sens_eg.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Integrated Gradients ("
            + str(np.round(infid_ig.numpy(), 4))
            + ", "
            + str(np.round(sens_ig.numpy(), 4))
            + ")",
        ),
        format_title(
            "",
            "Guided Backprop (" + str(np.round(infid_gbp.numpy(), 4)) + ", " + str(np.round(sens_gbp.numpy(), 4)) + ")",
        ),
        format_title(
            "",
            "Guided GradCAM (" + str(np.round(infid_ggc.numpy(), 4)) + ", " + str(np.round(sens_ggc.numpy(), 4)) + ")",
        ),
    ),
)

for row in range(2):
    for col in range(2):
        for i in range(4):
            fig.add_trace(
                go.Scatter(y=x[i], name=filter[i], marker_color="grey", opacity=0.3, showlegend=False),
                row=row + 1,
                col=col + 1,
            )

fig.add_trace(
    go.Bar(
        y=attr_eg.squeeze().sum(axis=0),
        marker_color="#042940",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=1,
)

fig.add_trace(
    go.Bar(
        y=attr_ig.squeeze().sum(axis=0),
        marker_color="#005C53",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=1,
    col=2,
)

fig.add_trace(
    go.Bar(
        y=attr_gbp.squeeze().sum(axis=0),
        marker_color="#9FC131",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=1,
)

fig.add_trace(
    go.Bar(
        y=attr_ggc.squeeze().sum(axis=0),
        marker_color="#DBF227",
        opacity=0.5,
        showlegend=False,
        marker_line_width=0,
    ),
    secondary_y=True,
    row=2,
    col=2,
)

fig.update_yaxes(title_text=None, secondary_y=True)
fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=1, col=2)
fig.update_yaxes(title="Intensity", row=2, col=1)
fig.update_yaxes(title_text="Attribution", secondary_y=True, row=2, col=2)

fig.update_xaxes(title=None)
fig.update_xaxes(title="Timesteps", row=2, col=1)
fig.update_xaxes(title="Timesteps", row=2, col=2)

if target == "pce":
    subtitle = "Perovskite 1D Model / Target: PCE /(mean Infidelity, mean Sensitivity)"
else:
    subtitle = "Perovskite 1D Model / Target: Mean Thickness / (mean Infidelity, mean Sensitivity)"

fig.update_layout(
    title=format_title(
        "Global Attribution: Mean abs. Attribution (n = " + str(batch_size) + ")",
        subtitle=subtitle,
    ),
    legend_title=None,
    title_y=0.965,
    title_x=0.035,
    template="plotly_white",
    height=500,
    width=1000,
)

fig.write_image("xai/images/" + target + "/1D/1D_cmp_global.png", scale=2)

fig.show()


#### Filter Importance Plot

In [None]:
def add_p_value_annotation(
    fig, array_columns, p_value, subplot=None, _format=dict(interline=0.07, text_height=1.07, color="black")
):
    """Adds notations giving the p-value between two box plot data (t-test two-sided comparison)

    Parameters:
    ----------
    fig: figure
        plotly boxplot figure
    array_columns: np.array
        array of which columns to compare
        e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
    subplot: None or int
        specifies if the figures has subplots and what subplot to add the notation to
    _format: dict
        format characteristics for the lines

    Returns:
    -------
    fig: figure
        figure with the added notation
    """
    # Specify in what y_range to plot for each pair of columns
    y_range = np.zeros([len(array_columns), 2])
    for i in range(len(array_columns)):
        y_range[i] = [1.01 + i * _format["interline"], 1.02 + i * _format["interline"]]

    # Get values from figure
    fig_dict = fig.to_dict()

    # Get indices if working with subplots
    if subplot:
        if subplot == 1:
            subplot_str = ""
        else:
            subplot_str = str(subplot)
        indices = []  # Change the box index to the indices of the data for that subplot
        for index, data in enumerate(fig_dict["data"]):
            # print(index, data['xaxis'], 'x' + subplot_str)
            if data["xaxis"] == "x" + subplot_str:
                indices = np.append(indices, index)
        indices = [int(i) for i in indices]
        print((indices))
    else:
        subplot_str = ""

    # Print the p-values
    for index, column_pair in enumerate(array_columns):
        # Mare sure it is selecting the data and subplot you want
        # print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
        # print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])

        # Get the p-value
        pvalue = p_value[index]
        if pvalue >= 0.1:
            symbol = "ns"
        elif pvalue >= 0.05:
            symbol = "*"
        elif pvalue >= 0.01:
            symbol = "**"
        else:
            symbol = "***"
        # # Vertical line
        fig.add_shape(
            type="line",
            xref="x" + subplot_str,
            yref="y" + subplot_str + " domain",
            x0=column_pair[0],
            y0=y_range[index][0] + 0.01,
            x1=column_pair[0],
            y1=y_range[index][1] - 0.03,
            line=dict(
                color=_format["color"],
                width=2,
            ),
        )
        # Horizontal line
        fig.add_shape(
            type="line",
            xref="x" + subplot_str,
            yref="y" + subplot_str + " domain",
            x0=column_pair[0],
            y0=y_range[index][1],
            x1=column_pair[1],
            y1=y_range[index][1],
            line=dict(
                color=_format["color"],
                width=2,
            ),
        )
        # #Vertical line
        fig.add_shape(
            type="line",
            xref="x" + subplot_str,
            yref="y" + subplot_str + " domain",
            x0=column_pair[1],
            y0=y_range[index][0] + 0.01,
            x1=column_pair[1],
            y1=y_range[index][1] - 0.03,
            line=dict(
                color=_format["color"],
                width=2,
            ),
        )
        ## add text at the correct x, y coordinates
        ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
        fig.add_annotation(
            dict(
                font=dict(color=_format["color"], size=14),
                x=(column_pair[0] + column_pair[1]) / 2,
                y=y_range[index][1] * _format["text_height"],
                showarrow=False,
                text=symbol,
                textangle=0,
                xref="x" + subplot_str,
                yref="y" + subplot_str + " domain",
            )
        )
    return fig


In [None]:
label = "mth"
color = ["#E1462C", "#FF8777", "#0059A0", "#5F3893"]
filter = ["ND", "LP725", "LP780", "SP775"]
fig = go.Figure()

if label == "pce":
    attr_pce = np.load("xai/results/eg_abs_global_pce.npy")

    fig.add_traces(
        go.Bar(
            x=filter,
            y=attr_pce.sum(axis=1),
            text=[
                str(i) + " ± " + str(np.round(j, 3))
                for i, j in zip(np.round(attr_pce.sum(axis=1), 2), attr_pce.std(axis=1))
            ],
            textposition="outside",
            textfont=dict(size=16, family="Helvetica", color="rgb(0,0,0)"),
            marker_color=color,
        )
    )

    fig.update_yaxes(zerolinewidth=4, range=[0, 4.5])

    f1 = scipy.stats.f_oneway(attr_pce[0], attr_pce[1], attr_pce[2], attr_pce[3], axis=0)[1]
    f2 = scipy.stats.f_oneway(attr_pce[1], attr_pce[2], attr_pce[3], axis=0)[1]
    fig = add_p_value_annotation(
        fig,
        array_columns=[[0, 3], [1, 3]],
        p_value=[f1, f2],
        _format=dict(interline=0.09, text_height=1.08, color="black"),
    )

else:
    attr_mth = np.load("xai/results/eg_abs_global_mth.npy")

    fig.add_traces(
        go.Bar(
            x=filter,
            y=attr_mth.sum(axis=1),
            text=[
                str(int(i)) + " ± " + str(np.round(j, 3))
                for i, j in zip(np.round(attr_mth.sum(axis=1)), attr_mth.std(axis=1))
            ],
            textposition="outside",
            textfont=dict(size=16, family="Helvetica", color="rgb(0,0,0)"),
            marker_color=color,
        )
    )

    fig.update_yaxes(zerolinewidth=4, range=[0, 450])

    t1 = scipy.stats.ttest_ind(a=attr_mth[1], b=attr_mth[3], equal_var=False)[1]
    t2 = scipy.stats.ttest_ind(a=attr_mth[0], b=attr_mth[3], equal_var=False)[1]
    fig = add_p_value_annotation(
        fig,
        array_columns=[[1, 3], [0, 3]],
        p_value=[t1, t2],
        _format=dict(interline=0.09, text_height=1.08, color="black"),
    )

fig.update_layout(
    showlegend=False,
    legend_title=" ",
    template="plotly_white",
    height=400,
    width=500,
    font=dict(family="Helvetica", color="#000000", size=14),
)

fig.write_image("xai/images/" + label + "/1D/1D_global_filter.png", scale=2)

fig.show()
