In [5]:
from utils import *
from dataset_utils import *

import pandas as pd 
from sal_resnet import resnet50
from madry_models import vit_base_patch16_224 as vit_b_16

from collections import defaultdict
import torch.nn.functional as F
seed = 1

set_seed(seed)
torch.set_default_dtype(torch.float32)
torchvision.set_image_backend('accimage')

salient_imagenet = SalientImageNet(IMAGENET_PATH, SALIENT_IMAGENET_PATH, typ='core')

In [3]:
model_list = [(resnet50, 'resnet50'), (vit_b_16, 'vit_b_16')]

class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

models = []
for model_type, m_name in model_list:
    model = model_type(pretrained=True)
    if 'resnet' in m_name:
        model.fc = Identity()
    elif 'vit_b_16' in m_name:
        model.head = Identity()
    model = MyDataParallel(model).to(device)
    models.append(model.eval())

In [6]:
num_workers = 4*torch.cuda.device_count()
gpu_size = 64*torch.cuda.device_count()
salient_loader = torch.utils.data.DataLoader(salient_imagenet, 
                                             batch_size=max(gpu_size, 100), 
                                             shuffle=False, 
                                             num_workers=num_workers, 
                                             pin_memory=True,
                                             drop_last=False)

In [7]:
def gaussian_noise(imgs, noise):
    return (imgs + noise*torch.randn(*imgs.shape, device=device)).clip(0,1)

def blackout(imgs):
    return torch.zeros_like(imgs, device=device)

def greyout(imgs):
    return torch.zeros_like(imgs, device=device) + torch.tensor([[[0.485]], [[0.456]], [[0.406]]], device=device)

def blur(imgs):
    return torchvision.transforms.functional.gaussian_blur(imgs, 21)

In [8]:
def blockify(m):
    m_new = torch.reshape(m, (-1, 1, 14, 16, 14, 16))
    m_sum = m_new.mean(dim=(3,5), keepdims=True)
    m_new = torch.broadcast_to(m_sum, m_new.shape) > 0
    return m_new.float().reshape(-1, 1, 224, 224)

def block_segment(imgs, masks, patch_size=16):
    seg_num = 224//patch_size
    segments = torch.arange(seg_num*seg_num, dtype=torch.uint8, device=masks.device).reshape(1, 1, seg_num, 1, seg_num, 1)
    seg_mask = segments.repeat(1, 1, 1, patch_size, 1, patch_size)
    seg_mask = seg_mask.reshape(1, 1, 224, 224).repeat(len(masks), 1, 1, 1)
    return seg_mask

def bbox_segment(imgs, masks):
    inds = torch.argmax(masks.view(masks.shape[0],-1), dim=1)
    inds_h = torch.div(inds, 224, rounding_mode='floor')
    inds_v = inds - inds_h*224
    seg_masks = torch.zeros_like(masks, dtype=torch.uint8)
    tl = list(range(150, 0, -10))
#     print(tl)
#     print(inds_h, inds_v)
    for i, (ih, iv) in enumerate(zip(inds_h, inds_v)):
        ih, iv = ih.item(), iv.item()
        for it, t in enumerate(tl):
            seg_masks[i][0][max(0, ih-t):min(ih+t, 224), max(0, iv-t):min(iv+t, 224)] = it+1
    return seg_masks

def contour_segment(imgs, masks):
    tl = np.array(range(0, 20))/20
    seg_masks = torch.zeros_like(masks, dtype=torch.uint8)
    for i, t in enumerate(tl):
        seg_masks[(masks > t)] = (i+1)
    return seg_masks

# def sklearn_segment(imgs, masks, skseg_fn=quickshift):
#     seg_masks = []
#     for i, img in enumerate(imgs):
#         img = img.permute(1,2,0).numpy()
#         sm = torch.Tensor(skseg_fn(img), device=masks.device).to(torch.uint8)
#         seg_masks.append(sm[None,:])
#     seg_masks = torch.stack(seg_masks, dim=0)
#     return seg_masks

