In [31]:
# https://herbwood.tistory.com/11?category=867198

In [32]:
import torch
import torch.nn as nn
import torchvision

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import time
import os
from typing import Tuple, List, Dict, Optional
from torch import nn, Tensor

from torch.utils.data import Dataset, DataLoader
# from torch.utils.data.sampler import Sampler
import torch.optim as optim
import sys
sys.path.append('../')

In [33]:
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

In [34]:
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.inplanes

512

In [35]:
body = IntermediateLayerGetter(resnet18, return_layers={'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'})

In [36]:
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:4]
returned_layers = [1, 2, 3, 4]
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
print(return_layers)

in_channels_stage2 = resnet18.inplanes // 8
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
in_channels_list

{'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}


[64, 128, 256, 512]

In [37]:
input = torch.randn(2,3,800,800)

In [38]:
out = body(input)

In [39]:
out['3'].shape

torch.Size([2, 512, 25, 25])

In [40]:
# out = resnet18(input)
# out.shape

In [41]:
# (resnet18)

In [42]:
# from resnet import resnet18
# resnet = resnet18()
# out = resnet(input)


In [43]:
out_channels = 256
fpn = FeaturePyramidNetwork(
    in_channels_list=in_channels_list,
    out_channels=out_channels,
    extra_blocks=LastLevelMaxPool(),
)

In [44]:
input = torch.randn(2,3,800,800)
out = body(input)
print(out.keys())
feat = fpn(out)
print(feat.keys())

odict_keys(['0', '1', '2', '3'])
odict_keys(['0', '1', '2', '3', 'pool'])


In [45]:
feat['pool'].shape

torch.Size([2, 256, 13, 13])

In [46]:
feat.keys()

odict_keys(['0', '1', '2', '3', 'pool'])

In [47]:
fpn

FeaturePyramidNetwork(
  (inner_blocks): ModuleList(
    (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (layer_blocks): ModuleList(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (extra_blocks): LastLevelMaxPool()
)

In [48]:
class fpn_network (nn.Module) :
    def __init__(self,) :
        super(fpn_network, self).__init__()
#         super(fpn_network, self).__init__()
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        self.out_channel = 256
        
        for i in in_channels_list :
            inner_block = nn.Conv2d(i, self.out_channel, 1)
            layer_block = nn.Conv2d(self.out_channel, self.out_channel, 3, padding=1)
            
            self.inner_blocks.append(inner_block)
            self.layer_blocks.append(layer_block)
            
        for  m in self.modules() :
            if isinstance(m, nn.Conv2d) :
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)                
                
    def forward (self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
        names = list(x.keys())
        print(names)
#         x = list(x.values())  
        out = {}
#         out.update(x)
        
        for n in names :
#             print(type(n))
            li_n = int(n)
            print(x[n].shape)
            print(self.inner_blocks[li_n])
            print(self.layer_blocks[li_n])
            out[n] = self.layer_blocks[li_n](self.inner_blocks[li_n](x[n]))
        
        return out
        

In [49]:
fpn = fpn_network()

In [50]:
# input = torch.randn(2,3,800,800)
# out = body(input)
# print(out.keys())
# feat = fpn(out)
# print(feat.keys())

In [51]:
feat['0'].shape

torch.Size([2, 256, 200, 200])

In [52]:
from torchvision.models.detection.anchor_utils import AnchorGenerator

anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)

In [53]:
rpn_anchor_generator.cell_anchors

[tensor([[-23., -11.,  23.,  11.],
         [-16., -16.,  16.,  16.],
         [-11., -23.,  11.,  23.]]),
 tensor([[-45., -23.,  45.,  23.],
         [-32., -32.,  32.,  32.],
         [-23., -45.,  23.,  45.]]),
 tensor([[-91., -45.,  91.,  45.],
         [-64., -64.,  64.,  64.],
         [-45., -91.,  45.,  91.]]),
 tensor([[-181.,  -91.,  181.,   91.],
         [-128., -128.,  128.,  128.],
         [ -91., -181.,   91.,  181.]]),
 tensor([[-362., -181.,  362.,  181.],
         [-256., -256.,  256.,  256.],
         [-181., -362.,  181.,  362.]])]

In [54]:
rpn_anchor_generator.num_anchors_per_location()

[3, 3, 3, 3, 3]

In [55]:
rpn_anchor_generator.num_anchors_per_location()[0]

3

In [56]:
from torchvision.models.detection.rpn import RPNHead

In [57]:
out_channels = 256
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])

In [60]:
feat['2'].shape

torch.Size([2, 256, 50, 50])

In [62]:
feat = list(feat.values())

In [66]:
objectness, pred_bbox_deltas = rpn_head(feat)

In [72]:
objectness[1].shape

torch.Size([2, 3, 100, 100])

In [73]:
pred_bbox_deltas[1].shape

torch.Size([2, 12, 100, 100])