In [1]:
import torch

from utils import *
from dataset_utils import *
from sal_resnet import resnet50
from torch.nn import functional as F

from alexnet import alexnet
from squeezenet import squeezenet1_1
from densenet import densenet121

seed = 1
from sal_resnet import resnet50, resnext50_32x4d, wide_resnet50_2
from madry_models import vit_base_patch16_224 as vit_b_16, deit_base_patch16_224 as deit_b_16

set_seed(seed)
torch.set_default_dtype(torch.float32)
# torchvision.set_image_backend('accimage')
torch.set_float32_matmul_precision('medium')

In [2]:
img_transform = transforms.Compose([
                    transforms.Resize(224, interpolation=BICUBIC, max_size=None, antialias=None),
                    transforms.CenterCrop(size=(224, 224)),
                    transforms.ToTensor()
                    ])

mask_transform = transforms.Compose([
                    transforms.Resize(224, interpolation=NEAREST, max_size=None, antialias=None),
                    transforms.CenterCrop(size=(224, 224)),
                    ToTensor()
                    ])

In [3]:
pixel_imagenet = PixelImageNet(IMAGENET_PATH, 
                             PIXEL_IMAGENET_PATH,
                             img_transform=img_transform, mask_transform=mask_transform)

In [4]:
num_workers = 4*torch.cuda.device_count()
gpu_size = 256*torch.cuda.device_count()
from torch.utils.data import SubsetRandomSampler

pixel_imagenet_loader = torch.utils.data.DataLoader(pixel_imagenet, 
                                                     batch_size=max(gpu_size, 100), 
                                                     num_workers=num_workers, 
                                                     pin_memory=True,
                                                     shuffle=False,
#                                                      sampler=SubsetRandomSampler(indices=),
                                                     drop_last=False)

In [5]:
def blackout(imgs):
    return torch.zeros_like(imgs, device=imgs.device)

def greyout(imgs):
    return torch.zeros_like(imgs, device=imgs.device) + torch.tensor([[[0.485]], [[0.456]], [[0.406]]], device=imgs.device)

In [6]:
from efficientnet import efficientnet_b0
from squeezenet import squeezenet1_1
from mobilenet import mobilenet_v2, mobilenet_v3_large
from densenet import densenet121
model_list = [
              (resnet50, 'resnet50'), 
#               (deit_b_16, 'deit_b_16'),
              (wide_resnet50_2, 'wide_resnet50'),
              (alexnet, 'AlexNet'),
              (efficientnet_b0, 'EfficientNet'),
              (mobilenet_v3_large, 'MobileNet'),
              (squeezenet1_1, 'SqueezeNet'),
              (densenet121, 'DenseNet'),
             ]
models = [(MyDataParallel(model_type(pretrained=True)).to(device).eval(), model_name) for model_type, model_name in model_list]
models.append((MyDataParallel(timm.create_model('resnet50', pretrained=True)).to(device).eval(), 'ResNet50_timm'))
baselines = [(blackout, "Blackout"),
            (greyout, "Greyout"),
            ([1000, 1000],"Layer mask")]
normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


In [8]:
all_clean_preds = defaultdict(list)
all_masked_preds = dict([(model_name, defaultdict(list)) for _, model_name in models])
all_sqmasked_preds = dict([(model_name, defaultdict(list)) for _, model_name in models])

all_labels = []
hi, wi = torch.arange(224), torch.arange(224)
for i, (imgs, masks, labels) in enumerate(pixel_imagenet_loader):
    all_labels += list(labels.numpy())
    
    sq_masks = torch.cat((masks[:,:,112:], masks[:,:,:112]), dim=2)
    sq_masks = torch.cat((sq_masks[:,:,:,112:], sq_masks[:,:,:,:112]), dim=3)
    
    for model, model_name in models:
        with torch.no_grad():
            clean_logits = model(normalizer(imgs))
            clean_preds = list(clean_logits.argmax(-1).cpu().numpy())
            all_clean_preds[model_name] += clean_preds
            for b, bname in baselines:
                if isinstance(b, list):
                    logits = model((normalizer(imgs)*(1-masks), 1-masks, b))
                    sq_logits = model((normalizer(imgs)*(1-sq_masks), 1-sq_masks, b))
                else:
                    logits = model(normalizer(imgs*(1-masks) + b(imgs)*masks))
                    sq_logits = model(normalizer(imgs*(1-sq_masks) + b(imgs)*sq_masks))
                all_masked_preds[model_name][bname] += list(logits.argmax(-1).cpu().numpy())
                all_sqmasked_preds[model_name][bname] += list(sq_logits.argmax(-1).cpu().numpy())
        if i % 5 == 0 and i != 0:
            hits = (np.array(all_labels) == np.array(all_clean_preds[model_name]))
            blackout_hits = (np.array(all_masked_preds[model_name]['Blackout']) == np.array(all_labels))
            greyout_hits = (np.array(all_masked_preds[model_name]['Greyout']) == np.array(all_labels))
            lm_hits = (np.array(all_masked_preds[model_name]['Layer mask']) == np.array(all_labels))
            sq_blackout_hits = (np.array(all_sqmasked_preds[model_name]['Blackout']) == np.array(all_labels))
            sq_greyout_hits = (np.array(all_sqmasked_preds[model_name]['Greyout']) == np.array(all_labels))
            sq_lm_hits = (np.array(all_sqmasked_preds[model_name]['Layer mask']) == np.array(all_labels))
            print(i, model_name, blackout_hits[hits].mean(), greyout_hits[hits].mean(), lm_hits[hits].mean(),
                  sq_blackout_hits[hits].mean(), sq_greyout_hits[hits].mean(), sq_lm_hits[hits].mean())
        
desc = 'more_models'
with open(f'./results/pixel_imgnet_clean_preds_{desc}.pkl','wb+') as fp:
    pickle.dump(all_clean_preds, fp)
with open(f'./results/pixel_imgnet_masked_preds_{desc}.pkl','wb+') as fp:
    pickle.dump(all_masked_preds, fp)
with open(f'./results/pixel_imgnet_labels_{desc}.pkl','wb+') as fp:
    pickle.dump(all_labels, fp)
with open(f'./results/pixel_imgnet_broken_masked_preds_{desc}.pkl','wb+') as fp:
    pickle.dump(all_sqmasked_preds, fp)

5 resnet50 0.23318385650224216 0.21748878923766815 0.16442451420029897 0.6831091180866966 0.7227204783258595 0.8101644245142003
5 wide_resnet50 0.23727647867950483 0.23727647867950483 0.19119669876203577 0.7455295735900963 0.7861072902338377 0.8431911966987621
5 AlexNet 0.18886861313868614 0.16697080291970803 0.1478102189781022 0.4625912408759124 0.5611313868613139 0.6751824817518248
5 EfficientNet 0.25734767025089605 0.207168458781362 0.16774193548387098 0.7677419354838709 0.8767025089605734 0.7849462365591398
5 MobileNet 0.24517512508934952 0.24088634739099357 0.1658327376697641 0.7169406719085061 0.8227305218012866 0.725518227305218
5 SqueezeNet 0.1395112016293279 0.1364562118126273 0.11812627291242363 0.48879837067209775 0.560081466395112 0.659877800407332
5 DenseNet 0.20061967467079783 0.203718048024787 0.13632842757552285 0.7025561580170411 0.7257939581719597 0.7567776917118513
5 ResNet50_timm 0.29859154929577464 0.28732394366197184 0.22605633802816902 0.8492957746478873 0.909154

KeyboardInterrupt: 