In [1]:
from models import get_depth_estimation_model, get_fire_segmentation_model, dice_bce_loss
import tensorflow as tf
from tqdm.auto import tqdm

In [6]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [3]:
# Create a custom callback function to update the progress bar
class ProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, epochs, train_total, test_total):
        super().__init__()
        self.epochs = epochs
        self.train_total = train_total
        self.test_total = test_total
        self.progress_bar = None

    def on_epoch_begin(self, epoch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar = tqdm(total=self.train_total, unit=' batch', desc=f'Epoch {epoch:04}'.rjust(15))

    def on_train_batch_end(self, batch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.set_postfix(logs)
        self.progress_bar.update()

    def on_epoch_end(self, epoch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.set_postfix(logs)
        self.progress_bar.close()

    def on_test_begin(self, logs=None):
        self.progress_bar = tqdm(total=self.test_total, unit=' batch', desc=f'{"Testing".rjust(15)}')

    def on_test_batch_end(self, batch, logs=None):
        if logs: self.progress_bar.set_postfix(logs)
        self.progress_bar.update()

In [4]:
import tensorflow as tf
from data_prep import FireData
from keras.callbacks import LearningRateScheduler


def learning_rate_schedule(epoch, learning_rate):
    new_learning_rate = max(learning_rate * .95, 1e-7)
    return new_learning_rate

lr_scheduler = LearningRateScheduler(learning_rate_schedule)

EPOCHS = 20
BATCH_SIZE = 16
LEARNING_RATE = 0.0001
optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE,
    amsgrad=False,
)
model = get_fire_segmentation_model(loss_function=dice_bce_loss)

# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE, amsgrad=False),
              metrics=tf.keras.metrics.MeanAbsoluteError())
train_loader = FireData(split='train', batch_size=BATCH_SIZE)
valid_loader = FireData(split='val', batch_size=BATCH_SIZE)
model.fit(
    train_loader,
    epochs=EPOCHS,
    validation_data=valid_loader,
    validation_steps=len(valid_loader),
    verbose=0,
    callbacks=[model.checkpoint_callback,
               model.tensor_board_callback,
               ProgressCallback(epochs=EPOCHS, train_total=len(train_loader), test_total=len(valid_loader)),
               lr_scheduler]
)


     Epoch 0000:   0%|          | 0/1373 [00:00<?, ? batch/s]



        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0001:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0002:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0003:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0004:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0005:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0006:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0007:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0008:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0009:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0010:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0011:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0012:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0013:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0014:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0015:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0016:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0017:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0018:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

     Epoch 0019:   0%|          | 0/1373 [00:00<?, ? batch/s]

        Testing:   0%|          | 0/343 [00:00<?, ? batch/s]

<keras.callbacks.History at 0x24e5f804dc0>

In [7]:
%tensorboard --logdir 'Models/UNet Fire Segmentation/logs'

Reusing TensorBoard on port 6006 (pid 4244), started 0:04:23 ago. (Use '!kill 4244' to kill it.)