In [None]:
import tensorflow as tf

base_yolo_model = tf.saved_model.load('yolov5_tf')

input_shape = (640, 640, 3)

inputs = tf.keras.Input(shape=input_shape)

features = base_yolo_model(inputs)

#we use the functional api instead of the sequential one as the yolo model has a complex architecture (isn't sequential) 
x = tf.keras.layers.GlobalAveragePooling2D()(features)
x = tf.keras.layers.Dense(256, activation='relu')(x) 
x = tf.keras.layers.Dropout(0.5)(x) 
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(x)#sigmoid cause 2 classes instead of relu


model = tf.keras.Model(inputs=inputs, outputs=output_layer)

# freeze the base layers for the start of the training process
for layer in base_yolo_model.layers:
    layer.trainable = False


In [None]:
model.compile(optimizer='adam',
              loss='binary_crossentropy',  #binary classification
              metrics=['accuracy'])

#bounding boxes handled by yolo ??

In [None]:
train = tf.data.Dataset.list_files("labelled_images/*.png")
train = train.batch(32).prefetch(tf.data.AUTOTUNE)  # define batch size for training efficiency

val = tf.data.Dataset.list_files("labelled_images/*.png")
val = val.batch(32).prefetch(tf.data.AUTOTUNE)


In [None]:
#train a first time
history = model.fit(train,
                    epochs=10,  
                    validation_data=val)

model.save("trained_yolov5_tf")


In [None]:
#then unfreeze the yolo layers to fine tune
for layer in base_yolo_model.layers:
    layer.trainable = True

model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),  # lower learning rate for fine-tuning
              loss='binary_crossentropy',
              metrics=['accuracy'])

#retrain with the first layers weights unfrozen
history2 = model.fit(train,
                    epochs=10,
                    validation_data=val)

model.save("fine_tuned_yolov5_tf")
