In [None]:
import tensorflow as tf
import time

from configuration import IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS, EPOCHS, NUM_CLASSES, BATCH_SIZE, save_model_dir, \
    load_weights_before_training, load_weights_from_epoch, save_frequency, test_images_during_training, \
    test_images_dir_list
from core.ground_truth import ReadDataset, MakeGT
from core.loss import SSDLoss
from core.make_dataset import TFDataset
from core.ssd import SSD, ssd_prediction
from utils.visualize import visualize_training_results


def print_model_summary(network):
    network.build(input_shape=(None, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS))
    network.summary()


if __name__ == '__main__':
    # GPU settings
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    dataset = TFDataset()
    train_data, train_count = dataset.generate_datatset()

    ssd = SSD()
    print_model_summary(network=ssd)

    if load_weights_before_training:
        ssd.load_weights(filepath=save_model_dir+"epoch-{}".format(load_weights_from_epoch))
        print("Successfully load weights!")
    else:
        load_weights_from_epoch = -1

    # loss
    loss = SSDLoss()

    # optimizer
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-3,
                                                                 decay_steps=20000,
                                                                 decay_rate=0.96)
    optimizer = tf.optimizers.Adam(learning_rate=lr_schedule)

    # metrics
    loss_metric = tf.metrics.Mean()
    cls_loss_metric = tf.metrics.Mean()
    reg_loss_metric = tf.metrics.Mean()

    def train_step(batch_images, batch_labels):
        with tf.GradientTape() as tape:
            pred = ssd(batch_images, training=True)
            output = ssd_prediction(feature_maps=pred, num_classes=NUM_CLASSES)
            gt = MakeGT(batch_labels, pred)
            gt_boxes = gt.generate_gt_boxes()
            loss_value, cls_loss, reg_loss = loss(y_true=gt_boxes, y_pred=output)
        gradients = tape.gradient(loss_value, ssd.trainable_variables)
        optimizer.apply_gradients(grads_and_vars=zip(gradients, ssd.trainable_variables))
        loss_metric.update_state(values=loss_value)
        cls_loss_metric.update_state(values=cls_loss)
        reg_loss_metric.update_state(values=reg_loss)


    for epoch in range(load_weights_from_epoch + 1, EPOCHS):
        start_time = time.time()
        for step, batch_data in enumerate(train_data):
            images, labels = ReadDataset().read(batch_data)
            train_step(batch_images=images, batch_labels=labels)
            time_per_step = (time.time() - start_time) / (step + 1)
            print("Epoch: {}/{}, step: {}/{}, {:.2f}s/step, loss: {:.5f}, "
                  "cls loss: {:.5f}, reg loss: {:.5f}".format(epoch,
                                                              EPOCHS,
                                                              step,
                                                              tf.math.ceil(train_count / BATCH_SIZE),
                                                              time_per_step,
                                                              loss_metric.result(),
                                                              cls_loss_metric.result(),
                                                              reg_loss_metric.result()))
        loss_metric.reset_states()
        cls_loss_metric.reset_states()
        reg_loss_metric.reset_states()

        if epoch % save_frequency == 0:
            ssd.save_weights(filepath=save_model_dir+"epoch-{}".format(epoch), save_format="tf")

        if test_images_during_training:
            visualize_training_results(pictures=test_images_dir_list, model=ssd, epoch=epoch)

    ssd.save_weights(filepath=save_model_dir+"saved_model", save_format="tf")


Model: "ssd"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
res_net50 (ResNet50)         multiple                  14785408  
_________________________________________________________________
conv2d_53 (Conv2D)           multiple                  262400    
_________________________________________________________________
conv2d_54 (Conv2D)           multiple                  1180160   
_________________________________________________________________
conv2d_55 (Conv2D)           multiple                  65664     
_________________________________________________________________
conv2d_56 (Conv2D)           multiple                  295168    
_________________________________________________________________
conv2d_57 (Conv2D)           multiple                  32896     
_________________________________________________________________
conv2d_58 (Conv2D)           multiple                  295168  

