In [1]:
from utils import *
from dataset_utils import *
import timm
from sal_resnet import resnet50, resnext50_32x4d, wide_resnet50_2
from alexnet import alexnet
from efficientnet import efficientnet_b0
from squeezenet import squeezenet1_1
from madry_models import vit_base_patch16_224 as vit_b_16, deit_base_patch16_224 as deit_b_16

from mobilenet import mobilenet_v2, mobilenet_v3_large
from densenet import densenet121
from collections import defaultdict

from skimage.segmentation import quickshift, slic

seed = 1

set_seed(seed)
torch.set_default_dtype(torch.float32)
torchvision.set_image_backend('accimage')

### Update Imagenet path and Salient Imagenet path

## If using Salient ImageNet, make sure that the filepaths used in the SalientImageNet class is correct, 
### since there is no standard path for this dataset. All information and files can be found in the github link:
### https://github.com/singlasahil14/salient_imagenet

salient_dataset = SalientImageNet(IMAGENET_PATH, SALIENT_IMAGENET_PATH, typ='core')

In [2]:
num_workers = 4*torch.cuda.device_count()
gpu_size = 128*torch.cuda.device_count()
salient_loader = torch.utils.data.DataLoader(salient_dataset, 
                                             batch_size=max(gpu_size, 100), 
                                             shuffle=False, 
                                             num_workers=num_workers, 
                                             pin_memory=True,
                                             drop_last=False)

In [3]:
def gaussian_noise(imgs, noise):
    return (imgs + noise*torch.randn(*imgs.shape, device=device)).clip(0,1)

def blackout(imgs):
    return torch.zeros_like(imgs, device=device)

def greyout(imgs):
    return torch.zeros_like(imgs, device=device) + torch.tensor([[[0.485]], [[0.456]], [[0.406]]], device=device)

def averageout(imgs):
    return torch.mean(imgs, dim=(2,3), keepdim=True)

def redout(imgs):
    return torch.zeros_like(imgs, device=device) + torch.tensor([[[1.]], [[0.]], [[0.]]], device=device)

def blueout(imgs):
    return torch.zeros_like(imgs, device=device) + torch.tensor([[[0.]], [[0.]], [[1.]]], device=device)

def greenout(imgs):
    return torch.zeros_like(imgs, device=device) + torch.tensor([[[0.]], [[1.]], [[0.]]], device=device)

def gauss_blur(imgs):
    return torchvision.transforms.functional.gaussian_blur(imgs, kernel_size=(5,5), sigma=1.0)

In [4]:
def blockify(m):
    m_new = torch.reshape(m, (-1, 1, 14, 16, 14, 16))
    m_sum = m_new.mean(dim=(3,5), keepdims=True)
    m_new = torch.broadcast_to(m_sum, m_new.shape) > 0
    return m_new.float().reshape(-1, 1, 224, 224)

def block_segment(imgs, masks, patch_size=16):
    num_patch = 224//patch_size
    segments = torch.arange(num_patch*num_patch, 
                            dtype=torch.uint8, device=masks.device).reshape(1, 1, num_patch, 1, num_patch, 1)
    seg_mask = segments.repeat(1, 1, 1, patch_size, 1, patch_size)
    seg_mask = seg_mask.reshape(1, 1, 224, 224).repeat(len(masks), 1, 1, 1)
    return seg_mask

def bbox_segment(imgs, masks):
    inds = torch.argmax(masks.view(masks.shape[0],-1), dim=1)
    inds_h = torch.div(inds, 224, rounding_mode='floor')
    inds_v = inds - inds_h*224
    seg_masks = torch.zeros_like(masks, dtype=torch.uint8)
    tl = list(range(150, 0, -10))
#     print(tl)
#     print(inds_h, inds_v)
    for i, (ih, iv) in enumerate(zip(inds_h, inds_v)):
        ih, iv = ih.item(), iv.item()
        for it, t in enumerate(tl):
            seg_masks[i][0][max(0, ih-t):min(ih+t, 224), max(0, iv-t):min(iv+t, 224)] = it+1
    return seg_masks

