In [None]:
import tensorflow as tf
import time

from configuration import IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS, EPOCHS, NUM_CLASSES, BATCH_SIZE, save_model_dir, \
    load_weights_before_training, load_weights_from_epoch, save_frequency, test_images_during_training, \
    test_images_dir_list
from core.ground_truth import ReadDataset, MakeGT
from core.loss import SSDLoss
from core.make_dataset import TFDataset
from core.ssd import SSD, ssd_prediction
from utils.visualize import visualize_training_results


def print_model_summary(network):
    network.build(input_shape=(None, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS))
    network.summary()


if __name__ == '__main__':
    # GPU settings
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    dataset = TFDataset()
    train_data, train_count = dataset.generate_datatset()

    ssd = SSD()
    print_model_summary(network=ssd)

    if load_weights_before_training:
        ssd.load_weights(filepath=save_model_dir+"epoch-{}".format(load_weights_from_epoch))
        print("Successfully load weights!")
    else:
        load_weights_from_epoch = -1

    # loss
    loss = SSDLoss()

    # optimizer
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-3,
                                                                 decay_steps=20000,
                                                                 decay_rate=0.96)
    optimizer = tf.optimizers.Adam(learning_rate=lr_schedule)

    # metrics
    loss_metric = tf.metrics.Mean()
    cls_loss_metric = tf.metrics.Mean()
    reg_loss_metric = tf.metrics.Mean()

    def train_step(batch_images, batch_labels):
        with tf.GradientTape() as tape:
            pred = ssd(batch_images, training=True)
            output = ssd_prediction(feature_maps=pred, num_classes=NUM_CLASSES)
            gt = MakeGT(batch_labels, pred)
            gt_boxes = gt.generate_gt_boxes()
            loss_value, cls_loss, reg_loss = loss(y_true=gt_boxes, y_pred=output)
        gradients = tape.gradient(loss_value, ssd.trainable_variables)
        optimizer.apply_gradients(grads_and_vars=zip(gradients, ssd.trainable_variables))
        loss_metric.update_state(values=loss_value)
        cls_loss_metric.update_state(values=cls_loss)
        reg_loss_metric.update_state(values=reg_loss)


    for epoch in range(load_weights_from_epoch + 1, EPOCHS):
        start_time = time.time()
        for step, batch_data in enumerate(train_data):
            images, labels = ReadDataset().read(batch_data)
            train_step(batch_images=images, batch_labels=labels)
            time_per_step = (time.time() - start_time) / (step + 1)
            print("Epoch: {}/{}, step: {}/{}, {:.2f}s/step, loss: {:.5f}, "
                  "cls loss: {:.5f}, reg loss: {:.5f}".format(epoch,
                                                              EPOCHS,
                                                              step,
                                                              tf.math.ceil(train_count / BATCH_SIZE),
                                                              time_per_step,
                                                              loss_metric.result(),
                                                              cls_loss_metric.result(),
                                                              reg_loss_metric.result()))
        loss_metric.reset_states()
        cls_loss_metric.reset_states()
        reg_loss_metric.reset_states()

        if epoch % save_frequency == 0:
            ssd.save_weights(filepath=save_model_dir+"epoch-{}".format(epoch), save_format="tf")

        if test_images_during_training:
            visualize_training_results(pictures=test_images_dir_list, model=ssd, epoch=epoch)

    ssd.save_weights(filepath=save_model_dir+"saved_model", save_format="tf")


Model: "ssd_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
res_net50_1 (ResNet50)       multiple                  14785408  
_________________________________________________________________
conv2d_118 (Conv2D)          multiple                  262400    
_________________________________________________________________
conv2d_119 (Conv2D)          multiple                  1180160   
_________________________________________________________________
conv2d_120 (Conv2D)          multiple                  65664     
_________________________________________________________________
conv2d_121 (Conv2D)          multiple                  295168    
_________________________________________________________________
conv2d_122 (Conv2D)          multiple                  32896     
_________________________________________________________________
conv2d_123 (Conv2D)          multiple                  295168