Epoch: 0/50, step: 58/2141.0, 11.07s/step, loss: 6368.51660, cls loss: 12702.53906, reg loss: 34.49294
Epoch: 0/50, step: 59/2141.0, 11.10s/step, loss: 6264.72314, cls loss: 12495.52637, reg loss: 33.91893
Epoch: 0/50, step: 60/2141.0, 11.08s/step, loss: 6163.92969, cls loss: 12294.49414, reg loss: 33.36445
Epoch: 0/50, step: 61/2141.0, 11.08s/step, loss: 6066.17236, cls loss: 12099.51074, reg loss: 32.83238
Epoch: 0/50, step: 62/2141.0, 11.08s/step, loss: 5971.63135, cls loss: 11910.93652, reg loss: 32.32457
Epoch: 0/50, step: 63/2141.0, 11.09s/step, loss: 5880.04980, cls loss: 11728.13574, reg loss: 31.96134
Epoch: 0/50, step: 64/2141.0, 11.07s/step, loss: 5791.37451, cls loss: 11551.25195, reg loss: 31.49504
Epoch: 0/50, step: 65/2141.0, 11.05s/step, loss: 5705.02441, cls loss: 11379.02832, reg loss: 31.01881
Epoch: 0/50, step: 66/2141.0, 11.04s/step, loss: 5621.45166, cls loss: 11212.34375, reg loss: 30.55717
Epoch: 0/50, step: 67/2141.0, 11.03s/step, loss: 5540.35303, cls loss: 11

Epoch: 0/50, step: 138/2141.0, 10.92s/step, loss: 2752.23145, cls loss: 5488.25781, reg loss: 16.20268
Epoch: 0/50, step: 139/2141.0, 10.92s/step, loss: 2732.95825, cls loss: 5449.82568, reg loss: 16.08794
Epoch: 0/50, step: 140/2141.0, 10.91s/step, loss: 2714.01636, cls loss: 5412.05469, reg loss: 15.97570
Epoch: 0/50, step: 141/2141.0, 10.91s/step, loss: 2695.33374, cls loss: 5374.79248, reg loss: 15.87246
Epoch: 0/50, step: 142/2141.0, 10.91s/step, loss: 2676.89062, cls loss: 5338.01611, reg loss: 15.76231
Epoch: 0/50, step: 143/2141.0, 10.92s/step, loss: 2659.33398, cls loss: 5301.89551, reg loss: 16.77002
Epoch: 0/50, step: 144/2141.0, 10.93s/step, loss: 2641.39062, cls loss: 5266.11279, reg loss: 16.66551
Epoch: 0/50, step: 145/2141.0, 10.92s/step, loss: 2624.18945, cls loss: 5230.97119, reg loss: 17.40527
Epoch: 0/50, step: 146/2141.0, 10.94s/step, loss: 2606.74536, cls loss: 5196.19189, reg loss: 17.29646
Epoch: 0/50, step: 147/2141.0, 10.94s/step, loss: 2589.49658, cls loss: 5

Epoch: 0/50, step: 218/2141.0, 11.21s/step, loss: 1767.21460, cls loss: 3520.19922, reg loss: 14.22830
Epoch: 0/50, step: 219/2141.0, 11.22s/step, loss: 1759.33594, cls loss: 3504.50659, reg loss: 14.16385
Epoch: 0/50, step: 220/2141.0, 11.22s/step, loss: 1751.55017, cls loss: 3488.99829, reg loss: 14.10068
Epoch: 0/50, step: 221/2141.0, 11.23s/step, loss: 1743.82056, cls loss: 3473.60181, reg loss: 14.03767
Epoch: 0/50, step: 222/2141.0, 11.23s/step, loss: 1736.16370, cls loss: 3458.35107, reg loss: 13.97486
Epoch: 0/50, step: 223/2141.0, 11.25s/step, loss: 1728.58508, cls loss: 3443.25513, reg loss: 13.91377
Epoch: 0/50, step: 224/2141.0, 11.26s/step, loss: 1721.06592, cls loss: 3428.27832, reg loss: 13.85204
Epoch: 0/50, step: 225/2141.0, 11.27s/step, loss: 1713.61804, cls loss: 3413.44336, reg loss: 13.79111
Epoch: 0/50, step: 226/2141.0, 11.27s/step, loss: 1706.21790, cls loss: 3398.70386, reg loss: 13.73052
Epoch: 0/50, step: 227/2141.0, 11.27s/step, loss: 1698.91040, cls loss: 3

