<a href="https://colab.research.google.com/github/addy1997/Adaptive-Feature-Pyramid-Network/blob/main/Adaptive_FPN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jan  3 18:07:03 2022

@author: adwaitnaik
"""

'''
Feature Pyramid Networks implementation

# References:
- [Scale Adaptive Feature Pyramid Networks for 2D Object Detection](
    https://doi.org/10.1155/2020/8839979)
- [Feature Pyramid Networks for Object Detection](
    https://arxiv.org/abs/1612.03144)
'''

import tensorflow as tf
from tensorflow.keras import layers

class FPN(tf.keras.Model):

  def __init__(self, n_out_channels=256, verbose=None):
    super(FPN, self).__init__()

    '''
      Implementing the lateral connection C2-P2, C3-P3, C4-P4, and C5-P5 layers
    '''
    self.n_out_channels = n_out_channels

    '''
      Top layer
    '''
    self.conv_1 = layers.Conv2D(n_out_channels, kernel_size=(3,3), 
                                strides=2, kernel_initializer='he_normal', 
                                name='fpn_c1')
    '''
      Bottom-up layers C2, C3, C4, C5
    '''
    self.fpn_conv2p2 = layers.Conv2D(n_out_channels, kernel_size=(1,1),
                            kernel_initializer='he_normal', name='fpn_conv2p2')
    self.fpn_conv3p3 = layers.Conv2D(n_out_channels, kernel_size=(1,1),
                            kernel_initializer='he_normal', name='fpn_conv3p3')
    
    self.fpn_conv4p4 = layers.Conv2D(n_out_channels, kernel_size=(1,1),
                            kernel_initializer='he_normal', name='fpn_conv4p4')
    
    self.fpn_conv5p5 = layers.Conv2D(n_out_channels, kernel_size=(1,1),
                            kernel_initializer='he_normal', name='fpn_conv5p5')
    '''
      Upsampling2D layers in order P5, P4, and P3, i.e., top-down approach 
    '''
    self.fpn_upsampled_p5 = layers.UpSampling2D(size=(2,2), name='fpn_p5upsampled')
    self.fpn_upsampled_p4 = layers.UpSampling2D(size=(2,2), name='fpn_p4upsampled')
    self.fpn_upsampled_p3 = layers.UpSampling2D(size=(2,2), name='fpn_p3upsampled')

    '''
      Smooth layers (maxpooling layer not included)
    '''
    self.fpn_p5 = layers.Conv2D(n_out_channels, (3, 3), padding='SAME', 
                                    kernel_initializer='he_normal', name='fpn_p5')
    self.fpn_p4 = layers.Conv2D(n_out_channels, (3, 3), padding='SAME', 
                                    kernel_initializer='he_normal', name='fpn_p4')
    self.fpn_p3 = layers.Conv2D(n_out_channels, (3, 3), padding='SAME', 
                                    kernel_initializer='he_normal', name='fpn_p3')
    self.fpn_p2 = layers.Conv2D(n_out_channels, (3, 3), padding='SAME', 
                                    kernel_initializer='he_normal', name='fpn_p2')
    
  def call(self, inputs, training=True):
    C2, C3, C4, C5 = inputs

    # Smooth layers
    P5 = self.fpn_conv5p5(C5)
    P4 = self.fpn_conv4p4(C4) + self.fpn_upsampled_p5(P5)
    P3 = self.fpn_conv3p3(C3) + self.fpn_upsampled_p4(P4)
    P2 = self.fpn_conv2p2(C2) + self.fpn_upsampled_p3(P3)
        
    # Attach 3x3 conv to all P layers to form lateral connections
    P2 = self.fpn_p2(P2)
    P3 = self.fpn_p3(P3)
    P4 = self.fpn_p4(P4)
    P5 = self.fpn_p5(P5)
        
    return [P2, P3, P4, P5]

  def compute_output_tensor_shape(self, input_tensor_shape):
        C2_shape, C3_shape, C4_shape, C5_shape = input_tensor_shape
        
        C2_shape, C3_shape, C4_shape, C5_shape = \
            C2_shape.as_list(), C3_shape.as_list(), C4_shape.as_list(), C5_shape.as_list()
        
        C2_shape[-1] = self.n_out_channels
        C3_shape[-1] = self.n_out_channels
        C4_shape[-1] = self.n_out_channels
        C5_shape[-1] = self.n_out_channels
        
        return [tf.TensorShape(C2_shape),
                tf.TensorShape(C3_shape),
                tf.TensorShape(C4_shape),
                tf.TensorShape(C5_shape)]

if __name__ == '__main__':
    
    C2 = tf.random.normal((2, 256, 256,  256))
    C3 = tf.random.normal((2, 128, 128,  512))
    C4 = tf.random.normal((2,  64,  64, 1024))
    C5 = tf.random.normal((2,  32,  32, 2048))
    
    fpn = FPN()
    
    # Passing C2, C3, C4, C5 values
    P2, P3, P4, P5 = fpn([C2, C3, C4, C5])
    
    print('P5 shape:%s'%(P5.shape.as_list()))
    print('P4 shape:%s'%(P4.shape.as_list()))
    print('P3 shape:%s'%(P3.shape.as_list()))
    print('P2 shape:%s'%(P2.shape.as_list()))