Skip to content

Commit

Permalink
additional occlusion strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
kahnchana committed Jun 8, 2021
1 parent 6ed98d9 commit 9b6edae
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 9 deletions.
25 changes: 24 additions & 1 deletion evaluate.py
Expand Up @@ -183,7 +183,16 @@ def main(args, device, verbose=True):
pass
return 0

clean_out = model(normalize(img.clone(), mean=mean, std=std))
if args.lesion:
if "resnet" in args.model_name:
clean_out = model(normalize(img.clone(), mean=mean, std=std), drop_layer=args.block_index,
drop_percent=args.drop_count)
else:
clean_out = model(normalize(img.clone(), mean=mean, std=std), block_index=args.block_index,
drop_rate=args.drop_count)
else:
clean_out = model(normalize(img.clone(), mean=mean, std=std))

if isinstance(clean_out, list):
clean_out = clean_out[-1]
clean_acc += torch.sum(clean_out.argmax(dim=-1) == label).item()
Expand Down Expand Up @@ -248,5 +257,19 @@ def main(args, device, verbose=True):
if not opt.test_image:
json.dump(acc_dict, open(f"report/dino/{opt.model_name}.json", "w"), indent=4)

elif opt.lesion:
for rand_exp in range(opt.exp_count):
acc_dict[f"run_{rand_exp:03d}"] = {}
block_index_list = opt.block_index
for cur_block_num in block_index_list:
opt.block_index = cur_block_num
acc_dict[f"run_{rand_exp:03d}"][f"{cur_block_num}"] = {}
for drop_count in [0.25, 0.50, 0.75]:
opt.drop_count = drop_count
acc = main(args=opt, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
acc_dict[f"run_{rand_exp:03d}"][f"{cur_block_num}"][f"{drop_count}"] = acc
if not opt.test_image:
json.dump(acc_dict, open(f"report/lesion/{opt.model_name}.json", "w"), indent=4)

else:
print("No arguments specified: finished running")
1 change: 0 additions & 1 deletion scripts/evaluate_occlusion.sh
@@ -1,7 +1,6 @@
#!/bin/bash

DATA_PATH="PATH/TO/IMAGENET/val"
DATA_PATH="$HOME/data/raw/imagenet/val"

python evaluate.py \
--model_name deit_tiny_patch16_224 \
Expand Down
40 changes: 40 additions & 0 deletions scripts/evaluate_occlusion_supp.sh
@@ -0,0 +1,40 @@
#!/bin/bash

DATA_PATH="PATH/TO/IMAGENET/val"
DATA_PATH="$HOME/data/raw/imagenet/val"


# use 8 x 8 grid of patches to drop
python evaluate.py \
--model_name deit_tiny_patch16_224 \
--test_dir "$DATA_PATH" \
--random_drop \
--shuffle_size 8 8

# use grid of patches with offset from top left
python evaluate.py \
--model_name deit_tiny_patch16_224 \
--test_dir "$DATA_PATH" \
--random_drop \
--random_offset_drop

# pixel level drop
python evaluate.py \
--model_name deit_tiny_patch16_224 \
--test_dir "$DATA_PATH" \
--random_drop \
--shuffle_size 224 224

# lesion study - feature drop
python evaluate.py \
--model_name deit_tiny_patch16_224 \
--test_dir "$DATA_PATH" \
--lesion \
--block_index 0 2 4 8 10

# lesion study - feature drop on resnet
python evaluate.py \
--model_name resnet_drop \
--test_dir "$DATA_PATH" \
--lesion \
--block_index 1 2 3 4 5
4 changes: 4 additions & 0 deletions utils.py
Expand Up @@ -50,6 +50,10 @@ def get_model(args, pretrained=True):
model = models.__dict__[args.model_name](pretrained=pretrained)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
elif 'resnet_drop' in args.model_name:
model = vit_models.drop_resnet50(pretrained=True)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
elif 'deit' in args.model_name:
model = create_model(args.model_name, pretrained=pretrained)
mean = (0.485, 0.456, 0.406)
Expand Down
1 change: 1 addition & 0 deletions vit_models/__init__.py
Expand Up @@ -8,3 +8,4 @@
from .t2t_vit_se import *
from .tnt import *
from .vit import *
from .resnet import drop_resnet50
23 changes: 16 additions & 7 deletions vit_models/deit.py
Expand Up @@ -2,6 +2,7 @@
# All rights reserved.
import math

import numpy as np
import torch
import torch.nn as nn
from functools import partial
Expand All @@ -12,7 +13,6 @@

import random


__all__ = [
'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
Expand Down Expand Up @@ -59,17 +59,17 @@ def forward(self, x):
x = [self.head(x) for x, _ in list_out]
x_dist = [self.head_dist(x_dist) for _, x_dist in list_out]
if self.training:
return [(out, out_dist) for out, out_dist in zip(x, x_dist)]
return [(out, out_dist) for out, out_dist in zip(x, x_dist)]
else:
# during inference, return the average of both classifier predictions
return [(out+out_dist) / 2 for out, out_dist in zip(x, x_dist)]
return [(out + out_dist) / 2 for out, out_dist in zip(x, x_dist)]


class VanillaVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward_features(self, x):
def forward_features(self, x, block_index=None, drop_rate=0):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
B, nc, w, h = x.shape
Expand Down Expand Up @@ -104,21 +104,30 @@ def forward_features(self, x):

layer_wise_tokens = []
for idx, blk in enumerate(self.blocks):

if block_index is not None and idx == block_index:
token = x[:, :1, :]
features = x[:, 1:, :]
row = np.random.choice(range(x.shape[1] - 1), size=int(drop_rate*x.shape[1]), replace=False)
features[:, row, :] = 0.0
x = torch.cat((token, features), dim=1)

x = blk(x)
layer_wise_tokens.append(x)

layer_wise_tokens = [self.norm(x) for x in layer_wise_tokens]

return [x[:, 0] for x in layer_wise_tokens], [x for x in layer_wise_tokens]

def forward(self, x, patches=False):
list_out, patch_out = self.forward_features(x)
def forward(self, x, block_index=None, drop_rate=0, patches=False):
list_out, patch_out = self.forward_features(x, block_index, drop_rate)
x = [self.head(x) for x in list_out]
if patches:
return x, patch_out
else:
return x


@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
model = VanillaVisionTransformer(
Expand Down Expand Up @@ -236,4 +245,4 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
return model
78 changes: 78 additions & 0 deletions vit_models/resnet.py
@@ -0,0 +1,78 @@
import torch
import torchvision
from torchvision.models.resnet import Bottleneck, load_state_dict_from_url, model_urls


class NewResnet(torchvision.models.ResNet):

def _forward_impl(self, x, drop_percent=None, drop_layer=0):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

if drop_layer == 1:
mask = torch.rand(x.shape[2:], device=x.device)
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0)
x = x * mask
x = self.layer1(x)

if drop_layer == 2:
mask = torch.rand(x.shape[2:], device=x.device)
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0)
x = x * mask
x = self.layer2(x)

if drop_layer == 3:
mask = torch.rand(x.shape[2:], device=x.device)
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0)
x = x * mask
x = self.layer3(x)

if drop_layer == 4:
mask = torch.rand(x.shape[2:], device=x.device)
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0)
x = x * mask
x = self.layer4(x)

if drop_layer == 5:
mask = torch.rand(x.shape[2:], device=x.device)
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0)
x = x * mask

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

def forward(self, x, drop_percent=None, drop_layer=None):
return self._forward_impl(x, drop_percent=drop_percent, drop_layer=drop_layer)


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = NewResnet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model


def drop_resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)


if __name__ == '__main__':
model = drop_resnet50(pretrained=True)
sample = torch.randn((1, 3, 224, 224))
out = model(sample, drop_layer=1, drop_percent=0.25)

0 comments on commit 9b6edae

Please sign in to comment.