In [1]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import *
from keras.models import *
from keras.activations import *
import numpy as np

In [2]:
class region_proposal_network():
    
    def __init__(self, num_anchors_per_cell):
        super(region_proposal_network, self).__init__()
        self.num_anchors_per_cell  = num_anchors_per_cell
        bn = BatchNormalization
        
        # block 1
        self.conv1_block1, self.bn1_block1 = self.conv(128, (3,3),(2,2)), bn(trainable=True)
        self.conv2_block1, self.bn2_block1 = self.conv(128, (3,3),(1,1)), bn(trainable=True)
        self.conv3_block1, self.bn3_block1 = self.conv(128, (3,3),(1,1)), bn(trainable=True)
        self.conv4_block1, self.bn4_block1 = self.conv(128, (3,3),(1,1)), bn(trainable=True)

        # block 2
        self.conv1_block2, self.bn1_block2 = self.conv(128, (3,3),(2,2)), bn(trainable=True)
        self.conv2_block2, self.bn2_block2 = self.conv(128, (3,3),(1,1)), bn(trainable=True)
        self.conv3_block2, self.bn3_block2 = self.conv(128, (3,3),(1,1)), bn(trainable=True)
        self.conv4_block2, self.bn4_block2 = self.conv(128, (3,3),(1,1)), bn(trainable=True)
        self.conv5_block2, self.bn5_block2 = self.conv(128, (3,3),(1,1)), bn(trainable=True)
        self.conv6_block2, self.bn6_block2 = self.conv(128, (3,3),(1,1)), bn(trainable=True)

        # block 3
        self.conv1_block3, self.bn1_block3 = self.conv(256, (3,3),(2,2)), bn(trainable=True)
        self.conv2_block3, self.bn2_block3 = self.conv(256, (3,3),(1,1)), bn(trainable=True)
        self.conv3_block3, self.bn3_block3 = self.conv(256, (3,3),(1,1)), bn(trainable=True)
        self.conv4_block3, self.bn4_block3 = self.conv(256, (3,3),(1,1)), bn(trainable=True)
        self.conv5_block3, self.bn5_block3 = self.conv(256, (3,3),(1,1)), bn(trainable=True)
        self.conv6_block3, self.bn6_block3 = self.conv(256, (3,3),(1,1)), bn(trainable=True)

        # deconvolutions
        self.deconv_1, self.deconv_bn1 = self.deconv(256, (3,3), (1,1)), bn(trainable=True)
        self.deconv_2, self.deconv_bn2 = self.deconv(256, (2,2), (2,2)), bn(trainable=True)
        self.deconv_3, self.deconv_bn3 = self.deconv(256, (4,4), (4,4)), bn(trainable=True)

        # probability and regression maps
        self.prob_map_conv = self.conv(self.num_anchors_per_cell,(1,1),(1,1))
        self.reg_map_conv = self.conv(7*self.num_anchors_per_cell, (1,1),(1,1))
        
    def conv(self, out_chan, kernel_size, stride_size):
        conv_output = Conv2D(out_chan, kernel_size, stride_size, padding='SAME', data_format="channels_first")
        return conv_output
    
    def deconv(self, out_chan, kernel_size, stride_size):
        deconv_output = Conv2DTranspose(out_chan, kernel_size, stride_size, padding='SAME', data_format="channels_first")
        return deconv_output
    
    def conv_block(self, block_id, input):
        i = 1
        out = input
        while True:
            try:
                c = getattr(self, "conv{}_block{}".format(i, block_id))
                b = getattr(self, "bn{}_block{}".format(i, block_id))
            except:
                break
            
            out = ReLU()(b(c(out)))
            i = i + 1
        
        return out
    
    
    def deconv_block(self, i, input):
        out = input
        c = getattr(self, "deconv_{}".format(i))
        b = getattr(self, "deconv_bn{}".format(i))
        out = tf.nn.relu(b(c(out)))
        return out
    
    
    def call(self, input):
        
        input_shape = input.shape
        assert len(input_shape)==4 and input_shape[-1]%8==0 and input_shape[-2]%8==0, "The input must be of shape [Batch_size, channels, map_height, map_width] with map_height and map_width multiple of 8, got {}".format(input_shape)
        
        output_block1 = self.conv_block(1, input)
        output_deconv1 = self.deconv_block(1, output)
        
        output_block2 = self.conv_block(2, output_block1)
        output_deconv2 = self.deconv_block(2, output)
        
        output_block3 = self.conv_block(3, output_block2)
        output_deconv3 = self.deconv_block(3, output)
        
        output = tf.concat([output_deconv3, output_deconv2, output_deconv1], axis=1)
        
        prob_map = self.prob_map_conv((output))
        reg_map = self.reg_map_conv((output))
        prob_map = tf.transpose(prob_map, (0,2,3,1))
        reg_map = tf.transpose(reg_map, (0,2,3,1))

        prob_map = tf.nn.sigmoid(prob_map)

        return prob_map, reg_map