In [1]:
import tensorflow as tf
from data_prep import FireData, DistanceData
from keras.callbacks import LearningRateScheduler
from models import get_depth_estimation_model, get_fire_segmentation_model, dice_bce_loss, UNet, depth_loss_function
from tqdm.auto import tqdm
from typing import Callable
import tensorflow_addons as tfa

In [2]:
# Create a custom callback function to update the progress bar
class ProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, epochs, train_total, test_total):
        super().__init__()
        self.epochs = epochs
        self.train_total = train_total
        self.test_total = test_total
        self.progress_bar = None

    def on_epoch_begin(self, epoch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar = tqdm(total=self.train_total, unit=' batch', desc=f'Epoch {epoch:04}'.rjust(15))

    def on_train_batch_end(self, batch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.set_postfix(logs)
        self.progress_bar.update()

    def on_epoch_end(self, epoch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.set_postfix(logs)
        self.progress_bar.close()

    def on_test_begin(self, logs=None):
        self.progress_bar = tqdm(total=self.test_total, unit=' batch', desc=f'{"Testing".rjust(15)}')

    def on_test_batch_end(self, batch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.update()


In [3]:
import tensorflow as tf
from keras.callbacks import Callback

class LRDecayCallback(Callback):
    def __init__(self, factor=0.1, patience=10, min_lr=1e-6):
        super(LRDecayCallback, self).__init__()
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.loss_history = []
        self.best_loss = float('inf')
        self.counter = 0

    def on_train_batch_end(self, batch, logs=None):
        current_loss = logs.get('loss')
        self.loss_history.append(current_loss)

        if current_loss < self.best_loss:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                lr = tf.keras.backend.get_value(self.model.optimizer.lr)
                new_lr = max(lr * self.factor, self.min_lr)
                tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
                self.counter = 0



In [4]:
def training_loop(train_data: tf.keras.utils.Sequence,
                  valid_data: tf.keras.utils.Sequence,
                  epochs: int = 50,
                  learning_rate: float = .0001,
                  optimizer: tf.keras.optimizers.Optimizer = None,
                  model: UNet = None,
                  loss: list = None,
                  learning_rate_scheduler: Callable = None,
                  ):
    if not optimizer:
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    callbacks = [model.checkpoint_callback, model.tensor_board_callback,
                 ProgressCallback(epochs=epochs, train_total=len(train_data), test_total=len(valid_data))]
    if learning_rate_scheduler:
        callbacks.append(LearningRateScheduler(learning_rate_schedule))
    model.compile(optimizer=optimizer,
                  metrics=[tfa.metrics.F1Score(num_classes=2),
                           tf.keras.metrics.Precision(),
                           tf.keras.metrics.AUC(),
                           tf.keras.metrics.Accuracy()])

    model.fit(train_data,
              epochs=epochs,
              validation_data=valid_data,
              validation_steps=len(valid_data),
              verbose=0,
              callbacks=callbacks)


In [5]:
def learning_rate_schedule(epoch, _):
    new_learning_rate = max(.0001 * .95 ** epoch, 1e-6)
    return new_learning_rate

In [6]:
# fire_detection_model = get_fire_segmentation_model(loss_function=tfa.losses.SigmoidFocalCrossEntropy())
# training_loop(train_data=FireData(batch_size=16, split='train'),
#               valid_data=FireData(batch_size=16, split='val'),
#               model=fire_detection_model,
#               learning_rate=.001,
#               learning_rate_scheduler=learning_rate_schedule)


In [7]:
# %load_ext tensorboard
# %tensorboard --logdir 'Models/UNet Fire Segmentation/logs'

In [8]:
batch_size = 8
depth_estimation_model = get_depth_estimation_model(loss_function=depth_loss_function)
train_data = DistanceData(batch_size=batch_size, split='train', height=512, width=384)
valid_data = DistanceData(batch_size=batch_size, split='val', height=512, width=384)
training_loop(train_data=train_data,
              valid_data=valid_data,
              model=depth_estimation_model,
              epochs=5,
              learning_rate=.00001,
              learning_rate_scheduler=learning_rate_schedule)

     Epoch 0000:   0%|          | 0/3182 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/96 [00:00<?, ? batch/s]

     Epoch 0001:   0%|          | 0/3182 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/96 [00:00<?, ? batch/s]

     Epoch 0002:   0%|          | 0/3182 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/96 [00:00<?, ? batch/s]

     Epoch 0003:   0%|          | 0/3182 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/96 [00:00<?, ? batch/s]

     Epoch 0004:   0%|          | 0/3182 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/96 [00:00<?, ? batch/s]

In [9]:
%load_ext tensorboard
%tensorboard --logdir 'Models/UNet Depth Estimation/logs'

Launching TensorBoard...