In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
import torch
torch.cuda.set_device(0)

In [2]:
path = untar_data(URLs.IMAGENETTE)

In [3]:
tfms = get_transforms(do_flip=False)
data = ImageDataBunch.from_folder(path, train = 'train', valid = 'val', bs = 64, size = 224, ds_tfms = tfms).normalize(imagenet_stats)

In [4]:
learn = cnn_learner(data, models.resnet34, metrics = accuracy)
learn = learn.load('imagenet_bs64')
learn.summary()

Sequential
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [64, 112, 112]       9,408      False     
______________________________________________________________________
BatchNorm2d          [64, 112, 112]       128        True      
______________________________________________________________________
ReLU                 [64, 112, 112]       0          False     
______________________________________________________________________
MaxPool2d            [64, 56, 56]         0          False     
______________________________________________________________________
Conv2d               [64, 56, 56]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 56, 56]         128        True      
______________________________________________________________________
ReLU                 [64, 56, 56]         0          False     
___________________________________________________

In [5]:
class SaveFeatures :
    def __init__(self, m) : 
        self.handle = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, m, inp, outp) : 
        self.features = outp
    def remove(self) :
        self.handle.remove()

In [6]:
net = learn.model
net[0]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, ker

In [7]:
# saving outputs of all Basic Blocks
sf = [SaveFeatures(m) for m in [net[0][4], net[0][5], net[0][6], net[0][7]]]

#### Checking for one image

In [8]:
# sample data
x, y = next(iter(data.train_dl))

In [9]:
img1 = x[None, 1]
img1.shape

torch.Size([1, 3, 224, 224])

In [10]:
net(torch.autograd.Variable(img1))

tensor([[-1.6740, -2.9325,  7.3321,  1.4669, -1.5797, -0.3475, -0.7164,  2.8611,
         -2.6251, -1.3780]], device='cuda:0', grad_fn=<AddmmBackward>)

In [11]:
[o.features.shape for o in sf]

[torch.Size([1, 64, 56, 56]),
 torch.Size([1, 128, 28, 28]),
 torch.Size([1, 256, 14, 14]),
 torch.Size([1, 512, 7, 7])]

#### Checking for entire batch

In [12]:
net(torch.autograd.Variable(x))

tensor([[ 1.2025e-01,  1.0091e+01, -2.1707e+00, -3.6375e+00, -7.0766e-01,
          1.0154e+00,  2.0796e+00, -3.8686e+00, -1.9264e+00, -1.3335e+00],
        [-1.6740e+00, -2.9324e+00,  7.3321e+00,  1.4669e+00, -1.5797e+00,
         -3.4753e-01, -7.1642e-01,  2.8612e+00, -2.6251e+00, -1.3780e+00],
        [ 1.4631e+00,  5.6038e-01, -8.3666e-01, -4.3348e+00, -1.1969e+00,
          1.7359e+01, -3.2761e+00, -2.3036e+00, -1.6530e+00, -2.5572e+00],
        [-2.6287e+00, -5.4973e-01,  1.1431e+01, -2.1546e+00, -1.1265e+00,
          4.7562e-01, -1.7394e+00,  9.7366e-01, -1.4351e+00, -2.9585e+00],
        [-2.4767e+00, -2.8279e+00, -2.0044e+00, -2.5384e+00,  7.3675e-01,
         -2.0133e+00,  1.2367e+01,  5.4425e-01, -4.7577e-01, -9.9447e-01],
        [-2.1387e+00, -2.9445e-01,  3.4961e-01, -3.5609e+00,  1.1598e+01,
         -1.7425e+00, -2.3733e+00, -2.1771e+00, -5.6285e-01, -8.7082e-01],
        [-1.7832e+00, -2.4922e+00, -1.0908e+00, -1.9137e+00,  4.5941e-01,
          4.1464e-01, -1.3093e+0

In [13]:
[o.features.shape for o in sf]

[torch.Size([64, 64, 56, 56]),
 torch.Size([64, 128, 28, 28]),
 torch.Size([64, 256, 14, 14]),
 torch.Size([64, 512, 7, 7])]