In [1]:
import os
import sys
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
sys.path.append(os.getcwd())

import torch
import torch.nn as nns
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from data.perovskite_dataset import PerovskiteDataset1d

from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from base_model import seed_worker
from os.path import join
from xai.utils.eval_methods import VisionSensitivityN, VisionInsertionDeletion

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
## Change npz data for different dimensions here ##
dim = "1D" # 1D, 2D_time, 2D_image, 3D
target = "mth" # pce, mth

data = np.load("./xai/results/" + target + "_eval_" + dim + "_results.npz")
sensN = data['arr_0']
ins_abc = data['arr_1']
del_abc = data['arr_2']
infid = data['arr_3']
sens = data['arr_4']

data = np.load("./xai/results/" + target + "_eval_" + dim + "_rand.npz")
sensN_rand = data['arr_0']
ins_abc_rand = data['arr_1']
del_abc_rand = data['arr_2']

log_n_max = 2.7
log_n_ticks = 0.4
n_list = np.logspace(0, log_n_max, int(log_n_max / log_n_ticks), base=10.0, dtype=int)

# Changes per Dimension
if target == "pce":
    if dim == "1D":
        ins_abc = ins_abc.clip(None,1)
        #infid = infid.clip(None,0.008)
        #sens = sens.clip(None,8)
    elif dim == "2D_time":
        pass
        #infid = infid.clip(None,0.0015)
    elif dim == "2D_image":
        ins_abc = ins_abc.clip(None,0.3)
        del_abc = del_abc.clip(None,1.0)
        #sens = sens.clip(None,3)
        #infid = infid.clip(None,0.0006)
    elif dim == "3D":
        ins_abc = ins_abc.clip(None, 0.1)
        #infid = infid.clip(None, 0.0001)
else:
    if dim == "1D":
        pass
        #infid = infid.clip(None,50)
        #sens = sens.clip(None,10)
    elif dim == "2D_time":
        ins_abc = ins_abc.clip(None,1)
        del_abc = del_abc.clip(None,6)
        #sens = sens.clip(None,6)
        #infid = infid.clip(None,2)
    elif dim == "2D_image":
        ins_abc = ins_abc.clip(None,25)
        del_abc = del_abc.clip(None,80)
        #sens = sens.clip(None,3)
        #infid = infid.clip(None,6)

In [25]:
## Plot ##
import plotly.graph_objects as go
from plotly.subplots import make_subplots


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

if target == "pce":
    title = "Attribution Evaluation PCE"
else:
    title = "Attribution Evaluation Mean Thickness"


fig = make_subplots(
    rows=1,
    cols=5
)

cd = ["#E1462C", "#0059A0", "#5F3893", "#FF8777","#0A2C6E", "#CEDEEB"]

fig.add_trace(
    go.Scatter(
        y=sensN[0].mean(0),
        x=n_list,
        name="EG",
        marker_color=cd[1],
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=sensN[1].mean(0),
        x=n_list,
        name="IG",
        marker_color=cd[2],
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=sensN[2].mean(0),
        x=n_list,
        name="GBP",
        marker_color=cd[0],
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=sensN[3].mean(0),
        x=n_list,
        name="GGC",
        marker_color=cd[3],
        showlegend=False,
        mode="lines+markers",
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        y=sensN_rand,
        x=n_list,
        name="Random Baseline",
        marker_color="grey",
        showlegend=False,
        mode="lines+markers",
        line_dash="dash",
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Box(y=ins_abc[0], name="EG", marker_color=cd[1], showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=ins_abc[1], name="IG", marker_color=cd[2], showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=ins_abc[2], name="GBP", marker_color=cd[0], showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Box(y=ins_abc[3], name="GGC", marker_color=cd[3], showlegend=False),
    row=1,
    col=2,
)

fig.add_hline(y=ins_abc_rand,fillcolor="grey",line_dash="dash",row=1,col=2)



fig.add_trace(
    go.Box(y=del_abc[0], name="EG", marker_color=cd[1], showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=del_abc[1], name="IG", marker_color=cd[2], showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=del_abc[2], name="GBP", marker_color=cd[0], showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Box(y=del_abc[3], name="GGC", marker_color=cd[3], showlegend=False),
    row=1,
    col=3,
)

fig.add_hline(y=del_abc_rand,fillcolor="grey",line_dash="dash",row=1,col=3)

fig.add_trace(
    go.Box(y=sens[0].squeeze(), name="EG", marker_color=cd[1]), row=1, col=4,
)
fig.add_trace(
    go.Box(y=sens[1].squeeze(), name="IG", marker_color=cd[2]), row=1, col=4,
)
fig.add_trace(
    go.Box(y=sens[2].squeeze(), name="GBP", marker_color=cd[0]), row=1, col=4,
)
fig.add_trace(
    go.Box(y=sens[3].squeeze(), name="GGC", marker_color=cd[3]), row=1, col=4,
)


fig.add_trace(
    go.Box(
        y=infid[0].squeeze(), name="EG", marker_color=cd[1], showlegend=False
    ),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(
        y=infid[1].squeeze(), name="IG", marker_color=cd[2], showlegend=False
    ),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(
        y=infid[2].squeeze(), name="GBP", marker_color=cd[0], showlegend=False
    ),
    row=1,
    col=5,
)
fig.add_trace(
    go.Box(
        y=infid[3].squeeze(), name="GGC", marker_color=cd[3], showlegend=False
    ),
    row=1,
    col=5,
)



fig.update_xaxes(title="N", row=1, col=1)
fig.update_xaxes(title="Methods", row=1, col=2)

fig.update_yaxes(title="Pearson Correlation", row=1, col=1)
fig.update_yaxes(title="Area Between Curves", rangemode="tozero", row=1, col=2)
fig.update_yaxes(rangemode="tozero", row=1, col=3)
fig.update_yaxes(title=" ", rangemode="tozero",type="log",tickvals = [0.1,1,10,100], row=1, col=4)
fig.update_yaxes(rangemode="tozero",type="log",range = [-3.9,4.1],row=1, col=5)

fig.update_layout(
    # title=format_title(
    #     title,
    #     "Perovskite "+ dim +" Model (n = " + str(sensN[3].shape[0]) + ")",
    # ),
    legend_title=None,
    legend={"traceorder": "normal"},
    title_y=0.96,
    title_x=0.035,
    template="plotly_white",
    height=400,
    width=1600,
)

fig.write_image("xai/images/"+ target + "/" + dim + "/" + dim +"_eval.png", scale=2)

fig.show()