Epoch: 0/50, step: 298/2141.0, 11.41s/step, loss: 1303.82666, cls loss: 2596.20557, reg loss: 11.44635
Epoch: 0/50, step: 299/2141.0, 11.43s/step, loss: 1299.61755, cls loss: 2587.77905, reg loss: 11.45457
Epoch: 0/50, step: 300/2141.0, 11.43s/step, loss: 1295.39294, cls loss: 2579.36792, reg loss: 11.41670
Epoch: 0/50, step: 301/2141.0, 11.43s/step, loss: 1291.21704, cls loss: 2571.05249, reg loss: 11.38038
Epoch: 0/50, step: 302/2141.0, 11.42s/step, loss: 1287.04419, cls loss: 2562.74390, reg loss: 11.34300
Epoch: 0/50, step: 303/2141.0, 11.42s/step, loss: 1282.91565, cls loss: 2554.50586, reg loss: 11.32414
Epoch: 0/50, step: 304/2141.0, 11.42s/step, loss: 1278.80359, cls loss: 2546.31836, reg loss: 11.28763
Epoch: 0/50, step: 305/2141.0, 11.42s/step, loss: 1274.72485, cls loss: 2538.19702, reg loss: 11.25137
Epoch: 0/50, step: 306/2141.0, 11.42s/step, loss: 1270.67786, cls loss: 2530.13867, reg loss: 11.21572
Epoch: 0/50, step: 307/2141.0, 11.42s/step, loss: 1266.64392, cls loss: 2

Epoch: 0/50, step: 378/2141.0, 11.42s/step, loss: 1034.27576, cls loss: 2059.11792, reg loss: 9.43233
Epoch: 0/50, step: 379/2141.0, 11.42s/step, loss: 1031.61511, cls loss: 2053.82080, reg loss: 9.40806
Epoch: 0/50, step: 380/2141.0, 11.42s/step, loss: 1028.96265, cls loss: 2048.54028, reg loss: 9.38377
Epoch: 0/50, step: 381/2141.0, 11.42s/step, loss: 1026.35083, cls loss: 2043.31824, reg loss: 9.38196
Epoch: 0/50, step: 382/2141.0, 11.43s/step, loss: 1023.72632, cls loss: 2038.09314, reg loss: 9.35803
Epoch: 0/50, step: 383/2141.0, 11.43s/step, loss: 1021.11859, cls loss: 2032.90186, reg loss: 9.33382
Epoch: 0/50, step: 384/2141.0, 11.44s/step, loss: 1018.52094, cls loss: 2027.73022, reg loss: 9.31032
Epoch: 0/50, step: 385/2141.0, 11.44s/step, loss: 1015.93823, cls loss: 2022.58862, reg loss: 9.28649
Epoch: 0/50, step: 386/2141.0, 11.43s/step, loss: 1013.37451, cls loss: 2017.48389, reg loss: 9.26373
Epoch: 0/50, step: 387/2141.0, 11.43s/step, loss: 1010.81354, cls loss: 2012.38562

Epoch: 0/50, step: 459/2141.0, 11.48s/step, loss: 855.88007, cls loss: 1703.74963, reg loss: 8.00955
Epoch: 0/50, step: 460/2141.0, 11.48s/step, loss: 854.06104, cls loss: 1700.12720, reg loss: 7.99379
Epoch: 0/50, step: 461/2141.0, 11.48s/step, loss: 852.25085, cls loss: 1696.52319, reg loss: 7.97759
Epoch: 0/50, step: 462/2141.0, 11.48s/step, loss: 850.44434, cls loss: 1692.92712, reg loss: 7.96052
Epoch: 0/50, step: 463/2141.0, 11.48s/step, loss: 848.64673, cls loss: 1689.34851, reg loss: 7.94397
Epoch: 0/50, step: 464/2141.0, 11.49s/step, loss: 846.86597, cls loss: 1685.80322, reg loss: 7.92761
Epoch: 0/50, step: 465/2141.0, 11.49s/step, loss: 845.08691, cls loss: 1682.26208, reg loss: 7.91067
Epoch: 0/50, step: 466/2141.0, 11.50s/step, loss: 843.32251, cls loss: 1678.74902, reg loss: 7.89482
Epoch: 0/50, step: 467/2141.0, 11.50s/step, loss: 841.56830, cls loss: 1675.24890, reg loss: 7.88650
Epoch: 0/50, step: 468/2141.0, 11.50s/step, loss: 839.81018, cls loss: 1671.74915, reg loss

