[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/SatelliteVu/SatelliteVu-AWS-Disaster-Response-Hackathon/blob/main/deep_learning/train.ipynb)

In this notebook, we train a ResU-Net architecture with the fire spread data

In [None]:
import os
import re
import numpy as np
import tensorflow as tf
from typing import Dict, List, Optional, Text, Tuple
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import colors

from datagen import get_dataset
from config import common_config, dataset_config, training_config, model_config

In [None]:
# Parameters
if not os.path.exists(common_config["INPUT_DIR"]):
    os.makedirs(common_config["INPUT_DIR"])
    
if not os.path.exists(common_config["OUTPUT_DIR"]):
    os.makedirs(common_config["OUTPUT_DIR"])

input_data_dir = os.path.join(common_config["INPUT_DIR"], "data")

parent_model_dir = os.path.join(common_config["INPUT_DIR"], "parent_model")
model_pattern = parent_model_dir + "/*.h5"

In [None]:
# Get the subsets for training and validaton.
train_dataset = get_dataset(
      input_data_dir + dataset_config["TRAIN_DATASET_PATTERN"],
      data_size=model_config["IMG_SIZE"][0],
      sample_size=model_config["IMG_SIZE"][0],
      batch_size=training_config["BATCH_SIZE"],
      num_in_channels=len(dataset_config["INPUT_FEATURES"]),
      compression_type=None,
      clip_and_normalize=False,
      clip_and_rescale=True,
      random_crop=False,
      center_crop=False,
      shuffle=True
        )
eval_dataset = get_dataset(
      input_data_dir + dataset_config["EVAL_DATASET_PATTERN"],
      data_size=model_config["IMG_SIZE"][0],
      sample_size=model_config["IMG_SIZE"][0],
      batch_size=training_config["BATCH_SIZE"],
      num_in_channels=len(dataset_config["INPUT_FEATURES"]),
      compression_type=None,
      clip_and_normalize=False,
      clip_and_rescale=True,
      random_crop=False,
      center_crop=False,
      shuffle=False)

In [None]:
# Plotting for verification of input data

train_inputs, train_labels = next(iter(train_dataset))

TITLES = dataset_config["INPUT_FEATURES"] + dataset_config["OUTPUT_FEATURES"]

n_rows=5
n_features= train_inputs.shape[3]

CMAP = colors.ListedColormap(['silver', 'orangered'])
BOUNDS = [0., 1.]
NORM = colors.BoundaryNorm(BOUNDS, CMAP.N)
keys = dataset_config["INPUT_FEATURES"]

fig = plt.figure(figsize=(15,6.5))

for i in range(n_rows):
    for j in range(n_features + 1):
        plt.subplot(n_rows, n_features + 1, i * (n_features + 1) + j + 1)
        if i == 0:
            plt.title(TITLES[j], fontsize=13)
        if j < n_features - 1:
            plt.imshow(train_inputs[i, :, :, j], cmap='viridis')
        if j == n_features - 1:
            plt.imshow(train_inputs[i, :, :, -1], cmap=CMAP, norm=NORM)
        if j == n_features:
            plt.imshow(train_labels[i, :, :, 0], cmap=CMAP, norm=NORM) 
        plt.axis('off')
plt.tight_layout()

Create custom_objects and get model architecture

In [None]:
import model_resunet
import model_satunet
import glob
import sys
import keras
from metrics import dice_coef, get_loss_function

# Get loss function
loss_function = get_loss_function(training_config["LOSS_FUNCTION_NAME"])

# Define model architecture
if model_config["MODEL_NAME"] == "resunet":
    model = model_resunet.get_model([model_config["IMG_SIZE"][0],model_config["IMG_SIZE"][1],len(dataset_config["INPUT_FEATURES"])])
elif model_config["MODEL_NAME"] == "satunet":
    model = model_satunet.get_model([model_config["IMG_SIZE"][0],model_config["IMG_SIZE"][1],len(dataset_config["INPUT_FEATURES"])], num_layers=model_config["NB_LAYERS"])
else:
    sys.exit("Provided wrong model name")

# Check if an input parent model was added to parent model directory
parent_model_paths = glob.glob(model_pattern)

if not len(parent_model_paths) <= 1:
    print("Only one parent model can be added to the parent model directory")
    sys.exit()

# If a parent model was given, load weights into the model
if model_config["TRAIN_FROM_PARENT_MODEL"] == True:
    print("Loading parent model weights")
    # Load the parent model's weights
    model.load_weights(parent_model_paths[0])
    parent_model_name = parent_model_paths[0]
else:    
    parent_model_name = None

# Unfreeze encoder layers if necessary
if model_config["UNFREEZE_ALL_LAYERS"]:
    for i in range(len(model.layers)): 
        keras.layers.trainable = True
        
model.summary()

In [None]:
# Callbacks
import wandb
from wandb.keras import WandbCallback

callbacks = list()

# Optional: WandB callback config and init
config = {
    "dataset_id": dataset_config["DATASET_ID"],
    "img_size": model_config["IMG_SIZE"],
    "model_architecture": model_config["MODEL_NAME"],
    "num_layers_satunet": model_config["NB_LAYERS"],
    "unfreeze_all_layers": model_config["UNFREEZE_ALL_LAYERS"],
    "parent_model_name": parent_model_name,
    "optimizer": training_config["OPTIMIZER_NAME"],
    "learning_rate": training_config["INITIAL_LEARNING_RATE"],
    "loss_function": training_config["LOSS_FUNCTION_NAME"],
    "epochs": training_config["NB_EPOCHS"],
    "batch_size": training_config["BATCH_SIZE"],
    "custom_objects": [
        "dice_coef",
        "focal_tversky_loss"
        ],
    "input_features": dataset_config["INPUT_FEATURES"]
    }
wandb.init(project="fire-model", config=config)
run_name = wandb.run.name
callbacks.append(WandbCallback())

# Define learning rate schedule callback
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    training_config["INITIAL_LEARNING_RATE"], decay_steps=15, decay_rate=0.96, staircase=True
    )

# Define checkpoints callback
checkpoint_path = os.path.join(common_config["OUTPUT_DIR"], "model", "fire_model_{}.h5".format(run_name))
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, save_weights_only=True, save_best_only=True
    )
callbacks.append(checkpoint_cb)

In [None]:
# Define optimizer
if training_config["OPTIMIZER_NAME"] == "adam":
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
elif training_config["OPTIMIZER_NAME"] == "nadam":
    optimizer = tf.keras.optimizers.Nadam()
elif training_config["OPTIMIZER_NAME"] == "rmsprop":
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr_schedule)
elif training_config["OPTIMIZER_NAME"] == "sgd":
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
else:
    sys.exit("Wrong optimizer name provided")

In [None]:
# Compile and train model
model.compile(
    optimizer=optimizer,
    loss=loss_function, metrics=[dice_coef,
                                 tf.keras.metrics.AUC(curve="PR"),
                                 tf.keras.metrics.Precision(),
                                 tf.keras.metrics.Recall()
                                ]
    )

history = model.fit(
    train_dataset,
    validation_data=eval_dataset,
    epochs=training_config["NB_EPOCHS"],
    callbacks=callbacks
    )