In [1]:
# This block allows us to import from the benchmark folder,
# as if it was a package installed using pip
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from captum import attr
from attrbench import models, attribution
from tqdm import tqdm
import torch

In [3]:
class TestCifar(Dataset):
    def __init__(self):
        self.cifar10 = datasets.CIFAR10(root="../data/CIFAR10/", download=False, train=False, transform=transforms.ToTensor())
    
    def __getitem__(self, index):
        data, target = self.cifar10[index]
        return data, target, index
    
    def __len__(self):
        return len(self.cifar10)

In [4]:
ds = TestCifar()
dl = DataLoader(ds, batch_size=256, shuffle=False, drop_last=True, num_workers=4)

In [5]:
device = "cuda"
model = models.Resnet(version="resnet18", output_logits=True, num_classes=10,
                      params_loc="../data/models/CIFAR10/resnet18.pt")
model.to(device)
model.eval()
pass

In [6]:
#saliency = attribution.GradCAM(model, model.get_last_conv_layer(), (32, 32))
saliency = attribution.GuidedGradCAM(model, model.get_last_conv_layer(), (32, 32))
#saliency = attribution.OldGuidedGradCAM(model, model.get_last_conv_layer())

In [7]:
bad_idxs = []
for batch_idx, (data, target, idx) in enumerate(tqdm(dl)):
    data = data.to(device)
    target = target.to(device)
    attrs = saliency(data, target)
    
    abs_attrs = torch.abs(attrs.flatten(1))
    max_abs_per_img = torch.max(abs_attrs, dim=1)[0]
    if torch.any(max_abs_per_img == 0):
        bad_idxs.append(idx[torch.where(max_abs_per_img == 0)])
if len(bad_idxs) > 0:
    bad_idxs = torch.cat(bad_idxs)

  "required_grads has been set automatically." % index
  "See the documentation of nn.Upsample for details.".format(mode))
  "Setting backward hooks on ReLU activations."




  3%|▎         | 1/39 [00:00<00:34,  1.09it/s]



  5%|▌         | 2/39 [00:01<00:31,  1.16it/s]



  8%|▊         | 3/39 [00:02<00:29,  1.21it/s]



 10%|█         | 4/39 [00:03<00:28,  1.24it/s]



 13%|█▎        | 5/39 [00:03<00:26,  1.27it/s]



 15%|█▌        | 6/39 [00:04<00:25,  1.29it/s]



 18%|█▊        | 7/39 [00:05<00:24,  1.30it/s]



 21%|██        | 8/39 [00:06<00:23,  1.31it/s]



 23%|██▎       | 9/39 [00:06<00:22,  1.32it/s]



 26%|██▌       | 10/39 [00:07<00:21,  1.32it/s]



 28%|██▊       | 11/39 [00:08<00:21,  1.33it/s]



 31%|███       | 12/39 [00:09<00:20,  1.33it/s]



 33%|███▎      | 13/39 [00:09<00:19,  1.33it/s]



 36%|███▌      | 14/39 [00:10<00:18,  1.33it/s]



 38%|███▊      | 15/39 [00:11<00:18,  1.33it/s]



 41%|████      | 16/39 [00:12<00:17,  1.33it/s]



 44%|████▎     | 17/39 [00:12<00:16,  1.33it/s]



 46%|████▌     | 18/39 [00:13<00:15,  1.33it/s]



 49%|████▊     | 19/39 [00:14<00:14,  1.33it/s]



 51%|█████▏    | 20/39 [00:15<00:14,  1.33it/s]



 54%|█████▍    | 21/39 [00:15<00:13,  1.33it/s]



 56%|█████▋    | 22/39 [00:16<00:12,  1.33it/s]



 59%|█████▉    | 23/39 [00:17<00:12,  1.33it/s]



 62%|██████▏   | 24/39 [00:18<00:11,  1.33it/s]



 64%|██████▍   | 25/39 [00:18<00:10,  1.33it/s]



 67%|██████▋   | 26/39 [00:19<00:09,  1.33it/s]



 69%|██████▉   | 27/39 [00:20<00:09,  1.33it/s]



 72%|███████▏  | 28/39 [00:21<00:08,  1.33it/s]



 74%|███████▍  | 29/39 [00:21<00:07,  1.33it/s]



 77%|███████▋  | 30/39 [00:22<00:06,  1.33it/s]



 79%|███████▉  | 31/39 [00:23<00:06,  1.33it/s]



 82%|████████▏ | 32/39 [00:24<00:05,  1.33it/s]



 85%|████████▍ | 33/39 [00:24<00:04,  1.33it/s]



 87%|████████▋ | 34/39 [00:25<00:03,  1.33it/s]



 90%|████████▉ | 35/39 [00:26<00:03,  1.33it/s]



 92%|█████████▏| 36/39 [00:27<00:02,  1.33it/s]



 95%|█████████▍| 37/39 [00:27<00:01,  1.33it/s]



 97%|█████████▋| 38/39 [00:28<00:00,  1.33it/s]



100%|██████████| 39/39 [00:29<00:00,  1.32it/s]






In [8]:
len(bad_idxs)

6624