In [4]:
import pandas as pd
import json
from pathlib import Path
import shap
import numpy as np
from shap.plots import colors
import torch
import matplotlib.pyplot as plt

In [33]:
for path in Path("double_path").glob("*.json"):
    print(path)
    with open(path, "r") as f:
        data = json.load(f)
        df = pd.json_normalize(data)
        df  = df.explode(df.columns.to_list()).melt(id_vars="gt_uncertainty")
        df[['metric', 'targets']] = df['variable'].str.split('.', n=1, expand=True)
        df[['heatmap', 'mask']] = df['targets'].str.extract(r'^(.*?_[^_]*?)_(.*)$')
        df.drop(columns=["variable", "targets"]).to_csv(path.with_suffix(".csv"), index=False)

double_path/mnist_plus_infoshap_500_double_path_extended2.json
double_path/mnist_plus_shap_500_double_path_extended2.json
double_path/mnist_plus_lrp_zennit_double_path_a1_b0_extended2_zero_bias.json
double_path/mnist_plus_lrp_zennit_double_path_a1_b0_extended2.json
double_path/mnist_plus_ig_double_path_extended2.json
double_path/localization_results_combined_mnist_clue.json
double_path/mnist_plus_gradcam_double_path_extended2.json


In [77]:
img1 = torch.load("lrp/attribution_mean_1520_a1_b0_extended2_input.pt").unsqueeze(-1).numpy()
img2 = torch.load("lrp/attribution_variance_1520_a1_b0_extended2_input.pt").squeeze().numpy()


plt.imsave("lrp/attribution_mean_1520_a1_b0_extended2_input.png", img1, cmap="gray")
plt.imsave("lrp/attribution_variance_1520_a1_b0_extended2_input.png", img2, cmap="gray")

In [None]:
img = torch.load("lrp/attribution_mean_587_a1_b0_extended2_input.pt").squeeze().numpy()


plt.imsave("lrp/attribution_mean_587_a1_b0_extended2_input.png", img, cmap="gray")

In [51]:
from PIL import Image
for path in Path("lrp").glob("*.png"):
    print(path)
    image = Image.open(path)
    new_image = image.resize((250, 250))
    new_image.save(path.with_name(path.stem + "_250.png"))


lrp/attribution_mean_1520_a1_b0_extended2.png
lrp/attribution_variance_1520_a1_b0_extended2_input.png
lrp/attribution_mean_1520_a1_b0_extended2_input.png
lrp/attribution_variance_1520_a1_b0_extended2.png


In [31]:
i = 85922

In [263]:
lrp_mean = torch.load(f"lrp/mnist_plus_lrp_zennit_double_path_a1_b0_extended2_mean_{i}.pt", map_location=torch.device('cpu')).numpy() #.unsqueeze(0).unsqueeze(-1).numpy()
lrp_var = torch.load(f"lrp/mnist_plus_lrp_zennit_double_path_a1_b0_extended2_variance_{i}.pt", map_location=torch.device('cpu')).numpy()
img1 = torch.load(f"lrp/mnist_plus_lrp_zennit_double_path_a1_b0_extended2_input_{i}.pt", map_location=torch.device('cpu')).squeeze().numpy()

In [264]:
max_val_var = np.nanpercentile(np.abs(lrp_var), 99.9)
max_val_mean = np.nanpercentile(np.abs(lrp_mean), 99.9)

