# Train a RetinaNet to Detect ElectroMagnetic Signals

**Author:** [lukewood](https://lukewood.xyz), Kevin Anderson, Peter Gerstoft<br>
**Date created:** 2022/08/16<br>
**Last modified:** 2022/08/16<br>
**Description:** Train ...

## Overview

TODO provide an overview of em-loader.

In [1]:
import sys

import keras_cv
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from keras_cv import bounding_box
from tensorflow import keras
from tensorflow.keras import callbacks as callbacks_lib
from tensorflow.keras import optimizers
from luketils import artifacts
import em_loader
import wandb
from luketils import visualization

batch_size = 8
epochs = 1

ground_truth_mapping = [
    "Ground Truth",
]
ground_truth_mapping = dict(zip(range(len(ground_truth_mapping)), ground_truth_mapping))

prediction_mapping = [
    "Prediction",
]
prediction_mapping = dict(zip(range(len(prediction_mapping)), prediction_mapping))

In [2]:
checkpoint_path = 'weights/'

In [3]:
import os
artifacts_dir = 'artifacts/naive/'
artifacts.set_base(artifacts_dir)
os.makedirs(artifacts_dir, exist_ok=True)

## Data loading

Great!  Our data is now loaded into the format
`{"images": images, "bounding_boxes": bounding_boxes}`.  This format is supported in all
KerasCV preprocessing components.

Lets load some data and verify that our data looks as we expect it to.

In [None]:
import functools

dataset, dataset_info = em_loader.load(
    split="train",
    bounding_box_format="xywh",
    batch_size=9,
    version=2,
)


example = next(iter(dataset))
images, boxes = example["images"], example["bounding_boxes"]

plot_fn = functools.partial(visualization.plot_bounding_box_gallery, 
    images=images,
    value_range=(0, 255),
    bounding_box_format="xywh",
    y_true=boxes,
    scale=2,
    rows=3,
    cols=3,
    thickness=2,
    font_scale=1,
    class_mapping=ground_truth_mapping,
)
    
plot_fn(path=f"{artifacts_dir}/ground-truth.png")
plot_fn(show=True)

Looks like everything is structured as expected.  Now we can move on to constructing our
data pipeline

In [None]:
# train_ds is batched as a (images, bounding_boxes) tuple
# bounding_boxes are ragged
train_ds, train_dataset_info = em_loader.load(
    bounding_box_format="xywh", split="train", batch_size=batch_size, version=2
)
val_ds, val_dataset_info = em_loader.load(
    bounding_box_format="xywh", split="val", batch_size=batch_size, version=2
)


def unpackage_dict(inputs):
    return inputs["images"], inputs["bounding_boxes"]


train_ds = train_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)

train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)

Our data pipeline is now complete.  We can now move on to model creation and training.

## Model creation

We'll use the KerasCV API to construct a RetinaNet model.  In this tutorial we use
a pretrained ResNet50 backbone using weights.  In order to perform fine-tuning, we
freeze the backbone before training.  When `include_rescaling=True` is set, inputs to
the model are expected to be in the range `[0, 255]`.

In [None]:
model = keras_cv.models.RetinaNet(
    classes=1,
    bounding_box_format="xywh",
    backbone="resnet50",
    backbone_weights="imagenet",
    include_rescaling=True,
    evaluate_train_time_metrics=False,
)

In [None]:
optimizer = tf.optimizers.SGD(global_clipnorm=10.0)
metrics = [
    keras_cv.metrics.COCOMeanAveragePrecision(
        class_ids=range(1),
        bounding_box_format="xywh",
        name="MaP",
    ),
    keras_cv.metrics.COCORecall(
        class_ids=range(1),
        bounding_box_format="xywh",
        max_detections=100,
        name="Recall",
    ),
]

model.compile(
    classification_loss=keras_cv.losses.FocalLoss(from_logits=True, reduction="none"),
    box_loss=keras_cv.losses.SmoothL1Loss(l1_cutoff=1.0, reduction="none"),
    optimizer=optimizer,
    metrics=metrics,
)

All that is left to do is construct some callbacks:

In [None]:
callbacks = [
    callbacks_lib.TensorBoard(log_dir="logs"),
    callbacks_lib.ReduceLROnPlateau(patience=7),
    keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True),
]

And run `model.fit()`!

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds.take(20),
    epochs=50,
    callbacks=callbacks,
)

## Model Evaluation

First, lets plot our learning curves

In [None]:
metrics = history.history

In [None]:
', '.join(metrics.keys())

In [None]:
metrics_to_plot = {
    'Train': metrics['loss'],
    'Validation': metrics['val_loss'],
}

