In [5]:
import torch
import torchvision

In [9]:
m = torchvision.models.resnet18(pretrained=True)

In [10]:
for name, module in m.named_children():
    print(name)

conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc


## IntermediateLayerGetter

In [2]:
returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
return_layers

{'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}

In [11]:
# extract layer1 and layer3, giving as names `feat1` and feat2`
# new_m = torchvision.models._utils.IntermediateLayerGetter(m, {'layer1': 'feat1', 'layer3': 'feat2'})

new_m = torchvision.models._utils.IntermediateLayerGetter(m, return_layers)

out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])

[('0', torch.Size([1, 64, 56, 56])), ('1', torch.Size([1, 128, 28, 28])), ('2', torch.Size([1, 256, 14, 14])), ('3', torch.Size([1, 512, 7, 7]))]


In [12]:
in_channels_stage2 = m.inplanes // 8
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]

In [13]:
in_channels_stage2

64

In [14]:
m.inplanes

512

In [15]:
in_channels_list

[64, 128, 256, 512]

## FeaturePyramidNetwork

In [21]:
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

In [16]:
fpn = FeaturePyramidNetwork(in_channels_list, 256)

In [17]:
fpn_out = fpn(out)

In [19]:
print([(k, v.shape) for k, v in fpn_out.items()])

[('0', torch.Size([1, 256, 56, 56])), ('1', torch.Size([1, 256, 28, 28])), ('2', torch.Size([1, 256, 14, 14])), ('3', torch.Size([1, 256, 7, 7]))]


In [22]:
extra_blocks = LastLevelMaxPool()
fpn_2 = FeaturePyramidNetwork(in_channels_list, 256, extra_blocks)

In [23]:
fpn_2_out = fpn_2(out)

In [25]:
print([(k, v.shape) for k, v in fpn_2_out.items()])

[('0', torch.Size([1, 256, 56, 56])), ('1', torch.Size([1, 256, 28, 28])), ('2', torch.Size([1, 256, 14, 14])), ('3', torch.Size([1, 256, 7, 7])), ('pool', torch.Size([1, 256, 4, 4]))]


## BackboneWithFPN

In [26]:
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

In [28]:
backbone = resnet_fpn_backbone(
    'resnet18',
    pretrained=True,
    trainable_layers=5, # all layers
#     trainable_layers=3,
)

In [32]:
# x = torch.rand(1, 3, 224, 224)
x = torch.rand(1, 3, 200, 200)
out = backbone(x)
print([(k, v.shape) for k, v in out.items()])

[('0', torch.Size([1, 256, 50, 50])), ('1', torch.Size([1, 256, 25, 25])), ('2', torch.Size([1, 256, 13, 13])), ('3', torch.Size([1, 256, 7, 7])), ('pool', torch.Size([1, 256, 4, 4]))]


In [33]:
feature_dims = [(v.shape[-1]) for _, v in out.items()]
feature_dims

[50, 25, 13, 7, 4]