Skip to content

Commit

Permalink
IMPROVEMENT: handle models with grayscale input (#31)
Browse files Browse the repository at this point in the history
* Handle models with grayscale input

* Fix spacing

* Fix linting

* Bump version
  • Loading branch information
MisaOgura committed May 29, 2020
1 parent ba5e3db commit 62a87ef
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion flashtorch/__init__.py
@@ -1 +1 @@
__version__ = '0.1.2'
__version__ = '0.1.3'
3 changes: 1 addition & 2 deletions flashtorch/saliency/backprop.py
Expand Up @@ -218,8 +218,7 @@ def _record_gradients(module, grad_in, grad_out):
self.gradients = grad_in[0]

for _, module in self.model.named_modules():
if isinstance(module, nn.modules.conv.Conv2d) and \
module.in_channels == 3:
if isinstance(module, nn.modules.conv.Conv2d):
module.register_backward_hook(_record_gradients)
break

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -26,7 +26,7 @@
DOCLINES = (__doc__ or '').split("\n")
long_description = "\n".join(DOCLINES[2:])

version = '0.1.2'
version = '0.1.3'

setup(
name='flashtorch',
Expand Down
31 changes: 31 additions & 0 deletions tests/test_backprop.py
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models

Expand Down Expand Up @@ -59,6 +60,21 @@ def make_expected_gradient_target(top_class):
return target


class CnnGrayscale(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=3, padding=1)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 10, kernel_size=3, stride=3, padding=1)
self.fc1 = nn.Linear(10 * 25 * 25, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))

return F.softmax(self.fc1(x.view(-1, 10 * 25 * 25)), dim=1)


#################
# Test fixtures #
#################
Expand All @@ -69,6 +85,11 @@ def model():
return models.alexnet()


@pytest.fixture
def model_grayscale():
return CnnGrayscale()


##############
# Test cases #
##############
Expand Down Expand Up @@ -162,6 +183,16 @@ def test_calc_gradients_of_top_class_if_prediction_is_wrong(mocker, model):
assert torch.all(kwargs['gradient'].eq(target))


def test_handle_greyscale_input(mocker, model_grayscale):
backprop = Backprop(model_grayscale)

input_ = torch.zeros([1, 1, 224, 224], requires_grad=True)

gradients = backprop.calculate_gradients(input_)

assert gradients.shape == (1, 224, 224)


def test_return_max_across_color_channels_if_specified(mocker, model):
backprop = Backprop(model)

Expand Down

0 comments on commit 62a87ef

Please sign in to comment.