In [1]:
import os
import sys
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
sys.path.append(os.getcwd())

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from data.perovskite_dataset import PerovskiteDataset1d

from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from base_model import seed_worker
from os.path import join
from xai.utils.eval_methods import VisionSensitivityN, VisionInsertionDeletion

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
## Change npz data for different dimensions here ##
dim = "2D_time" # 1D, 2D_time, 2D_image, 3D

data = np.load("./xai/results/eval_" + dim + "_results.npz")
sensN = data['arr_0']
ins_abc = data['arr_1']
del_abc = data['arr_2']
infid = data['arr_3']
sens = data['arr_4']

log_n_max = 2.7
log_n_ticks = 0.4
n_list = np.logspace(0, log_n_max, int(log_n_max / log_n_ticks), base=10.0, dtype=int)

# Changes per Dimension
if dim == "1D":
    ins_abc = ins_abc.clip(None,1)
    infid = infid.clip(None,0.008)
    sens = sens.clip(None,8)
elif dim == "2D_time":
    infid = infid.clip(None,0.0015)
elif dim == "2D_image":
    ins_abc = ins_abc.clip(None,0.3)
    sens = sens.clip(None,3)
    infid = infid.clip(None,0.0006)
elif dim == "3D":
    pass

In [11]:
## Plot ##
import plotly.graph_objects as go
from plotly.subplots import make_subplots


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


fig = make_subplots(
    rows=1,
    cols=5,
    subplot_titles=(
        format_title("Importance", "Sensitivity-N " + "\u2197",),
        format_title("", "Insertion " + "\u2197",),
        format_title("", "Deletion " + "\u2197",),
        format_title("Robustness", "Sensitivity " + "\u2198",),
        format_title("", "Infidelity " + "\u2198",),
    ),
)


fig.add_trace(
    go.Scatter(
        y=sensN[0].mean(0),
        x=n_list,
        name="EG",
        marker_color="#042940",
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_traces(
    [
        go.Scatter(
            x=n_list,
            y=sensN[0].mean(0)
            + 1.960 * (np.std(sensN[0], axis=0) / sensN[0].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            showlegend=False,
        ),
        go.Scatter(
            x=n_list,
            y=sensN[0].mean(0)
            - 1.960 * (np.std(sensN[0], axis=0) / sensN[0].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            name="95% confidence interval",
            showlegend=False,
            fill="tonexty",
            fillcolor="rgba(4,41,64,0.2)",
        ),
    ]
)

fig.add_trace(
    go.Scatter(
        y=sensN[1].mean(0),
        x=n_list,
        name="IG",
        marker_color="#005C53",
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_traces(
    [
        go.Scatter(
            x=n_list,
            y=sensN[1].mean(0)
            + 1.960 * (np.std(sensN[1], axis=0) / sensN[1].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            showlegend=False,
        ),
        go.Scatter(
            x=n_list,
            y=sensN[1].mean(0)
            - 1.960 * (np.std(sensN[1], axis=0) / sensN[1].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            name="95% confidence interval",
            showlegend=False,
            fill="tonexty",
            fillcolor="rgba(0,92,83,0.2)",
        ),
    ]
)

fig.add_trace(
    go.Scatter(
        y=sensN[2].mean(0),
        x=n_list,
        name="GBC",
        marker_color="#9FC131",
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_traces(
    [
        go.Scatter(
            x=n_list,
            y=sensN[2].mean(0)
            + 1.960 * (np.std(sensN[2], axis=0) / sensN[2].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            showlegend=False,
        ),
        go.Scatter(
            x=n_list,
            y=sensN[2].mean(0)
            - 1.960 * (np.std(sensN[2], axis=0) / sensN[2].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            name="95% confidence interval",
            showlegend=False,
            fill="tonexty",
            fillcolor="rgba(159,193,49,0.2)",
        ),
    ]
)

fig.add_trace(
    go.Scatter(
        y=sensN[3].mean(0),
        x=n_list,
        name="GGC",
        marker_color="#DBF227",
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_traces(
    [
        go.Scatter(
            x=n_list,
            y=sensN[3].mean(0)
            + 1.960 * (np.std(sensN[3], axis=0) / sensN[3].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            showlegend=False,
        ),
        go.Scatter(
            x=n_list,
            y=sensN[3].mean(0)
            - 1.960 * (np.std(sensN[3], axis=0) / sensN[3].shape[0]),
            mode="lines",
            line_color="rgba(0,0,0,0)",
            name="95% confidence interval",
            showlegend=False,
            fill="tonexty",
            fillcolor="rgba(219,242,39,0.2)",
        ),
    ]
)

fig.add_trace(
    go.Box(y=ins_abc[0], name="EG", marker_color="#042940", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=ins_abc[1], name="IG", marker_color="#005C53", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=ins_abc[2], name="GBC", marker_color="#9FC131", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=ins_abc[3], name="GGC", marker_color="#DBF227", showlegend=False),
    row=1,
    col=2,
)

fig.add_trace(
    go.Box(y=del_abc[0], name="EG", marker_color="#042940", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=del_abc[1], name="IG", marker_color="#005C53", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=del_abc[2], name="GBC", marker_color="#9FC131", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=del_abc[3], name="GGC", marker_color="#DBF227", showlegend=False),
    row=1,
    col=3,
)

fig.add_trace(
    go.Box(y=sens[0].squeeze(), name="EG", marker_color="#042940"), row=1, col=4,
)
fig.add_trace(
    go.Box(y=sens[1].squeeze(), name="IG", marker_color="#005C53"), row=1, col=4,
)
fig.add_trace(
    go.Box(y=sens[2].squeeze(), name="GBC", marker_color="#9FC131"), row=1, col=4,
)
fig.add_trace(
    go.Box(y=sens[3].squeeze(), name="GGC", marker_color="#DBF227"), row=1, col=4,
)


fig.add_trace(
    go.Box(
        y=infid[0].squeeze(), name="EG", marker_color="#042940", showlegend=False
    ),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(
        y=infid[1].squeeze(), name="IG", marker_color="#005C53", showlegend=False
    ),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(
        y=infid[2].squeeze(), name="GBC", marker_color="#9FC131", showlegend=False
    ),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(
        y=infid[3].squeeze(), name="GGC", marker_color="#DBF227", showlegend=False
    ),
    row=1,
    col=5,
)


fig.update_xaxes(title="N", row=1, col=1)
fig.update_xaxes(title="Methods", row=1, col=2)

fig.update_yaxes(title="Pearson Correlation", row=1, col=1)
fig.update_yaxes(title="Area Between Curves", rangemode="tozero", row=1, col=2)
fig.update_yaxes(rangemode="tozero", row=1, col=3)
fig.update_yaxes(title="Score", rangemode="tozero", row=1, col=4)
fig.update_yaxes(rangemode="tozero", row=1, col=5)

fig.update_layout(
    title=format_title(
        "Attribution Evaluation",
        "Perovskite "+ dim +" Model (n = " + str(sensN[3].shape[0]) + ")",
    ),
    legend_title=None,
    legend={"traceorder": "normal"},
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=400,
    width=1600,
)

fig.write_image("xai/images/" + dim + "/" + dim +"_eval.png", scale=2)

fig.show()