In [3]:
import os
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid
import ir_utils.interp_generators as igs
import ir_utils.dataloaders as dataloaders
from ir_utils.simple_models import SimpleCNN
import ir_utils.wide_resnet as wide_resnet
import ir_utils.utils as utils

device = torch.device(0 if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device)
if device != 'cpu':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

### Load the robust model to be used to generate the ground truth saliency maps

In [9]:
dataset = 'CIFAR-10'
simp = False # if True, use simple gradient, False => use SmoothGrad

if dataset == 'CIFAR-10':
#     model_name = 'model_pgd2_eps0.314_iters7_42.pt'
#     dir_name = 'pgdL2_eps0.314_iters7_smooth_unproc'
    
    model_name = 'model_0.pt'
    dir_name = 'std_train_smooth_unproc'
    

elif dataset == 'MNIST':
    pass
#     model_name = 'model_pgd2_eps2.5_iters40_0.pt'
#     dir_name = 'pgdL2_eps2.5_iters40_simp_unproc'
    
#     model_name = 'model_42.pt'
#     dir_name = 'std_train_simp_unproc'

#     model_name = 'model_42.pt'
#     dir_name = 'std_train_smooth_unproc'
    
#     model_name = 'model_pgd2_eps1.5_iters40_42.pt'
#     dir_name = 'pgdL2_eps1.5_iters40_smooth_unproc'
    
#     model_name = 'model_pgdinf_eps.3_iters40_0.pt'
#     dir_name = 'pgdinf_eps.3_iters40_simp_unproc'

if not os.path.isdir(f'data/{dataset}/{dir_name}'):
    os.mkdir(f'data/{dataset}/{dir_name}')
    
dataset, simp

('CIFAR-10', False)

In [10]:
if dataset == 'CIFAR-10':
    train_loader, test_loader = dataloaders.cifar10(batch_size=1, augment=False)
    net = wide_resnet.Wide_ResNet(depth=28, widen_factor=10, dropout_rate=.3, num_classes=10)
    net.load_state_dict(torch.load(f'trained_models/CIFAR-10/WRN-28-10_st/{model_name}', map_location=device))
elif dataset == 'MNIST':
    train_loader, test_loader = dataloaders.mnist(batch_size=1)
    net = SimpleCNN()
    net.load_state_dict(torch.load(f'trained_models/MNIST/SimpleCNN_st/{model_name}', map_location=device))

net.cuda()
net.eval(); print()

Files already downloaded and verified
Files already downloaded and verified



### For each sample in both the train and test sets, generate a ground truth saliency map and save in new datasets

In [12]:
for loader_name in ['training','test']:
    i = 0
    init = True
    if loader_name == 'training':
        loader = train_loader
    else:
        loader = test_loader
        
    samples = []
    labels = []
    salience_maps = []
    
    for sample, label in loader:
#         print(sample.min(), sample.max())
        if i % 1000 == 0:
            print(f'{i}/{len(loader.dataset)}')
            
        # for sample,label in loader
        sample, label = sample.to(device), label.to(device)
        
        if simp:
            salience_map = igs.simple_gradient(net, sample, label, 
                                               normalize=False, rgb=dataset=='CIFAR-10', abs=False)
            assert len(sample.size()) == 4
            assert len(label.size()) == 1
#             sample.requires_grad = True
#             logits = net(sample)
#             grad_outputs = F.one_hot(label, num_classes=10).float()
#             salience_map = torch.autograd.grad(logits, sample, grad_outputs=grad_outputs, create_graph=False)[0]
            assert len(salience_map.size()) == 4
        else:
            # SmoothGrad paper recommends 10-20% noise should be added. I.e. scale ~ .15
            salience_map = igs.smoothgrad(net, sample, label, j=50, scale=.15, 
                                          normalize=False, rgb=dataset=='CIFAR-10', abs=False) 
            salience_map = salience_map.unsqueeze(0)
            assert len(salience_map.size()) == 4
            
#         salience_map = utils.zero_one_scale(salience_map)
#         img_list = []
#         img_list.append(sample.squeeze(0).detach().cpu())
#         img_list.append(salience_map.squeeze(0).cpu())

#         utils.show(make_grid(img_list, nrow=4), size=8)
                
        samples.append(sample.detach().cpu())
        labels.append(label.cpu())
        salience_maps.append(salience_map.detach().cpu())

        i+=1
        
        if i % 5000 == 0 and dataset == 'CIFAR-10':
            samples, labels, salience_maps = torch.cat(samples, dim=0), torch.cat(labels, dim=0), torch.cat(salience_maps, dim=0)
            torch.save((samples, labels, salience_maps), 
                       f'data/CIFAR-10/{dir_name}/{loader_name}{i}.pt')
            print(f'saved {loader_name}{i}.pt')
            samples, labels, salience_maps = [], [], []
    
    if dataset == 'MNIST':
        samples, labels, salience_maps = torch.cat(samples, dim=0), torch.cat(labels, dim=0), torch.cat(salience_maps, dim=0)

        try:
            torch.save((samples, labels, salience_maps), f'data/MNIST/{dir_name}/{loader_name}.pt')
            print(f'saved {loader_name}.pt')
        except OSError:
            print('error closing file. continuing...')

0/50000
1000/50000
2000/50000
3000/50000
4000/50000
saved training5000.pt
5000/50000
6000/50000
7000/50000
8000/50000
9000/50000
saved training10000.pt
10000/50000
11000/50000
12000/50000
13000/50000
14000/50000
saved training15000.pt
15000/50000
16000/50000
17000/50000
18000/50000
19000/50000
saved training20000.pt
20000/50000
21000/50000
22000/50000
23000/50000
24000/50000
saved training25000.pt
25000/50000
26000/50000
27000/50000
28000/50000
29000/50000
saved training30000.pt
30000/50000
31000/50000
32000/50000
33000/50000
34000/50000
saved training35000.pt
35000/50000
36000/50000
37000/50000
38000/50000
39000/50000
saved training40000.pt
40000/50000
41000/50000
42000/50000
43000/50000
44000/50000
saved training45000.pt
45000/50000
46000/50000
47000/50000
48000/50000
49000/50000
saved training50000.pt
0/10000
1000/10000
2000/10000
3000/10000
4000/10000
saved test5000.pt
5000/10000
6000/10000
7000/10000
8000/10000
9000/10000
saved test10000.pt


### Make sure stuff is saving correctly

In [13]:
samples, labels, interps = torch.load(f'data/{dataset}/{dir_name}/test5000.pt')
# assert len(samples.size()) == 4 and len(labels.size()) == 1 and len(interps.size()) == 4
samples.shape, labels.shape, interps.shape

(torch.Size([5000, 3, 32, 32]),
 torch.Size([5000]),
 torch.Size([5000, 3, 32, 32]))

In [16]:
interps[0].min(), interps[0].max()

(tensor(-0.1228, device='cpu'), tensor(0.1578, device='cpu'))