In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
import torch
import numpy as np
torch.cuda.set_device(0)

In [2]:
path = untar_data(URLs.IMAGENETTE)

In [3]:
tfms = get_transforms(do_flip=False)
data = ImageDataBunch.from_folder(path, train = 'train', valid = 'val', bs = 64, size = 224, ds_tfms = tfms).normalize(imagenet_stats)

In [4]:
learn = cnn_learner(data, models.resnet34, metrics = accuracy)
learn = learn.load('imagenet_bs64')
learn.summary()

Sequential
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [64, 112, 112]       9,408      False     
______________________________________________________________________
BatchNorm2d          [64, 112, 112]       128        True      
______________________________________________________________________
ReLU                 [64, 112, 112]       0          False     
______________________________________________________________________
MaxPool2d            [64, 56, 56]         0          False     
______________________________________________________________________
Conv2d               [64, 56, 56]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 56, 56]         128        True      
______________________________________________________________________
ReLU                 [64, 56, 56]         0          False     
___________________________________________________

In [7]:
class SaveFeatures :
    def __init__(self, m) : 
        self.handle = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, m, inp, outp) : 
        self.features = outp
    def remove(self) :
        self.handle.remove()
        
# saving outputs of all Basic Blocks
net = learn.model
sf = [SaveFeatures(m) for m in [net[0][2], net[0][4], net[0][5], net[0][6], net[0][7]]]

In [8]:
net = learn.model
net[0]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, ker

#### Checking for one image

In [9]:
# sample data
x, y = next(iter(data.train_dl))

In [10]:
img1 = x[None, 1]
img1.shape

torch.Size([1, 3, 224, 224])

In [11]:
%timeit net(torch.autograd.Variable(img1))

4.16 ms ± 1.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
[o.features.shape for o in sf]

[torch.Size([1, 64, 56, 56]),
 torch.Size([1, 64, 56, 56]),
 torch.Size([1, 128, 28, 28]),
 torch.Size([1, 256, 14, 14]),
 torch.Size([1, 512, 7, 7])]

#### Checking for entire batch

In [13]:
%timeit net(torch.autograd.Variable(x))

The slowest run took 13.41 times longer than the fastest. This could mean that an intermediate result is being cached.
13.8 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
[o.features.shape for o in sf]

[torch.Size([64, 64, 56, 56]),
 torch.Size([64, 128, 28, 28]),
 torch.Size([64, 256, 14, 14]),
 torch.Size([64, 512, 7, 7])]

#### checking if np.save works properly (not using, taking the features directly at runtime)

In [27]:
n = 0
for m in sf : 
    # converting to numpy array
    fm = (m.features).cpu().detach().numpy()
    filename = 'numpyarray' + str(n) + '.npy'
    np.save(filename, fm)
    n += 1
    
for i in range(4) : 
    filename = 'numpyarray' + str(i) + '.npy'
    fm = np.load(filename)
    assert(np.array_equal(fm, sf[i].features.cpu().detach().numpy()))