Epoch: 0/50, step: 58/2141.0, 10.77s/step, loss: 6611.22900, cls loss: 13189.37305, reg loss: 33.08218
Epoch: 0/50, step: 59/2141.0, 10.80s/step, loss: 6503.34082, cls loss: 12974.14648, reg loss: 32.53175
Epoch: 0/50, step: 60/2141.0, 10.79s/step, loss: 6398.75928, cls loss: 12765.51660, reg loss: 32.00027
Epoch: 0/50, step: 61/2141.0, 10.79s/step, loss: 6297.31445, cls loss: 12563.13770, reg loss: 31.48919
Epoch: 0/50, step: 62/2141.0, 10.80s/step, loss: 6199.10742, cls loss: 12367.22266, reg loss: 30.99109
Epoch: 0/50, step: 63/2141.0, 10.84s/step, loss: 6104.17529, cls loss: 12177.74023, reg loss: 30.60897
Epoch: 0/50, step: 64/2141.0, 10.82s/step, loss: 6011.98389, cls loss: 11993.76465, reg loss: 30.20181
Epoch: 0/50, step: 65/2141.0, 10.82s/step, loss: 5922.40137, cls loss: 11815.05566, reg loss: 29.74584
Epoch: 0/50, step: 66/2141.0, 10.81s/step, loss: 5835.44141, cls loss: 11641.57715, reg loss: 29.30289
Epoch: 0/50, step: 67/2141.0, 10.82s/step, loss: 5751.22852, cls loss: 11

Epoch: 0/50, step: 138/2141.0, 10.87s/step, loss: 2854.79272, cls loss: 5694.00000, reg loss: 15.58529
Epoch: 0/50, step: 139/2141.0, 10.86s/step, loss: 2834.83862, cls loss: 5654.20166, reg loss: 15.47510
Epoch: 0/50, step: 140/2141.0, 10.85s/step, loss: 2815.16699, cls loss: 5614.96631, reg loss: 15.36697
Epoch: 0/50, step: 141/2141.0, 10.85s/step, loss: 2795.76758, cls loss: 5576.27002, reg loss: 15.26507
Epoch: 0/50, step: 142/2141.0, 10.84s/step, loss: 2776.57422, cls loss: 5537.98926, reg loss: 15.15924
Epoch: 0/50, step: 143/2141.0, 10.85s/step, loss: 2758.20312, cls loss: 5500.43506, reg loss: 15.97110
Epoch: 0/50, step: 144/2141.0, 10.86s/step, loss: 2739.57910, cls loss: 5463.28369, reg loss: 15.87429
Epoch: 0/50, step: 145/2141.0, 10.85s/step, loss: 2721.80933, cls loss: 5426.78223, reg loss: 16.83603
Epoch: 0/50, step: 146/2141.0, 10.86s/step, loss: 2703.67676, cls loss: 5390.61182, reg loss: 16.74121
Epoch: 0/50, step: 147/2141.0, 10.86s/step, loss: 2685.77368, cls loss: 5

Epoch: 0/50, step: 218/2141.0, 11.08s/step, loss: 1831.98669, cls loss: 3650.23608, reg loss: 13.73704
Epoch: 0/50, step: 219/2141.0, 11.08s/step, loss: 1823.83167, cls loss: 3633.98804, reg loss: 13.67485
Epoch: 0/50, step: 220/2141.0, 11.09s/step, loss: 1815.76514, cls loss: 3617.91650, reg loss: 13.61333
Epoch: 0/50, step: 221/2141.0, 11.09s/step, loss: 1807.72668, cls loss: 3601.90039, reg loss: 13.55248
Epoch: 0/50, step: 222/2141.0, 11.09s/step, loss: 1799.78503, cls loss: 3586.07788, reg loss: 13.49189
Epoch: 0/50, step: 223/2141.0, 11.10s/step, loss: 1791.92395, cls loss: 3570.41260, reg loss: 13.43487
Epoch: 0/50, step: 224/2141.0, 11.10s/step, loss: 1784.11499, cls loss: 3554.85425, reg loss: 13.37543
Epoch: 0/50, step: 225/2141.0, 11.11s/step, loss: 1776.37927, cls loss: 3539.44116, reg loss: 13.31722
Epoch: 0/50, step: 226/2141.0, 11.12s/step, loss: 1768.70618, cls loss: 3524.15332, reg loss: 13.25869
Epoch: 0/50, step: 227/2141.0, 11.12s/step, loss: 1761.11670, cls loss: 3