## general steps for loading a feature extractor for object detection tasks

In [1]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.feature_extraction import (
    create_feature_extractor, 
    get_graph_node_names
)
device = 'cuda'

### specify the pretrained werights : "IMAGENET1K_V1",   "IMAGENET1K_V2"
link: https://pytorch.org/vision/stable/models.html

In [2]:
weights = "IMAGENET1K_V1"

### load the pretrained model
ref link: https://pytorch.org/vision/stable/models.html

In [3]:
model = models.resnet18(weights=weights)
model = model.to(device)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### get the node names 
ref link: https://pytorch.org/vision/stable/feature_extraction.html

In [4]:
train_nodes, eval_nodes = get_graph_node_names(model)
print('train nodes')
print('=' * 100)
for node in train_nodes: print(node)

train nodes
x
conv1
bn1
relu
maxpool
layer1.0.conv1
layer1.0.bn1
layer1.0.relu
layer1.0.conv2
layer1.0.bn2
layer1.0.add
layer1.0.relu_1
layer1.1.conv1
layer1.1.bn1
layer1.1.relu
layer1.1.conv2
layer1.1.bn2
layer1.1.add
layer1.1.relu_1
layer2.0.conv1
layer2.0.bn1
layer2.0.relu
layer2.0.conv2
layer2.0.bn2
layer2.0.downsample.0
layer2.0.downsample.1
layer2.0.add
layer2.0.relu_1
layer2.1.conv1
layer2.1.bn1
layer2.1.relu
layer2.1.conv2
layer2.1.bn2
layer2.1.add
layer2.1.relu_1
layer3.0.conv1
layer3.0.bn1
layer3.0.relu
layer3.0.conv2
layer3.0.bn2
layer3.0.downsample.0
layer3.0.downsample.1
layer3.0.add
layer3.0.relu_1
layer3.1.conv1
layer3.1.bn1
layer3.1.relu
layer3.1.conv2
layer3.1.bn2
layer3.1.add
layer3.1.relu_1
layer4.0.conv1
layer4.0.bn1
layer4.0.relu
layer4.0.conv2
layer4.0.bn2
layer4.0.downsample.0
layer4.0.downsample.1
layer4.0.add
layer4.0.relu_1
layer4.1.conv1
layer4.1.bn1
layer4.1.relu
layer4.1.conv2
layer4.1.bn2
layer4.1.add
layer4.1.relu_1
avgpool
flatten
fc


In [5]:
print('eval nodes')
print('=' * 100)
for node in train_nodes: print(node)

eval nodes
x
conv1
bn1
relu
maxpool
layer1.0.conv1
layer1.0.bn1
layer1.0.relu
layer1.0.conv2
layer1.0.bn2
layer1.0.add
layer1.0.relu_1
layer1.1.conv1
layer1.1.bn1
layer1.1.relu
layer1.1.conv2
layer1.1.bn2
layer1.1.add
layer1.1.relu_1
layer2.0.conv1
layer2.0.bn1
layer2.0.relu
layer2.0.conv2
layer2.0.bn2
layer2.0.downsample.0
layer2.0.downsample.1
layer2.0.add
layer2.0.relu_1
layer2.1.conv1
layer2.1.bn1
layer2.1.relu
layer2.1.conv2
layer2.1.bn2
layer2.1.add
layer2.1.relu_1
layer3.0.conv1
layer3.0.bn1
layer3.0.relu
layer3.0.conv2
layer3.0.bn2
layer3.0.downsample.0
layer3.0.downsample.1
layer3.0.add
layer3.0.relu_1
layer3.1.conv1
layer3.1.bn1
layer3.1.relu
layer3.1.conv2
layer3.1.bn2
layer3.1.add
layer3.1.relu_1
layer4.0.conv1
layer4.0.bn1
layer4.0.relu
layer4.0.conv2
layer4.0.bn2
layer4.0.downsample.0
layer4.0.downsample.1
layer4.0.add
layer4.0.relu_1
layer4.1.conv1
layer4.1.bn1
layer4.1.relu
layer4.1.conv2
layer4.1.bn2
layer4.1.add
layer4.1.relu_1
avgpool
flatten
fc


### extract all the nodes

In [6]:
all_node = [ (name, 'N'+str(idx)) for idx, name in enumerate(train_nodes)]
sel_node = dict(all_node[1:-2])
backbone = create_feature_extractor(model, sel_node)

In [7]:
all_node