In [10]:
normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
sal_degrees = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,  1.0]
corruptions = [
               blackout,
               greyout,
#                blur,
               None
              ]
corr_names = [
              "Blackout",
              "Greyout",
#               "Blur",
              'Layer masking',
             ]

In [11]:
def get_random_patch(bs, num_patches, n, seg_masks):
    n = num_patches
    inds = [torch.randperm(num_patches)[:n].to(seg_masks.device) for i in range(bs)]
    masks = torch.ones(bs, n, 224, 224, device=seg_masks.device)
    for i, ind in enumerate(inds):
        masks[i, :, :, :] = (seg_masks[i] == ind.view(-1, 1, 1)).float()
    masks = masks.permute(1,0,2,3)
    masks = masks[:,:,None]
    return masks

In [12]:
from torch.nn.functional import cosine_similarity

In [None]:
patch_sizes = [32, 16]
all_dists = np.zeros((len(corruptions), len(patch_sizes), len(models)))
all_cos_sims = np.zeros((len(corruptions), len(patch_sizes), len(models)))
all_feature_vector_mags = np.zeros((len(corruptions), len(patch_sizes), len(models)))
baseline_dist = 0
seg_fn = block_segment
num_patches = 4
total = 0
eps = 1e-8
ad = []
bd = []
for i, (imgs, sal_masks, labels) in enumerate(salient_loader):
    print(i)
    for pi, ps in enumerate(patch_sizes):
        seg_mask = seg_fn(imgs, sal_masks, patch_size=ps).to(device)
        num_segs = seg_mask.max().item()+1
        print(pi, ps, num_segs)
        with torch.no_grad():

            corr_imgs = []
            for ci, corr in enumerate(corruptions):
                if corr is not None:
                    corr_imgs.append(corr(imgs))

            imgs = imgs.to(device)
            labels = labels.to(device)

            masks = get_random_patch(len(imgs), num_segs, num_patches, seg_mask)

            all_mask = (masks).sum(dim=0)

            for ci, corr in enumerate(corruptions):
                for mi, model in enumerate(models):
                    if corr is None:
                        features_list = [model((normalizer(imgs), m, [100, 5])).cpu() for m in masks]
                        combined_features = model((normalizer(imgs), all_mask, [100,5])).cpu()
                    else:
                        features_list = [model(normalizer(corr_imgs[ci].to(device)*(1-m) + imgs*(m))).cpu() for m in masks]
                        combined_features = model(normalizer(corr_imgs[ci].to(device)*(1-all_mask) + imgs*(all_mask))).cpu()
                    basis_features = torch.stack(features_list, dim=1)
                    cos_sims = cosine_similarity(torch.sum(basis_features, dim=1), combined_features)
                    dists = torch.linalg.norm(torch.sum(basis_features, dim=1) - combined_features, dim=-1)/(eps + torch.linalg.norm(combined_features, dim=-1))
                    feature_vector_mags = torch.linalg.norm(basis_features, dim=-1).mean(1)
                    
                    all_dists[ci][pi][mi] += (dists).sum()
                    all_cos_sims[ci][pi][mi] += (cos_sims).sum()
                    all_feature_vector_mags[ci][pi][mi] += (feature_vector_mags).sum()
            
    total += len(imgs)    
    print(np.round(all_dists, 3).transpose()/total)
    print(np.round(all_cos_sims, 3).transpose()/total)
    print(np.round(all_feature_vector_mags, 3).transpose()/total)
    
    if total > 1_000:
        break


0
0 32 49


In [None]:
np.save('./results/linearity_all_dists.npy', all_dists/total)
np.save('./results/linearity_all_cos_sims.npy', all_cos_sims/total)
np.save('./results/linearity_all_feature_vector_mags.npy', all_feature_vector_mags/total)

In [None]:
for i in range(2):
    pd.DataFrame(np.round(all_dists, 3).transpose()[i], index=[f'{ps} X {ps}' for ps in patch_sizes]).to_latex(open(f"./results/linearity_{model_list[i][1]}.csv", 'a'),
                                                          header=corr_names)