In [28]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import UpSampling2D, Conv2DTranspose, AveragePooling2D, ZeroPadding2D, Add, Conv2D
from pprint import pprint

from models.blocks import create_res_conv_block, create_input_layer, conv_block
from models.rangeview_branch import create_range_view_branch

In [35]:
bev_in = create_input_layer((800, 800, 10), 'bev')
x = bev_in

'''
    Backbone
'''
for i in range(2): # same res
    x = create_res_conv_block(x, 24, 3)
f1 = x
    
for i in range(4): # 1 / 2 res
    x = create_res_conv_block(x, 48, 3, i==0)
f2 = x

for i in range(4): # 1 / 4 res
    x = create_res_conv_block(x, 96, 3, i==0)
f3 = x

for i in range(4): # 1 / 8 res
    x = create_res_conv_block(x, 192, 3, i==0)
f4 = x

for i in range(4): # 1 / 16 res
    x = create_res_conv_block(x, 256, 3, i==0)
f5 = x

In [36]:
rv_net = create_range_view_branch(input_shape=(375, 1242, 3), 
                                  input_names=['rgb_img_input', 'depth_map_input', 'intensity_map_input', 'height_map_input'])
rv_inputs  = rv_net['inputs']
rv_outputs = rv_net['outputs']

l2, l3, l4, l5 = rv_outputs

In [37]:
l2, l3, l4, l5

(<tf.Tensor 'average_13/Identity:0' shape=(None, 188, 621, 24) dtype=float32>,
 <tf.Tensor 'average_15/Identity:0' shape=(None, 94, 311, 48) dtype=float32>,
 <tf.Tensor 'average_17/Identity:0' shape=(None, 94, 311, 48) dtype=float32>,
 <tf.Tensor 'average_19/Identity:0' shape=(None, 47, 156, 92) dtype=float32>)

In [38]:
'''
    FPN + Header
'''
f2_dw = conv_block(f2, 96, 1, 1)
f2_dw = AveragePooling2D()(f2_dw)

f3_zp = f3

f4_up = conv_block(f4, 96, 1, 1)
f4_up = UpSampling2D(size=(2, 2), interpolation='bilinear')(f4_up)

f5_up = conv_block(f5, 96, 1, 1)
f5_up = UpSampling2D(size=(4, 4), interpolation='bilinear')(f5_up)

out = Add()([f2_dw, f3_zp, f4_up, f5_up])

obj_map = Conv2D(filters=1, 
                 kernel_size=1, 
                 padding='same', 
                 activation='sigmoid', 
                 name='obj_map', 
                 kernel_initializer='glorot_normal')(out)
geo_map = Conv2D(filters=11, 
                 kernel_size=1, 
                 padding='same', 
                 activation=None, 
                 name='geo_map',
                 kernel_initializer='glorot_normal')(out)

pprint([f2, f3, f4, f5])
print('-----------------')
pprint([f2_dw, f3_zp, f4_up, f5_up])
print('-----------------')
pprint([out])
print('-----------------')
pprint([obj_map, geo_map])
print('-----------------')

[<tf.Tensor 'add_149/Identity:0' shape=(None, 400, 400, 48) dtype=float32>,
 <tf.Tensor 'add_153/Identity:0' shape=(None, 200, 200, 96) dtype=float32>,
 <tf.Tensor 'add_157/Identity:0' shape=(None, 100, 100, 192) dtype=float32>,
 <tf.Tensor 'add_161/Identity:0' shape=(None, 50, 50, 256) dtype=float32>]
-----------------
[<tf.Tensor 'average_pooling2d_6/Identity:0' shape=(None, 200, 200, 96) dtype=float32>,
 <tf.Tensor 'add_153/Identity:0' shape=(None, 200, 200, 96) dtype=float32>,
 <tf.Tensor 'up_sampling2d_8/Identity:0' shape=(None, 200, 200, 96) dtype=float32>,
 <tf.Tensor 'up_sampling2d_9/Identity:0' shape=(None, 200, 200, 96) dtype=float32>]
-----------------
[<tf.Tensor 'add_186/Identity:0' shape=(None, 200, 200, 96) dtype=float32>]
-----------------
[<tf.Tensor 'obj_map_6/Identity:0' shape=(None, 200, 200, 1) dtype=float32>,
 <tf.Tensor 'geo_map_6/Identity:0' shape=(None, 200, 200, 11) dtype=float32>]
-----------------


In [40]:
# model = Model([bev_in] + rv_inputs, [obj_map, geo_map]).summary()