In [None]:
import tensorflow as tf
from data_prep import FireData, DistanceData
from keras.callbacks import LearningRateScheduler
from models import get_depth_estimation_model, get_fire_segmentation_model, dice_bce_loss, UNet, depth_loss_function
from tqdm.auto import tqdm
from typing import Callable
import tensorflow_addons as tfa

In [None]:
%load_ext tensorboard

In [None]:
# Create a custom callback function to update the progress bar
class ProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, epochs, train_total, test_total):
        super().__init__()
        self.epochs = epochs
        self.train_total = train_total
        self.test_total = test_total
        self.progress_bar = None

    def on_epoch_begin(self, epoch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar = tqdm(total=self.train_total, unit=' batch', desc=f'Epoch {epoch:04}'.rjust(15))

    def on_train_batch_end(self, batch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.set_postfix(logs)
        self.progress_bar.update()

    def on_epoch_end(self, epoch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.set_postfix(logs)
        self.progress_bar.close()

    def on_test_begin(self, logs=None):
        self.progress_bar = tqdm(total=self.test_total, unit=' batch', desc=f'{"Testing".rjust(15)}')

    def on_test_batch_end(self, batch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.update()


In [None]:
def training_loop(train_data: tf.keras.utils.Sequence,
                  valid_data: tf.keras.utils.Sequence,
                  epochs: int = 20,
                  learning_rate: float = .0001,
                  optimizer: tf.keras.optimizers.Optimizer = None,
                  model: UNet = None,
                  learning_rate_scheduler: Callable = None,
                  ):
    if not optimizer:
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    model.compile(optimizer=optimizer,
                  metrics=[tfa.metrics.F1Score(),
                           tf.keras.metrics.Precision(),
                           tf.keras.metrics.AUC(),
                           tf.keras.metrics.Accuracy()])
    callbacks = [model.checkpoint_callback, model.tensor_board_callback,
                 ProgressCallback(epochs=epochs, train_total=len(train_data), test_total=len(valid_data))]
    if learning_rate_scheduler:
        callbacks.append(LearningRateScheduler(learning_rate_schedule))
    model.fit(train_data,
              epochs=epochs,
              validation_data=valid_data,
              validation_steps=len(valid_data),
              verbose=0,
              callbacks=callbacks)


In [None]:
def learning_rate_schedule(epoch, learning_rate):
    new_learning_rate = max(learning_rate * .95, 1e-7)
    return new_learning_rate

fire_detection_model = get_fire_segmentation_model(loss_function=dice_bce_loss)
training_loop(train_data=FireData(batch_size=16, split='train'),
              valid_data=FireData(batch_size=16, split='val'),
              model=fire_detection_model,
              learning_rate=.001,
              learning_rate_scheduler=learning_rate_schedule)



In [None]:
%tensorboard --logdir 'Models/UNet Fire Segmentation/logs'

In [None]:
depth_estimation_model = get_depth_estimation_model(loss_function=depth_loss_function)
training_loop(train_data=DistanceData(batch_size=16, split='train', height=512, width=384),
              valid_data=DistanceData(batch_size=16, split='val', height=512, width=384),
              model=depth_estimation_model,
              learning_rate=.001,
              learning_rate_scheduler=learning_rate_schedule)