[('x', 'N0'),
 ('conv1', 'N1'),
 ('bn1', 'N2'),
 ('relu', 'N3'),
 ('maxpool', 'N4'),
 ('layer1.0.conv1', 'N5'),
 ('layer1.0.bn1', 'N6'),
 ('layer1.0.relu', 'N7'),
 ('layer1.0.conv2', 'N8'),
 ('layer1.0.bn2', 'N9'),
 ('layer1.0.add', 'N10'),
 ('layer1.0.relu_1', 'N11'),
 ('layer1.1.conv1', 'N12'),
 ('layer1.1.bn1', 'N13'),
 ('layer1.1.relu', 'N14'),
 ('layer1.1.conv2', 'N15'),
 ('layer1.1.bn2', 'N16'),
 ('layer1.1.add', 'N17'),
 ('layer1.1.relu_1', 'N18'),
 ('layer2.0.conv1', 'N19'),
 ('layer2.0.bn1', 'N20'),
 ('layer2.0.relu', 'N21'),
 ('layer2.0.conv2', 'N22'),
 ('layer2.0.bn2', 'N23'),
 ('layer2.0.downsample.0', 'N24'),
 ('layer2.0.downsample.1', 'N25'),
 ('layer2.0.add', 'N26'),
 ('layer2.0.relu_1', 'N27'),
 ('layer2.1.conv1', 'N28'),
 ('layer2.1.bn1', 'N29'),
 ('layer2.1.relu', 'N30'),
 ('layer2.1.conv2', 'N31'),
 ('layer2.1.bn2', 'N32'),
 ('layer2.1.add', 'N33'),
 ('layer2.1.relu_1', 'N34'),
 ('layer3.0.conv1', 'N35'),
 ('layer3.0.bn1', 'N36'),
 ('layer3.0.relu', 'N37'),
 ('layer3

In [8]:
sel_node

{'conv1': 'N1',
 'bn1': 'N2',
 'relu': 'N3',
 'maxpool': 'N4',
 'layer1.0.conv1': 'N5',
 'layer1.0.bn1': 'N6',
 'layer1.0.relu': 'N7',
 'layer1.0.conv2': 'N8',
 'layer1.0.bn2': 'N9',
 'layer1.0.add': 'N10',
 'layer1.0.relu_1': 'N11',
 'layer1.1.conv1': 'N12',
 'layer1.1.bn1': 'N13',
 'layer1.1.relu': 'N14',
 'layer1.1.conv2': 'N15',
 'layer1.1.bn2': 'N16',
 'layer1.1.add': 'N17',
 'layer1.1.relu_1': 'N18',
 'layer2.0.conv1': 'N19',
 'layer2.0.bn1': 'N20',
 'layer2.0.relu': 'N21',
 'layer2.0.conv2': 'N22',
 'layer2.0.bn2': 'N23',
 'layer2.0.downsample.0': 'N24',
 'layer2.0.downsample.1': 'N25',
 'layer2.0.add': 'N26',
 'layer2.0.relu_1': 'N27',
 'layer2.1.conv1': 'N28',
 'layer2.1.bn1': 'N29',
 'layer2.1.relu': 'N30',
 'layer2.1.conv2': 'N31',
 'layer2.1.bn2': 'N32',
 'layer2.1.add': 'N33',
 'layer2.1.relu_1': 'N34',
 'layer3.0.conv1': 'N35',
 'layer3.0.bn1': 'N36',
 'layer3.0.relu': 'N37',
 'layer3.0.conv2': 'N38',
 'layer3.0.bn2': 'N39',
 'layer3.0.downsample.0': 'N40',
 'layer3.0.dow

In [9]:
backbone

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Module(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2):

### create a dummy input and compute the feature dimensions for all the nodes

In [10]:
input_data_shape = (1, 3, 256, 256)   # (num batches, num_channels, height, width)
dummy_in = torch.randn(input_data_shape).to(device)
dummy_out = backbone(dummy_in)
dummy_out

{'N1': tensor([[[[-6.4126e-01, -3.9959e+00,  4.1265e-01,  ...,  6.0506e-01,
            -3.3244e+00, -2.7867e+00],
           [-3.7809e-01,  6.4483e-01,  3.6997e+00,  ..., -5.8511e-01,
             7.3598e-01,  3.8803e+00],
           [-2.7625e+00, -1.2690e+00, -1.8874e+00,  ..., -1.6388e+00,
            -1.0817e-01,  4.1829e-01],
           ...,
           [ 5.4446e+00,  1.1474e+00,  3.3516e-01,  ...,  2.8674e+00,
             2.9920e+00, -1.4572e-01],
           [ 4.8673e+00, -1.7644e+00, -3.1566e+00,  ...,  9.2524e+00,
            -1.7811e+00, -2.5516e+00],
           [ 1.7187e-01, -8.2484e+00,  6.7959e-02,  ...,  1.1893e+00,
            -4.1062e+00,  4.4818e-01]],
 
          [[ 1.3066e+00,  3.8104e-01, -1.5333e+00,  ...,  3.2389e-02,
            -3.2969e-01, -3.4744e-01],
           [-6.2798e-01,  4.3766e-01,  1.7493e+00,  ...,  4.2445e-02,
             1.7046e+00,  1.8126e+00],
           [-2.5688e+00, -3.2187e+00, -2.0745e+00,  ..., -1.0538e-03,
            -2.3473e+00, -1.5526e

### create a dictionary of node tags to feature dimensions

In [11]:
nodes_and_feature_dim = [ ( key,  list(dummy_out[val].shape[1:]) ) 
                            for idx, (key, val) in enumerate(sel_node.items()) ]
nodes_and_feature_dim = dict(nodes_and_feature_dim)
nodes_and_feature_dim

{'conv1': [64, 128, 128],
 'bn1': [64, 128, 128],
 'relu': [64, 128, 128],
 'maxpool': [64, 64, 64],
 'layer1.0.conv1': [64, 64, 64],
 'layer1.0.bn1': [64, 64, 64],
 'layer1.0.relu': [64, 64, 64],
 'layer1.0.conv2': [64, 64, 64],
 'layer1.0.bn2': [64, 64, 64],
 'layer1.0.add': [64, 64, 64],
 'layer1.0.relu_1': [64, 64, 64],
 'layer1.1.conv1': [64, 64, 64],
 'layer1.1.bn1': [64, 64, 64],
 'layer1.1.relu': [64, 64, 64],
 'layer1.1.conv2': [64, 64, 64],
 'layer1.1.bn2': [64, 64, 64],
 'layer1.1.add': [64, 64, 64],
 'layer1.1.relu_1': [64, 64, 64],
 'layer2.0.conv1': [128, 32, 32],
 'layer2.0.bn1': [128, 32, 32],
 'layer2.0.relu': [128, 32, 32],
 'layer2.0.conv2': [128, 32, 32],
 'layer2.0.bn2': [128, 32, 32],
 'layer2.0.downsample.0': [128, 32, 32],
 'layer2.0.downsample.1': [128, 32, 32],
 'layer2.0.add': [128, 32, 32],
 'layer2.0.relu_1': [128, 32, 32],
 'layer2.1.conv1': [128, 32, 32],
 'layer2.1.bn1': [128, 32, 32],
 'layer2.1.relu': [128, 32, 32],
 'layer2.1.conv2': [128, 32, 32],
 '

### print the nodes and the corrosponding feature shapes

In [12]:
gap = 35
for (key, value) in nodes_and_feature_dim.items():
    print(key , '-'*(gap-len(key)), value)

conv1 ------------------------------ [64, 128, 128]
bn1 -------------------------------- [64, 128, 128]
relu ------------------------------- [64, 128, 128]
maxpool ---------------------------- [64, 64, 64]
layer1.0.conv1 --------------------- [64, 64, 64]
layer1.0.bn1 ----------------------- [64, 64, 64]
layer1.0.relu ---------------------- [64, 64, 64]
layer1.0.conv2 --------------------- [64, 64, 64]
layer1.0.bn2 ----------------------- [64, 64, 64]
layer1.0.add ----------------------- [64, 64, 64]
layer1.0.relu_1 -------------------- [64, 64, 64]
layer1.1.conv1 --------------------- [64, 64, 64]
layer1.1.bn1 ----------------------- [64, 64, 64]
layer1.1.relu ---------------------- [64, 64, 64]
layer1.1.conv2 --------------------- [64, 64, 64]
layer1.1.bn2 ----------------------- [64, 64, 64]
layer1.1.add ----------------------- [64, 64, 64]
layer1.1.relu_1 -------------------- [64, 64, 64]
layer2.0.conv1 --------------------- [128, 32, 32]
layer2.0.bn1 ----------------------- [128, 

### decide which intermediate outputs to keep based on feature dimensions and return_nodes

In [13]:
return_nodes = {'layer1.1.relu_1': 'c2', 
                'layer2.1.relu_1': 'c3', 
                'layer3.1.relu_1': 'c4',
                'layer4.1.relu_1': 'c5'}
return_nodes

{'layer1.1.relu_1': 'c2',
 'layer2.1.relu_1': 'c3',
 'layer3.1.relu_1': 'c4',
 'layer4.1.relu_1': 'c5'}

### recompute the backbone

In [14]:
backbone = create_feature_extractor(model, return_nodes)
backbone

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Module(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2):

### optionally freeze some layers like batch normalization or freeze all layers

In [15]:
# bn_modules = [ module for module in backbone.modules() if isinstance(module, nn.BatchNorm2d) ]
# for module in bn_modules:
#     for parameter in module.parameters():
#         parameter.requires_grad = False
        
backbone = backbone.requires_grad_(False)    # disable all the gradient

### print Model summary

In [16]:
from torchsummary import summary

input_data_shape = (3, 256, 256)   # (num batches, num_channels, height, width)
summary = summary(backbone, input_data_shape)

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 128, 128]        (9,408)
├─BatchNorm2d: 1-2                       [-1, 64, 128, 128]        (128)
├─ReLU: 1-3                              [-1, 64, 128, 128]        --
├─MaxPool2d: 1-4                         [-1, 64, 64, 64]          --
├─Module: 1                              []                        --
|    └─Module: 2                         []                        --
|    |    └─Conv2d: 3-1                  [-1, 64, 64, 64]          (36,864)
|    |    └─BatchNorm2d: 3-2             [-1, 64, 64, 64]          (128)
|    |    └─ReLU: 3-3                    [-1, 64, 64, 64]          --
|    |    └─Conv2d: 3-4                  [-1, 64, 64, 64]          (36,864)
|    |    └─BatchNorm2d: 3-5             [-1, 64, 64, 64]          (128)
|    |    └─ReLU: 3-6                    [-1, 64, 64, 64]          --
|    └─Module: 2                         []                