Epoch: 0/50, step: 541/2141.0, 11.62s/step, loss: 728.89294, cls loss: 1450.89172, reg loss: 6.89292
Epoch: 0/50, step: 542/2141.0, 11.62s/step, loss: 727.57660, cls loss: 1448.27148, reg loss: 6.88048
Epoch: 0/50, step: 543/2141.0, 11.62s/step, loss: 726.28589, cls loss: 1445.67871, reg loss: 6.89181
Epoch: 0/50, step: 544/2141.0, 11.62s/step, loss: 724.98364, cls loss: 1443.08667, reg loss: 6.87938
Epoch: 0/50, step: 545/2141.0, 11.63s/step, loss: 723.68207, cls loss: 1440.49585, reg loss: 6.86699
Epoch: 0/50, step: 546/2141.0, 11.63s/step, loss: 722.38489, cls loss: 1437.91370, reg loss: 6.85475
Epoch: 0/50, step: 547/2141.0, 11.63s/step, loss: 721.08899, cls loss: 1435.33447, reg loss: 6.84226
Epoch: 0/50, step: 548/2141.0, 11.63s/step, loss: 719.79852, cls loss: 1432.76575, reg loss: 6.83003
Epoch: 0/50, step: 549/2141.0, 11.63s/step, loss: 718.51416, cls loss: 1430.20935, reg loss: 6.81769
Epoch: 0/50, step: 550/2141.0, 11.63s/step, loss: 717.23547, cls loss: 1427.66370, reg loss

Epoch: 0/50, step: 623/2141.0, 11.72s/step, loss: 635.01862, cls loss: 1263.94336, reg loss: 6.09262
Epoch: 0/50, step: 624/2141.0, 11.72s/step, loss: 634.01886, cls loss: 1261.95349, reg loss: 6.08290
Epoch: 0/50, step: 625/2141.0, 11.72s/step, loss: 633.02179, cls loss: 1259.96899, reg loss: 6.07326
Epoch: 0/50, step: 626/2141.0, 11.72s/step, loss: 632.02948, cls loss: 1257.99402, reg loss: 6.06370
Epoch: 0/50, step: 627/2141.0, 11.72s/step, loss: 631.05646, cls loss: 1256.05627, reg loss: 6.05528
Epoch: 0/50, step: 628/2141.0, 11.72s/step, loss: 630.07385, cls loss: 1254.10022, reg loss: 6.04610
Epoch: 0/50, step: 629/2141.0, 11.72s/step, loss: 629.10248, cls loss: 1252.16577, reg loss: 6.03785
Epoch: 0/50, step: 630/2141.0, 11.72s/step, loss: 628.12476, cls loss: 1250.21973, reg loss: 6.02844
Epoch: 0/50, step: 631/2141.0, 11.72s/step, loss: 627.15198, cls loss: 1248.28357, reg loss: 6.01913
Epoch: 0/50, step: 632/2141.0, 11.72s/step, loss: 626.17999, cls loss: 1246.34863, reg loss

Epoch: 0/50, step: 705/2141.0, 11.66s/step, loss: 562.63715, cls loss: 1119.85034, reg loss: 5.42255
Epoch: 0/50, step: 706/2141.0, 11.66s/step, loss: 561.85901, cls loss: 1118.30103, reg loss: 5.41554
Epoch: 0/50, step: 707/2141.0, 11.66s/step, loss: 561.07806, cls loss: 1116.74670, reg loss: 5.40801
Epoch: 0/50, step: 708/2141.0, 11.66s/step, loss: 560.29865, cls loss: 1115.19531, reg loss: 5.40063
Epoch: 0/50, step: 709/2141.0, 11.66s/step, loss: 559.52814, cls loss: 1113.66150, reg loss: 5.39331
Epoch: 0/50, step: 710/2141.0, 11.66s/step, loss: 558.75464, cls loss: 1112.12207, reg loss: 5.38582
Epoch: 0/50, step: 711/2141.0, 11.66s/step, loss: 557.98236, cls loss: 1110.58484, reg loss: 5.37840
Epoch: 0/50, step: 712/2141.0, 11.65s/step, loss: 557.21259, cls loss: 1109.05273, reg loss: 5.37101
Epoch: 0/50, step: 713/2141.0, 11.65s/step, loss: 556.44733, cls loss: 1107.52942, reg loss: 5.36383
Epoch: 0/50, step: 714/2141.0, 11.65s/step, loss: 555.68146, cls loss: 1106.00500, reg loss