In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model

In [3]:
# Simulate backbone outputs (C3, C4, C5)
# These mimic the Darknet feature maps
C3 = layers.Input(shape=(52, 52, 256), name='C3')
C4 = layers.Input(shape=(26, 26, 512), name='C4')
C5 = layers.Input(shape=(13, 13, 1024), name='C5')

In [4]:
# Build FPN Top-down pathway (YOLOv3-style)
# Process C5 (13x13)
P5 = layers.Conv2D(512, (1,1), padding='same', activation='relu')(C5)

In [9]:
print(P5)

<KerasTensor shape=(None, 13, 13, 512), dtype=float32, sparse=False, ragged=False, name=keras_tensor_12>


In [10]:
# Upsample and concat with C4
P5_upsampled = layers.UpSampling2D(size=(2,2))(P5)
print(P5_upsampled)
P4 = layers.Conv2D(512, (1,1), padding='same', activation='relu')(C4)
print(P4)
P4 = layers.Concatenate()([P4, P5_upsampled])
print(P4)
P4 = layers.Conv2D(256, (3,3), padding='same', activation='relu')(P4)
print(P4)

<KerasTensor shape=(None, 26, 26, 512), dtype=float32, sparse=False, ragged=False, name=keras_tensor_28>
<KerasTensor shape=(None, 26, 26, 512), dtype=float32, sparse=False, ragged=False, name=keras_tensor_29>
<KerasTensor shape=(None, 26, 26, 1024), dtype=float32, sparse=False, ragged=False, name=keras_tensor_30>
<KerasTensor shape=(None, 26, 26, 256), dtype=float32, sparse=False, ragged=False, name=keras_tensor_31>


In [6]:
# Upsample and concat with C3
P4_upsampled = layers.UpSampling2D(size=(2,2))(P4)
P3 = layers.Conv2D(256, (1,1), padding='same', activation='relu')(C3)
P3 = layers.Concatenate()([P3, P4_upsampled])
P3 = layers.Conv2D(128, (3,3), padding='same', activation='relu')(P3)

In [7]:
# Detection heads (simulate YOLOv3 outputs)
det_small  = layers.Conv2D(3 * (5 + 80), (1,1), padding='same', name='detect_small')(P3)  # 52x52
det_medium = layers.Conv2D(3 * (5 + 80), (1,1), padding='same', name='detect_medium')(P4) # 26x26
det_large  = layers.Conv2D(3 * (5 + 80), (1,1), padding='same', name='detect_large')(P5)  # 13x13

In [2]:
fpn_yolo_v3 = Model(inputs=[C3, C4, C5],
                    outputs=[det_small, det_medium, det_large],
                    name="YOLOv3_FPN_Simulator")

fpn_yolo_v3.summary()
