Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
163 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#!/bin/bash | ||
|
||
DATA_PATH="PATH/TO/IMAGENET/val" | ||
DATA_PATH="$HOME/data/raw/imagenet/val" | ||
|
||
|
||
# use 8 x 8 grid of patches to drop | ||
python evaluate.py \ | ||
--model_name deit_tiny_patch16_224 \ | ||
--test_dir "$DATA_PATH" \ | ||
--random_drop \ | ||
--shuffle_size 8 8 | ||
|
||
# use grid of patches with offset from top left | ||
python evaluate.py \ | ||
--model_name deit_tiny_patch16_224 \ | ||
--test_dir "$DATA_PATH" \ | ||
--random_drop \ | ||
--random_offset_drop | ||
|
||
# pixel level drop | ||
python evaluate.py \ | ||
--model_name deit_tiny_patch16_224 \ | ||
--test_dir "$DATA_PATH" \ | ||
--random_drop \ | ||
--shuffle_size 224 224 | ||
|
||
# lesion study - feature drop | ||
python evaluate.py \ | ||
--model_name deit_tiny_patch16_224 \ | ||
--test_dir "$DATA_PATH" \ | ||
--lesion \ | ||
--block_index 0 2 4 8 10 | ||
|
||
# lesion study - feature drop on resnet | ||
python evaluate.py \ | ||
--model_name resnet_drop \ | ||
--test_dir "$DATA_PATH" \ | ||
--lesion \ | ||
--block_index 1 2 3 4 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ | |
from .t2t_vit_se import * | ||
from .tnt import * | ||
from .vit import * | ||
from .resnet import drop_resnet50 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
import torchvision | ||
from torchvision.models.resnet import Bottleneck, load_state_dict_from_url, model_urls | ||
|
||
|
||
class NewResnet(torchvision.models.ResNet): | ||
|
||
def _forward_impl(self, x, drop_percent=None, drop_layer=0): | ||
# See note [TorchScript super()] | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
if drop_layer == 1: | ||
mask = torch.rand(x.shape[2:], device=x.device) | ||
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) | ||
x = x * mask | ||
x = self.layer1(x) | ||
|
||
if drop_layer == 2: | ||
mask = torch.rand(x.shape[2:], device=x.device) | ||
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) | ||
x = x * mask | ||
x = self.layer2(x) | ||
|
||
if drop_layer == 3: | ||
mask = torch.rand(x.shape[2:], device=x.device) | ||
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) | ||
x = x * mask | ||
x = self.layer3(x) | ||
|
||
if drop_layer == 4: | ||
mask = torch.rand(x.shape[2:], device=x.device) | ||
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) | ||
x = x * mask | ||
x = self.layer4(x) | ||
|
||
if drop_layer == 5: | ||
mask = torch.rand(x.shape[2:], device=x.device) | ||
mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) | ||
x = x * mask | ||
|
||
x = self.avgpool(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc(x) | ||
|
||
return x | ||
|
||
def forward(self, x, drop_percent=None, drop_layer=None): | ||
return self._forward_impl(x, drop_percent=drop_percent, drop_layer=drop_layer) | ||
|
||
|
||
def _resnet(arch, block, layers, pretrained, progress, **kwargs): | ||
model = NewResnet(block, layers, **kwargs) | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(model_urls[arch], | ||
progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model | ||
|
||
|
||
def drop_resnet50(pretrained=False, progress=True, **kwargs): | ||
r"""ResNet-50 model from | ||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, | ||
**kwargs) | ||
|
||
|
||
if __name__ == '__main__': | ||
model = drop_resnet50(pretrained=True) | ||
sample = torch.randn((1, 3, 224, 224)) | ||
out = model(sample, drop_layer=1, drop_percent=0.25) |