fig = plt.figure(figsize=(1, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img1, cmap="gray", alpha=0.8)
ax.imshow(lrp_var, cmap=colors.red_transparent_blue, vmin=-max_val_var, vmax=max_val_var, alpha=0.6)
fig.savefig(f"lrp_var_{i}.png", dpi=600) 
plt.close(fig)

fig = plt.figure(figsize=(1, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img1, cmap="gray", alpha=0.8)
ax.imshow(lrp_mean, cmap=colors.red_transparent_blue, vmin=-max_val_mean, vmax=max_val_mean, alpha=0.6)
fig.savefig(f"lrp_mean_{i}.png", dpi=600) 
plt.close(fig)

In [265]:
shap_mean = torch.load(f"shap/mnist_plus_shap_500_double_path_extended2_mean_{i}.pt").numpy() #.unsqueeze(0).unsqueeze(-1).numpy()
shap_var = torch.load(f"shap/mnist_plus_shap_500_double_path_extended2_variance_{i}.pt").numpy()
img1 = torch.load(f"shap/mnist_plus_shap_500_double_path_extended2_input_{i}.pt", map_location=torch.device('cpu')).squeeze().numpy()

In [266]:
max_val_var = np.nanpercentile(np.abs(shap_var), 99.9)
max_val_mean = np.nanpercentile(np.abs(shap_mean), 99.9)

fig = plt.figure(figsize=(1, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img1, cmap="gray", alpha=0.8)
ax.imshow(shap_var, cmap=colors.red_transparent_blue, vmin=-max_val_var, vmax=max_val_var, alpha=0.6)
fig.savefig(f"shap_var_{i}.png", dpi=600) 
plt.close(fig)

In [47]:
ig_mean = torch.load(f"ig/mnist_plus_ig_double_path_extended2_mean_{i}_new.pt", map_location=torch.device('cpu')).detach().numpy() #.unsqueeze(0).unsqueeze(-1).numpy()
ig_var = torch.load(f"ig/mnist_plus_ig_double_path_extended2_variance_{i}_new.pt", map_location=torch.device('cpu')).detach().numpy()
img1 = torch.load(f"ig/mnist_plus_ig_double_path_extended2_input_{i}_new.pt", map_location=torch.device('cpu')).detach().squeeze().numpy()

In [48]:
max_val_var = np.nanpercentile(np.abs(ig_var), 99.9)
max_val_mean = np.nanpercentile(np.abs(ig_mean), 99.9)

fig = plt.figure(figsize=(1, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img1, cmap="gray", alpha=0.8)
ax.imshow(ig_var, cmap=colors.red_transparent_blue, vmin=-max_val_var, vmax=max_val_var, alpha=0.6)
fig.savefig(f"ig_var_{i}.png", dpi=600) 
plt.close(fig)


fig = plt.figure(figsize=(1, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img1, cmap="gray", alpha=0.8)
ax.imshow(ig_mean, cmap=colors.red_transparent_blue, vmin=-max_val_mean, vmax=max_val_mean, alpha=0.6)
fig.savefig(f"ig_mean_{i}.png", dpi=600) 
plt.close(fig)

In [269]:
infoshap_mean = torch.load(f"infoshap/mnist_plus_infoshap_500_double_path_extended2_mean_{i}.pt", map_location=torch.device('cpu')).numpy() #.unsqueeze(0).unsqueeze(-1).numpy()
infoshap_var = torch.load(f"infoshap/mnist_plus_infoshap_500_double_path_extended2_variance_{i}.pt", map_location=torch.device('cpu')).numpy()
img1 = torch.load(f"infoshap/mnist_plus_infoshap_500_double_path_extended2_input_{i}.pt", map_location=torch.device('cpu')).squeeze().numpy()

In [270]:
max_val_var = np.nanpercentile(np.abs(infoshap_var), 99.9)
max_val_mean = np.nanpercentile(np.abs(infoshap_mean), 99.9)

fig = plt.figure(figsize=(1, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img1, cmap="gray", alpha=0.8)
ax.imshow(infoshap_var, cmap=colors.red_transparent_blue, vmin=-max_val_var, vmax=max_val_var, alpha=0.6)
fig.savefig(f"infoshap_var_{i}.png", dpi=600) 
plt.close(fig)

In [271]:
clue_var = np.load("clue/VAR_HEATMAP.npy")[0]

In [272]:
max_val_var = np.nanpercentile(np.abs(clue_var), 99.9)

fig = plt.figure(figsize=(1, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img1, cmap="gray", alpha=0.8)
ax.imshow(clue_var, cmap=colors.red_transparent_blue, vmin=-max_val_var, vmax=max_val_var, alpha=0.6)
fig.savefig(f"clue_var_{i}.png", dpi=600) 
plt.close(fig)

In [7]:
for image_path in Path("image_mask_examples").glob("*.pt"):
    img = torch.load(image_path, map_location=torch.device('cpu')).squeeze().numpy()
    name = image_path.stem
    
    fig = plt.figure(figsize=(1, 1))
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(img, cmap="gray", alpha=0.8)
    fig.savefig(f"image_mask_examples/{name}.png", dpi=600) 
    plt.close(fig)
