In [None]:
from attribench.util import visualize_attributions
from util.datasets import unnormalize
from attribench.data import (
    AttributionsDataset,
    HDF5Dataset,
    GroupedAttributionsDataset,
)
import torch
import os

In [None]:
res_path = "../out/2023_08_22/"
dataset_name = "ImageNet"
idx = 95

In [None]:
ds_path = os.path.join(res_path, dataset_name)
samples_dataset = HDF5Dataset(path=os.path.join(ds_path, "samples.h5"))
attrs_dataset = GroupedAttributionsDataset(AttributionsDataset(
    samples=samples_dataset, path=os.path.join(ds_path, "attributions.h5")
))

In [None]:
idx, image, label, attrs = attrs_dataset[idx]
image = torch.tensor(image).unsqueeze(0)
image = unnormalize(image, dataset_name)
image = image.squeeze(0).permute(1, 2, 0)

In [None]:
fig = visualize_attributions(
    attributions=attrs,
    image=image,
    overlay=True
)

In [None]:
from tqdm import trange
from collections import defaultdict

for ds_name in ["MNIST", "FashionMNIST", "SVHN", "CIFAR10", "CIFAR100", "ImageNet", "Places365", "Caltech256"]:
    ds_path = os.path.join(res_path, ds_name)
    samples_dataset = HDF5Dataset(path=os.path.join(ds_path, "samples.h5"))
    attrs_dataset = GroupedAttributionsDataset(AttributionsDataset(
        samples=samples_dataset, path=os.path.join(ds_path, "attributions.h5")
    ))
    has_nans = defaultdict(lambda: False)

    for i in trange(256):
        idx, image, label, attrs = attrs_dataset[i]
        for method_name in attrs:
            has_nans[method_name] = has_nans[method_name] or torch.isnan(attrs[method_name]).any().item()

    for key in has_nans:
        if has_nans[key]:
            print(f"{ds_name}: {key} has nans")