def contour_segment(imgs, masks):
    tl = np.array(range(0, 20))/20
    seg_masks = torch.zeros_like(masks, dtype=torch.uint8)
    for i, t in enumerate(tl):
        seg_masks[(masks > t)] = (i+1)
    return seg_masks

def sklearn_segment(imgs, masks, skseg_fn=quickshift):
    seg_masks = []
    imgs = imgs.cpu()
    for i, img in enumerate(imgs):
        img = img.permute(1,2,0).numpy()
        sm = torch.Tensor(skseg_fn(img), device=masks.device)#.to(torch.uint8)
        seg_masks.append(sm[None,:])
    seg_masks = torch.stack(seg_masks, dim=0)
    return seg_masks
                         

In [5]:
def extract_seg_order(sal_mask, seg_mask, eps=1e-6):
    num_segs = int(seg_mask.max().item()+1)
    seg_sal_vals = []
    for i in range(num_segs):
        seg_sal_vals.append(torch.sum(sal_mask*(seg_mask == i), dim=(1,2,3)))#/(torch.sum((seg_mask == i), dim=(1,2,3))+eps))
    seg_sal_vals = torch.stack(seg_sal_vals, dim=-1)
    seg_order = torch.argsort(seg_sal_vals, dim=-1, descending=False)
    return seg_order#.to(torch.uint8)

def mask_seg_inds(seg_inds, seg_mask):
    return (torch.sum((seg_inds[:,:,None,None,None] == seg_mask[:,  None, :, :, :]), dim=1) > 0).float()

def plot_metric(sal_degrees, accuracies, corr_names, save_path=None):
    for i in range(len(accuracies)):
        plt.plot(sal_degrees, accuracies[i], label=corr_names[i])
    plt.legend()
    if save_path:
        plt.savefig(save_path)
        plt.clf()
    else:
        plt.show()
    return

In [6]:
sal_degrees = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,  1.0]
corruptions = [
        #        (blackout, "Blackout"),
               (greyout, "Greyout"),
               (greyout,"Edge smoothing + greyout"),
               ([1000, 1000],"Layer mask"),
#                 (averageout, 'Average'),
#                 (redout, "Red"),
#                 (blueout, "Blue"),
#                 (greenout, "Green")        
#                 ([7, 0,],"First convolution layer"),
#                 ([7, 2,],"First two resnet blocks"),
#                 ([0, 4,],"No padding"),
              ]
mtypes = [
          # (wide_resnet50_2, 'wide_resnet50'),
          (resnet50, 'resnet50'), 
          # (deit_b_16, 'DeiT-B-16'),
          #   (squeezenet1_1, 'SqueezeNet'),
          #   (alexnet, 'AlexNet'),
          #   (densenet121, 'DenseNet'),
          #   (efficientnet_b0, 'EfficientNet'),
          #   (mobilenet_v3_large, 'MobileNet'),
        ]
seg_fns = [
            # (lambda x, y: sklearn_segment(x, y, 
            #                   skseg_fn=lambda x: quickshift(x,kernel_size=4,
            #                                                 max_dist=200, ratio=0.2,
            #                                                 random_seed=0)),  'quickshift'),
            (lambda x, y: block_segment(x, y, patch_size=16), '16x16'),
#             (contour_segment, 'contour'),
            # (lambda x, y: sklearn_segment(x, y, skseg_fn=slic), 'slic'),
            ]
orders = [
            'random', 
            # 'start_from_sal', 
            # 'start_from_nsal'
        ] 
skip_inds = []# [(si, oi, ci, 1) for si in range(len(seg_fns)) for oi in range(len(orders)) for ci in [3, 4, 5] ]
write = True
desc = ""
num_samples = 1500

In [7]:
wn_sims = np.load('./wn_sims.npy')

In [8]:
models = [(MyDataParallel(mtype(pretrained=True)).to(device).eval(), m_name) for mtype, m_name in mtypes]
normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
models.append((timm.create_model('resnet50', pretrained=True).to(device).eval(), 'resnet50_timm'))