plot_fn = functools.partial(visualization.line_plot, data=metrics_to_plot, title='Loss', xlabel='Epochs', ylabel='Loss', transparent=True)
plot_fn(path=f'{artifacts_dir}/loss.png')
plot_fn(show=True)

In [None]:
metrics_to_plot = {
    'Mean Average Precision': metrics['val_MaP'],
}

plot_fn = functools.partial(visualization.line_plot, data=metrics_to_plot, title='Mean Average Precision', xlabel='Epochs', ylabel='Mean Average Precision')
plot_fn(path=f'{artifacts_dir}/MaP.png')
plot_fn(show=True)

In [None]:
metrics_to_plot = {
    'Recall': metrics['val_Recall'],
}

plot_fn = functools.partial(visualization.line_plot, data=metrics_to_plot, title='Recall', xlabel='Epochs', ylabel='Recall')
plot_fn(path=f'{artifacts_dir}/Recall.png')
plot_fn(show=True)

In [None]:
metrics_to_plot = {
    'Mean Average Precision': metrics['val_MaP'],
    'Recall': metrics['val_Recall'],
}

plot_fn = functools.partial(visualization.line_plot, data=metrics_to_plot, title='COCO Metrics', xlabel='Epochs')
plot_fn(path=f'{artifacts_dir}/COCO_Metrics.png')
plot_fn(show=True)

In [None]:
metrics_to_plot = {
    'Train': metrics['box_loss'],
    'Validation': metrics['val_box_loss'],
}

plot_fn = functools.partial(visualization.line_plot, data=metrics_to_plot, title='Box Loss', xlabel='Epochs', ylabel='Box Loss')
plot_fn(path=f'{artifacts_dir}/box_loss.png')
plot_fn(show=True)

In [None]:
metrics_to_plot = {
    'Train': metrics['classification_loss'],
    'Validation': metrics['val_classification_loss'],
}

plot_fn = functools.partial(visualization.line_plot, data=metrics_to_plot, title='Classification Loss', xlabel='Epochs', ylabel='Classification Loss')
plot_fn(path=f'{artifacts_dir}/classification_loss.png')
plot_fn(show=True)

In [None]:
metrics_to_plot = {
    'Train Box Loss': metrics['box_loss'],
    'Validation Box Loss': metrics['val_box_loss'],
    'Train Classification Loss': metrics['classification_loss'],
    'Validation Classification Loss': metrics['val_classification_loss'],
}

plot_fn = functools.partial(visualization.line_plot, data=metrics_to_plot, title='All Losses', xlabel='Epochs', ylabel='Box Loss')
plot_fn(path=f'{artifacts_dir}/all_losses.png')
plot_fn(show=True)

In [None]:
model.load_weights(checkpoint_path)

In [None]:
metrics = model.evaluate(val_ds, return_dict=True)

In [None]:
print("FINAL METRICS:", metrics)

Lets write our metrics to our artifact directory.

In [None]:
os.makedirs(f'{artifacts_dir}/metrics/', exist_ok=True)
for metric in metrics:
    with open(f"{artifacts_dir}/metrics/{metric}.txt", "w") as f:
        f.write(str(round(metrics[metric], 3)))

Cool, our metrics look pretty reasonable!  

Very low loss, but our MaP and Recall are not that great.  For perspective, in PascalVOC you can expect a MaP of 0.35 in a state of the art workflow - but the loss converges at 2.0 instead of 0.001.  So all in all, we either have an overfitting problem or something else...

For further investigation, lets visualize our detections.

In [None]:
import functools

def visualize_detections(model, split="train"):
    train_ds, val_dataset_info = em_loader.load(
        bounding_box_format="xywh", split=split, batch_size=9
    )
    train_ds = train_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)
    images, y_true = next(iter(train_ds.take(1)))
    y_pred = model.predict(images)
    
    plot_fn = functools.partial(
        visualization.plot_bounding_box_gallery,
        images,
        value_range=(0, 255),
        bounding_box_format='xywh',
        y_true=y_true,
        y_pred=y_pred,
        scale=2,
        rows=3,
        cols=3,
        thickness=2,
        font_scale=1,
        legend=True,
    )
    plot_fn(
        show=True,
    )
    plot_fn(
        path=f"{artifacts_dir}/{split}.png"
    )

In [None]:
visualize_detections(model, split="train")

In [None]:
visualize_detections(model, split="val")

Uh oh!  Looks like something is going wrong!  It seems that we are not catching any of the small-skinny, or oblong bounding boxes.

# Conclusion

While the loss converges, we are not able to effectively detect signals in our spectograms.
In part 2, we will dive into why we can't detect the signals and begin solving the issue.
That concludes part one of our notebooks.  To find out why, see part-2 of our notebooks!