# seg_fn = lambda x, y: sklearn_segment(x, y, skseg_fn=quickshift)
res_shape = (len(seg_fns), len(orders), len(corruptions), len(models), len(sal_degrees))
total_hits = np.zeros(res_shape)
class_counts = np.zeros((*res_shape, 1000))
wn_sim_hits = np.zeros(res_shape)
unchanged_preds = np.zeros(res_shape)
total = 0
for i, (imgs, masks, labels) in enumerate(salient_loader):
    print(i)
    
    corr_imgs = []
    for ci, (corr, corr_name) in enumerate(corruptions):
        if not isinstance(corr, list):
            corr_imgs.append(corr(imgs))
    
    for si, (seg_fn, seg_fn_name) in enumerate(seg_fns):
        seg_mask = seg_fn(imgs, masks)
        num_segs = int(seg_mask.max().item()+1)
        print(seg_fn_name,num_segs)
        if num_segs > 200:
            print(seg_fn_name, " skipped")
            continue
        for oi, order in enumerate(orders):
            print(order)
            if order == 'start_from_sal' or order == 'start_from_nsal':
                seg_order = extract_seg_order(masks, seg_mask)
            elif order == 'random':
                rand_order = torch.argsort(torch.rand(len(masks), num_segs), dim=-1, descending=False)

            corr_masks = []
            for di, d in enumerate(sal_degrees):
                if order == 'start_from_sal':
                    m = mask_seg_inds(seg_order[:,num_segs-int(d*num_segs):], seg_mask)
                elif order == 'start_from_nsal':
                    m = mask_seg_inds(seg_order[:, :int(d*num_segs)], seg_mask)
                elif order == 'random': 
                    m = mask_seg_inds(rand_order[:, :int(d*num_segs)], seg_mask)
                corr_masks.append(m)

            imgs = imgs.to(device)
            labels = labels.to(device)
            
#             raise Exception()
            for ci, (corr, corr_name) in enumerate(corruptions):
                for mi, (model, mname) in enumerate(models):
                    if (si, oi, ci, mi) in skip_inds:
                        continue
                    for di, d in enumerate(sal_degrees):
                        m = corr_masks[di].to(device)                 
                        with torch.no_grad():
                            if isinstance(corr, list):
                                logits = model((normalizer(imgs)*(1-m), 1-m, corr))
                            else:
                                probe_imgs = corr_imgs[ci].to(device)*m + imgs*(1-m)
                                                            
                                if 'smooth' in corr_name:
                                    probe_imgs = probe_imgs*(1-m) + m*gauss_blur(probe_imgs)
                                logits = model(normalizer(probe_imgs))  
                        preds = logits.argmax(-1)
                        if d == 0:
                            init_preds = preds
                        total_hits[si][oi][ci][mi][di] += (preds == labels).float().sum().item()
                        class_counts[si][oi][ci][mi][di] +=  torch.bincount(preds, minlength=1000).cpu().numpy()
                        unchanged_preds[si][oi][ci][mi][di] += (preds == init_preds).float().sum().item()
                        wn_sim_hits[si][oi][ci][mi][di] += np.sum([wn_sims[p.item()][l.item()] for p, l in zip(preds, labels) ])
        
    total += len(imgs)
    
    ntotal_hits = total_hits/total
    nclass_counts = class_counts/total
    nwn_sim_hits = wn_sim_hits/total
    nunchanged_preds = unchanged_preds/total
#     raise Exception()
    if write:
        print('Writing..')
        np.save(f'./results/total_hits_{desc}.npy', total_hits)
        np.save(f'./results/class_counts_{desc}.npy', class_counts)
        np.save(f'./results/wn_sim_hits_{desc}.npy', wn_sim_hits)
        np.save(f'./results/unchanged_preds_{desc}.npy', unchanged_preds)
    if total > num_samples:
        break

0
16x16 196
random
Writing..
1
16x16 196
random
Writing..
2
16x16 196
random
Writing..
3
16x16 196
random
Writing..
4
16x16 196
random
Writing..
5
16x16 196
random
Writing..
6
16x16 196
random
Writing..
7
16x16 196
random
Writing..
8
16x16 196
random
Writing..
9
16x16 196
random
Writing..
10
16x16 196
random
Writing..
11
16x16 196
random


KeyboardInterrupt: 

In [17]:
ntotal_hits[0,0,:,0,:].mean(-1)

array([0.24606147, 0.269047  , 0.58483988])