<a href="https://colab.research.google.com/github/MosheDorZarka/YoloV3/blob/main/YOLOv3_SelfAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


### In this project, I aim to enhance the YoloV3 model by integrating a novel Self-Attention (SA) mechanism. My primary objective is to investigate whether this augmented YoloV3 model outperforms its traditional counterparts in terms of accuracy and efficiency. This endeavor represents an exploration in the field of object detection, leveraging advanced techniques to potentially enhance the performance metrics.

In [20]:
!pip install -q --upgrade keras-cv
!pip install -q --upgrade keras

In [None]:
!pip show keras
!pip show keras-cv

In [22]:
import keras
import keras_cv
import tensorflow as tf
import numpy as np
import os

from keras.layers import (
    Layer,
    Conv2D,
    Embedding,
    BatchNormalization,
    LeakyReLU,
    MultiHeadAttention,
    Dense,
    Dropout,
    LayerNormalization,
    Input
)
from keras.ops import (
    expand_dims,
    arange,
    reshape,
    shape,
    scatter_update,
    cast,
    minimum,
    argmax,
    stack,
    concatenate,
    zeros,
    ones_like,
    zeros_like,
    tile,
    shape,
    sigmoid,
    exp,
    log,
    max,
    where,
    square,
    sqrt,
    sum,
    average,
    binary_crossentropy,
    sparse_categorical_crossentropy,
    meshgrid,
    equal,
    one_hot,
    softmax
)
from keras_cv.bounding_box import (
    compute_iou,
    convert_format,
    to_dense
)
from keras_cv.visualization import (
    plot_bounding_box_gallery
)
from keras_cv.layers import (
    RandomFlip,
    JitteredResize,
    MultiClassNonMaxSuppression,
    Resizing
)
from keras.regularizers import l2
from keras.activations import gelu
from keras.applications import ResNet50
from keras.applications.resnet import preprocess_input
from keras.callbacks import ReduceLROnPlateau
from keras import Model
from tensorflow_datasets import load
from tensorflow import data as tf_data

os.environ["KERAS_BACKEND"] = "tensorflow"

# Config

In [23]:
ANCHORS = [
    [[116, 90], [156, 198], [373, 326]],
    [[30, 61], [62, 45], [59, 119]],
    [[10, 13], [16, 30], [33, 23]]
    ]
IMAGE_DIMENSION = 416
GRID_SIZES = [13, 26, 52]
NUM_CLASSES = 20
CLASSES = [
    "Aeroplane",
    "Bicycle",
    "Bird",
    "Boat",
    "Bottle",
    "Bus",
    "Car",
    "Cat",
    "Chair",
    "Cow",
    "Dining Table",
    "Dog",
    "Horse",
    "Motorbike",
    "Person",
    "Potted Plant",
    "Sheep",
    "Sofa",
    "Train",
    "Tvmonitor",
    "Total",
]
BATCH_SIZE = 8

# Utils

In [24]:
def bbox_to_output_loop(classes, boxes):
  """
  Transforming a bbox with rel_center_xywh format to a valid output suitable for loss function.

  Args:
  - classes: A tensor of shape (N, )
  - boxes: A tensor of shape (N, 4)

  Returns:
  - scale_1, scale,2, scale_3: A tuple of tensors in the shape of (scale_grid_size, scale_grid_size, 3, 4 + 1 + 1)
  """
  outputs = []

  for scale, anchors in enumerate(ANCHORS):
    grid_x = cast(boxes[..., 0] * GRID_SIZES[scale], dtype="int32") # shape: (N, )
    grid_y = cast(boxes[..., 1] * GRID_SIZES[scale], dtype="int32") # shape: (N, )

    anchors_norm_np = np.array(anchors, dtype="float32") / IMAGE_DIMENSION # shape: (3, 2)

    # Calculating IoU's to find the best anchors
    bbox_w, bbox_h = boxes[..., 2:3], boxes[..., 3:4] # shape: (N, 1)
    anchors_norm_np_w, anchors_norm_np_h = expand_dims(anchors_norm_np[..., 0], axis=0), \
                                           expand_dims(anchors_norm_np[..., 1], axis=0)  # shape: (1, 3)
    intersection = minimum(bbox_w, anchors_norm_np_w) * minimum(bbox_h, anchors_norm_np_h) # shape: (N, 3)
    union = (bbox_w * bbox_h + anchors_norm_np_w * anchors_norm_np_h) - intersection # shape: (N, 3)
    IoUs = intersection / union

    best_anchors = argmax(IoUs, axis=-1) # shape: (N, )

    # converting xy to be relative to cells
    grid_xy = stack([grid_x, grid_y], axis=-1)
    bbox_xy = boxes[..., 0:2] * GRID_SIZES[scale] - cast(grid_xy, dtype="float32") # shape: (N, 2)
    bbox_rel = concatenate([bbox_xy, boxes[..., 2:4]], axis=-1) # shape: (N, 4)

    indices = stack([grid_y, grid_x, best_anchors], axis=-1) # shape: (N, 3)
    classes_expanded = expand_dims(classes, axis=-1)
    updates = concatenate([bbox_rel, ones_like(classes_expanded), classes_expanded], axis=-1) # shape (N, 4 + 1 + 1)

    scale_1 = scatter_update(inputs=zeros(shape=(GRID_SIZES[scale], GRID_SIZES[scale], 3, 4 + 1 + 1)), indices=indices, updates=updates)

    outputs += [scale_1]

  return outputs

In [25]:
def bbox_to_output_vec(classes, boxes):
  """
  Transforming a bbox with rel_center_xywh format to a valid output suitable for loss function.

  Args:
  - classes: A tensor of shape (N, )
  - boxes: A tensor of shape (N, 4)

  Returns:
  - scale_1, scale,2, scale_3: A tuple of tensors in the shape of (scale_grid_size, scale_grid_size, 3, 4 + 1 + 1)
  """
  grid_sizes_np = np.array([GRID_SIZES], dtype="float32") # shape: (1, 3)

  grid_x = cast(boxes[..., 0:1] * grid_sizes_np, dtype="int32") # shape: (N, 3)
  grid_y = cast(boxes[..., 1:2] * grid_sizes_np, dtype="int32") # shape: (N, 3)

  anchors_norm_np = np.array(ANCHORS, dtype="float32") / IMAGE_DIMENSION # shape: (3, 3, 2)

  # Calculating IoU's to find the best anchors
  boxes_w, boxes_h = expand_dims(boxes[..., 2:3], axis=1), expand_dims(boxes[..., 3:4], axis=1) # shape: (N, 1, 1)
  anchors_norm_np_w, anchors_norm_np_h = expand_dims(anchors_norm_np[..., 0], axis=0), \
                                         expand_dims(anchors_norm_np[..., 1], axis=0)  # shape: (1, 3, 3)
  intersection = minimum(boxes_w, anchors_norm_np_w) * minimum(boxes_h, anchors_norm_np_h) # shape: (N, 3, 3)
  union = (boxes_w * boxes_h + anchors_norm_np_w * anchors_norm_np_h) - intersection # shape: (N, 3, 3)
  IoUs = intersection / union

  best_anchors = argmax(IoUs, axis=-1) # shape: (N, 3)

  # converting xy to be relative to cells
  grid_xy = stack([grid_x, grid_y], axis=-1) # shape: (N, 3, 2)
  boxes_xy = expand_dims(boxes[..., 0:2], axis=1) * expand_dims(grid_sizes_np, axis=-1) - cast(grid_xy, dtype="float32") # shape: (N, 3, 2)
  boxes = concatenate([boxes_xy, tile(expand_dims(boxes[..., 2:4], axis=1), repeats=(1, 3, 1))], axis=-1) # shape: (N, 3, 4)

  indices = stack([grid_y, grid_x, best_anchors], axis=-1) # shape: (N, 3, 3)
  classes = tile(reshape(classes, newshape=(-1, 1, 1)), repeats=(1, 3, 1)) # shape: (N, 3, 1)
  updates = concatenate([boxes, ones_like(classes), classes], axis=-1) # shape (N, 3, 4 + 1 + 1)

  scale_1 = scatter_update(inputs=zeros(shape=(GRID_SIZES[0], GRID_SIZES[0], 3, 4 + 1 + 1)), indices=indices[..., 0, :], updates=updates[..., 0, :])
  scale_2 = scatter_update(inputs=zeros(shape=(GRID_SIZES[1], GRID_SIZES[1], 3, 4 + 1 + 1)), indices=indices[..., 1, :], updates=updates[..., 1, :])
  scale_3 = scatter_update(inputs=zeros(shape=(GRID_SIZES[2], GRID_SIZES[2], 3, 4 + 1 + 1)), indices=indices[..., 2, :], updates=updates[..., 2, :])

  return scale_1, scale_2, scale_3

In [26]:
def unpackage_raw_sample(sample):
  """
  Unpacking tfds raw samples and transforming it to be in valid format for KerasCV components.
  ({"images": images,
    "bounding_boxes": {"classes": classes, "boxes": boxes}}).
  This function assumes the samples bbox are in rel_yxyx format.

  Args:
  - sample: a raw tfds sample (FeaturesDict).

  Returns:
  - x: a valid format for KerasCV components.
  """
  # Unpacking the sample
  image = sample["image"]
  classes = sample["objects"]["label"]
  boxes = convert_format(
      boxes=sample["objects"]["bbox"],
      images=image,
      source="rel_yxyx",
      target="center_xywh"
  )

  # Creating the valid output format
  x = {
      "images": image,
      "bounding_boxes": {"classes": classes, "boxes": boxes}
      }

  return x

In [27]:
def preprocessing_wrapper(augmenters):
  def preprocessing(x):
    """
    Tranforming a valid KerasCV format sample to be a proper input and label for model trainig.

    Args:
    - x: a valid KerasCV format.

    Returns:
    - image, labels: A tuple of the preprocessed image and (scale_1, scale_2, scale_3) label.
    """

    # apply augmentations
    for augmenter in augmenters:
      x = augmenter(x)

    # apply ResNet50 preprocessing & tranforming the bbox to be valid labels
    image = preprocess_input(x["images"])

    x["bounding_boxes"] = to_dense(x["bounding_boxes"])
    x["bounding_boxes"]["boxes"] = x["bounding_boxes"]["boxes"] / IMAGE_DIMENSION
    labels = bbox_to_output_vec(**x["bounding_boxes"])

    return image, labels

  return preprocessing

In [28]:
def output_to_bbox(scale, anchors):
  """
  Extracting from a single model output the boxes and class predictions.
  This function is particularly helpful before applying NonMaxSuppression layer.

  Args:
  - scale: A tensor of shape (Batch, grid_scale_x, grid_scale_y, 3 * (4 + 1 + NUM_CLASSES))

  Returns:
  - box_prediction: A tensor of shape (Batch, N, 4)
  - class_prediction: A tensor of shape (Batch, N, NUM_CLASSES)
  """
  batch_size, grid_scale_x, grid_scale_y, _ = shape(scale)
  grid_x = arange(0, grid_scale_x, dtype="float32")
  grid_y = arange(0, grid_scale_y, dtype="float32")
  grid_xs, grid_ys = meshgrid(grid_x, grid_y)
  grid_xy = stack([grid_xs, grid_ys], axis=-1) # shape: (Batch, grid_scale_x, grid_scale_y, 2)
  grid_xy = expand_dims(grid_xy, axis=-2) # shape: (Batch, grid_scale_x, grid_scale_y, 1, 2)

  scale = reshape(scale, newshape=(batch_size, grid_scale_x, grid_scale_y, 3, -1))
  scale_xy = sigmoid(scale[..., 0:2])
  scale_wh = exp(scale[..., 2:4]) * anchors
  scale_confidence = sigmoid(scale[..., 4:5])
  scale_cls = softmax(scale[..., 5:])

  boxes_xy = (scale_xy + grid_xy) / (grid_scale_x, grid_scale_y)
  boxes_wh = scale_wh
  boxes_xywh = concatenate([boxes_xy, boxes_wh], axis=-1) * IMAGE_DIMENSION

  classes = scale_confidence * scale_cls

  box_prediction = reshape(
      boxes_xywh,
      newshape=(batch_size, grid_scale_x * grid_scale_y * 3, -1)
      )
  class_prediction = reshape(
      classes,
      newshape=(batch_size, grid_scale_x * grid_scale_y * 3, -1)
      )

  return box_prediction, class_prediction

In [29]:
def visualize_progress(model, confidence_threshold=0.7, iou_threshold=0.1, skip=10):
  """
  Tracking down by visualization of the prediction of the model the progress
  amid training.

  Args:
  - model: the model being trained

  """
  sample = next(iter(eval_ds.skip(skip).take(1)))
  resize_layer = Resizing(IMAGE_DIMENSION, IMAGE_DIMENSION, bounding_box_format="center_xywh", pad_to_aspect_ratio=True)

  image = sample["image"]
  classes = sample["objects"]["label"]
  boxes_rel_yxyx = sample["objects"]["bbox"]

  boxes = convert_format(
      boxes=boxes_rel_yxyx,
      source="rel_yxyx",
      target="center_xywh",
      images=image
  )

  bounding_boxes_true = {
      "classes": expand_dims(classes, axis=0),
      "boxes": expand_dims(boxes, axis=0)
      }

  x = {
      "images": expand_dims(image, axis=0),
      "bounding_boxes": bounding_boxes_true
      }

  sample_resized = resize_layer(x)

  input = preprocess_input(sample_resized["images"])
  scale_1, scale_2, scale_3 = model(input)

  anchors_1 = np.array(ANCHORS[0])/IMAGE_DIMENSION
  anchors_2 = np.array(ANCHORS[1])/IMAGE_DIMENSION
  anchors_3 = np.array(ANCHORS[2])/IMAGE_DIMENSION

  box_prediction_1, class_prediction_1 = output_to_bbox(scale_1, anchors=anchors_1)
  box_prediction_2, class_prediction_2 = output_to_bbox(scale_2, anchors=anchors_2)
  box_prediction_3, class_prediction_3 = output_to_bbox(scale_3, anchors=anchors_3)

  box_prediction = concatenate([box_prediction_1, box_prediction_2, box_prediction_3], axis=1)
  class_prediction = concatenate([class_prediction_1, class_prediction_2, class_prediction_3], axis=1)

  nms = MultiClassNonMaxSuppression(
      bounding_box_format="center_xywh",
      from_logits=False,
      confidence_threshold=confidence_threshold,
      iou_threshold=iou_threshold
      )

  bounding_boxes_pred = nms(box_prediction, class_prediction)
  class_mapping = dict(zip(range(len(CLASSES)), CLASSES))

  plot_bounding_box_gallery(
      sample_resized["images"],
      value_range=(0, 255),
      rows=1,
      cols=1,
      y_true=sample_resized["bounding_boxes"],
      scale=5,
      bounding_box_format="center_xywh",
      font_scale=0.5,
      class_mapping=class_mapping
  )

  plot_bounding_box_gallery(
      sample_resized["images"],
      value_range=(0, 255),
      rows=1,
      cols=1,
      # y_true=sample_resized["bounding_boxes"],
      y_pred=bounding_boxes_pred,
      scale=5,
      bounding_box_format="center_xywh",
      font_scale=0.5,
      class_mapping=class_mapping
  )

# Dataset

In [30]:
train_ds, eval_ds = load(name="voc/2007", split=["train", 'validation'], with_info=False, shuffle_files=True)

In [31]:
train_ds = train_ds.map(lambda sample: unpackage_raw_sample(sample), num_parallel_calls=tf_data.AUTOTUNE)
eval_ds = eval_ds.map(lambda sample: unpackage_raw_sample(sample), num_parallel_calls=tf_data.AUTOTUNE)

In [32]:
train_ds_augmenters = [
    RandomFlip(
        mode="horizontal",
        bounding_box_format="center_xywh"
        ),
    JitteredResize(
        target_size=(IMAGE_DIMENSION, IMAGE_DIMENSION),
        scale_factor=(0.65, 1.25),
        bounding_box_format="center_xywh"
        )
]

In [33]:
eval_ds_augmenters = [Resizing(IMAGE_DIMENSION,
                               IMAGE_DIMENSION,
                               bounding_box_format="center_xywh",
                               pad_to_aspect_ratio=True)]

In [34]:
train_ds = train_ds.map(preprocessing_wrapper(train_ds_augmenters), num_parallel_calls=tf_data.AUTOTUNE)
eval_ds = eval_ds.map(preprocessing_wrapper(eval_ds_augmenters), num_parallel_calls=tf_data.AUTOTUNE)

In [35]:
train_ds = train_ds.shuffle(BATCH_SIZE * 4).batch(BATCH_SIZE).prefetch(tf_data.AUTOTUNE)
eval_ds = eval_ds.batch(BATCH_SIZE).prefetch(tf_data.AUTOTUNE)

# Model Building

In [36]:
class YoloConvBlock(Layer):
  """
  A standart Yolov3 conv block. Used solely in the output layer.
  Essentialy, this is a helper layer
  """

  def __init__(self, filters=256, kernel_size=3, num_classes=20, num_anchors=3, **kwargs):
    super().__init__(**kwargs)
    self.conv1 = Conv2D(
        filters=filters,
        kernel_size=kernel_size,
        padding="same",
        use_bias=False,
        kernel_regularizer=l2(1e-5)
        )
    self.conv2 = Conv2D(
        filters=num_anchors*(num_classes + 5),
        kernel_size=1,
        padding="same",
        kernel_regularizer=l2(1e-5)
        )
    self.bn = BatchNormalization()
    self.activation = LeakyReLU(negative_slope=0.1)

  def call(self, x):
    x = self.activation(self.bn(self.conv1(x)))
    x = self.conv2(x)

    return x

In [37]:
class YoloTransformerEncoderBlock(Layer):
  """
  Integrates a Transformer Encoder into YOLO for capturing long-range dependencies
  in object detection. This layer employs multi-head self-attention and
  feedforward networks, aiming to enhance detection accuracy by understanding
  global context. It's an experimental approach to balance precision and
  inference speed in complex image scenes.
  """

  def __init__(self, grid_size, num_anchors=3, num_classes=20, num_heads=8, **kwargs):
    super().__init__(**kwargs)
    self.positions = arange(start=0, stop=grid_size**2, step=1)
    self.position_embedding = Embedding(input_dim=grid_size**2, output_dim=num_anchors*(num_classes + 5))
    self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=num_anchors*(num_classes + 5), dropout=0.1)
    self.ln1 = LayerNormalization()
    self.ln2 = LayerNormalization()
    self.dense1 = Dense(units=2*num_anchors*(4 + 1 + num_classes), activation=gelu)
    self.dense2 = Dense(units=num_anchors*(4 + 1 + num_classes), activation=gelu)
    self.drop1 = Dropout(0.1)
    self.drop2 = Dropout(0.1)


  def call(self, x):
    positions = expand_dims(self.position_embedding(self.positions), axis=0)
    x = x + positions

    x = self.mha(x, x) + x
    x = self.ln1(x)

    x1 = self.drop1(self.dense1(x))
    x1 = self.drop2(self.dense2(x1))

    x = x + x1
    x = self.ln2(x)

    return x

In [38]:
class YoloConvSelfAttention(Layer):
  """
  The final layer applied on each of the scales
  """

  def __init__(self, grid_size, num_classes, num_anchors=3, filters=256, kernel_size=3, num_heads=8, **kwargs):
    super().__init__(**kwargs)
    # model layers
    self.conv_block = YoloConvBlock(filters, kernel_size, num_classes, num_anchors)
    self.encoder_block =  YoloTransformerEncoderBlock(grid_size, num_anchors, num_classes, num_heads)
    self.bn = BatchNormalization()
    self.drop = Dropout(0.2)
    self.pred = Conv2D(filters=num_anchors*(4 + 1 + num_classes), kernel_size=1)

  def call(self, x):
    x = self.conv_block(x)

    b, g_x, g_y, out = shape(x)   # X shape: (batch_size, grid_size, grid_size, num_anchors * (4 + 1 + num_classes))
    x1 = reshape(x, [b, g_x * g_y, out])   # Reshaping to (batch_size, grid_size * grid_size, self.output_dim)

    x1 = self.encoder_block(x1)
    x = reshape(x1, [b, g_x, g_y, out]) + x  # Reshaping to original shape

    x = self.drop(self.bn(x))
    x = self.pred(x)

    return x

In [39]:
class YoloSelfAttention(Model):
  """
  End-to-end yolov3 model with self attention mechanism. This model uses ResNet50
  as it's backbone, with classic IMAGE_DIMENSION X IMAGE_DIMENSION X 3 input shape.
  The model outputs three outputs (scales): (batch_size, 13 | 26 | 52, 13 | 26 | 52, 3 * (4 + 1 + NUM_CLASSES))
  """

  def __init__(self, num_anchors=3, num_classes=NUM_CLASSES, **kwargs):
    super().__init__(**kwargs)

    base_model = ResNet50(
      include_top=False,
      weights="imagenet",
      input_shape=(IMAGE_DIMENSION, IMAGE_DIMENSION, 3),
      )
    intermediate_layers_output = [base_model.get_layer("conv3_block4_out").output, base_model.get_layer("conv4_block6_out").output, base_model.get_layer("conv5_block3_out").output]
    self.backbone = Model(inputs=base_model.input, outputs=intermediate_layers_output)
    self.freeze_backbone()

    self.scale1 = YoloConvSelfAttention(grid_size=IMAGE_DIMENSION // 32, num_classes=num_classes, num_anchors=num_anchors)
    self.scale2 = YoloConvSelfAttention(grid_size=IMAGE_DIMENSION // 16, num_classes=num_classes, num_anchors=num_anchors)
    self.scale3 = YoloConvSelfAttention(grid_size=IMAGE_DIMENSION // 8, num_classes=num_classes, num_anchors=num_anchors)


  def call(self, x, training=False):
    s3, s2, s1 = self.backbone(x, training=training)   # output shape: (N, 52 | 26 | 13, 52 | 26 | 13, 512 | 1024 | 2048)

    s1 = self.scale1(s1)
    s2 = self.scale2(s2)
    s3 = self.scale3(s3)

    return s1, s2, s3

  def freeze_backbone(self):
    """
    Freeze the backbone trainable parameters. Use when transfer learning.
    """
    self.backbone.trainable = False

  def unfreeze_backbone(self):
    """
    Unreeze the backbone trainable parameters. Use when fine-tuning.
    """
    self.backbone.trainable = True


In [40]:
def yolo_loss_wrraper(anchors):
  def yolo_loss(y_true, y_pred, lambda_coord=10, lambda_obj=1, lambda_no_obj=1e-1, lambda_cls=1):
    """
    Classic yolo loss function. In this function we assume y_true correspond
    to one scale only.

    Args:
    - y_true: A tensor of shape (batch_size, scale_grid_scale, scale_grid_scale, 3, 4 + 1 + 1)
    - y_pred: A tensor of shape (batch_size, scale_grid_scale, scale_grid_scale, 3 * (4 + 1 + NUM_CLASSES))

    Returns:
    - loss value
    """
    batch_size, grid_scale_x, grid_scale_y, _ = shape(y_pred)
    y_pred = reshape(y_pred, newshape=(batch_size, grid_scale_x, grid_scale_y, 3, -1))

    # coordinates loss
    y_pred_xy = sigmoid(y_pred[..., 0:2])
    y_pred_wh = exp(y_pred[..., 2:4]) * anchors
    y_true_xy = y_true[..., 0:2]
    y_true_wh = y_true[..., 2:4]

    # object & no object masks
    object_mask = y_true[..., 4:5]

    grid_x = arange(0, grid_scale_x, dtype="float32")
    grid_y = arange(0, grid_scale_y, dtype="float32")
    grid_xs, grid_ys = meshgrid(grid_x, grid_y)
    grid_xy = stack([grid_xs, grid_ys], axis=-1) # shape: (grid_scale_x, grid_scale_y, 2)
    grid_xy = expand_dims(grid_xy, axis=-2) # shape: (grid_scale_x, grid_scale_y, 1, 2)

    y_pred_xy_non_cell_rel = (y_pred_xy + grid_xy) / (grid_scale_x, grid_scale_y)
    y_true_xy_non_cell_rel = (y_true_xy + grid_xy) / (grid_scale_x, grid_scale_y)

    y_pred_xywh = concatenate([y_pred_xy_non_cell_rel, y_pred_wh], axis=-1)
    y_true_xywh = concatenate([y_true_xy_non_cell_rel, y_true_wh], axis=-1)

    iou = compute_iou(
        reshape(y_pred_xywh, newshape=(-1, 4)),
        reshape(y_true_xywh, newshape=(-1, 4)),
        bounding_box_format="center_xywh"
        ) # shape: (batch_size * grid_scale_x * grid_scale_y * 3, ...)

    iou = reshape(iou, newshape=(batch_size, grid_scale_x, grid_scale_y, 3, -1))
    iou_max = max(iou, axis=-1, keepdims=True)

    no_object_mask = where(iou_max < 0.6, 1 - object_mask, zeros_like(object_mask))

    xy_loss = lambda_coord * sum(object_mask * (square(y_pred_xy - y_true_xy) + \
                                                square(sqrt(y_pred_wh) - sqrt(y_true_wh))))

    # object confidence loss
    y_pred_confidence = sigmoid(y_pred[..., 4:5])
    y_true_confidence = y_true[..., 4:5]
    bce_output = binary_crossentropy(y_true_confidence, y_pred_confidence) # shape: (batch_size, grid_scale_x, grid_scale_y, 3, 1)
    confidence_loss = lambda_obj * sum(object_mask * bce_output) + \
                      lambda_no_obj * sum(no_object_mask * bce_output)

    # classification loss
    y_pred_cls = y_pred[..., 5:]
    y_true_cls = y_true[..., 5]
    sc_ce_output = expand_dims(sparse_categorical_crossentropy(y_true_cls, y_pred_cls, from_logits=True), axis=-1) # shape: (grid_scale_x, grid_scale_y, 3, 1)
    classification_loss = lambda_cls * sum(object_mask * sc_ce_output)

    return xy_loss + confidence_loss + classification_loss

  return yolo_loss

# Model Training

In [41]:
optimizer = keras.optimizers.SGD(
    learning_rate=1e-3,
    momentum=0.9,
    global_clipnorm=10.0
    )

In [42]:
loss_1 = yolo_loss_wrraper(np.array(ANCHORS[0], dtype="float64") / IMAGE_DIMENSION)
loss_2 = yolo_loss_wrraper(np.array(ANCHORS[1], dtype="float64") / IMAGE_DIMENSION)
loss_3 = yolo_loss_wrraper(np.array(ANCHORS[2], dtype="float64") / IMAGE_DIMENSION)

In [43]:
callbacks = [ReduceLROnPlateau(factor=0.2, patience=3, min_delta=1)]

In [None]:
model = YoloSelfAttention()

In [45]:
model.compile(
    optimizer=optimizer,
    loss=[loss_1, loss_2, loss_3]
)

In [None]:
model.fit(
    train_ds,
    validation_data=eval_ds.take(250),
    epochs=30,
    callbacks=callbacks,
)

Epoch 1/30


# Model Fine Tuning

In [None]:
model.unfreeze_backbone()

In [None]:
optimizer = keras.optimizers.Adam(
    learning_rate=1e-5,
    global_clipnorm=10.0
    )

In [None]:
loss_1 = yolo_loss_wrraper(np.array(ANCHORS[0], dtype="float64") / IMAGE_DIMENSION)
loss_2 = yolo_loss_wrraper(np.array(ANCHORS[1], dtype="float64") / IMAGE_DIMENSION)
loss_3 = yolo_loss_wrraper(np.array(ANCHORS[2], dtype="float64") / IMAGE_DIMENSION)

In [None]:
callbacks = [ReduceLROnPlateau(factor=0.2, patience=2, min_delta=1)]

In [None]:
model.compile(
    optimizer=optimizer,
    loss=[loss_1, loss_2, loss_3]
)

In [None]:
model.fit(
    train_ds,
    validation_data=eval_ds.take(250),
    epochs=20,
    callbacks=callbacks,
)

In [None]:
model.save_weights("yolov3_sa_voc2007.weights.h5")

# Visual tests

In [None]:
train_ds, eval_ds = load(name="voc/2007", split=["train", 'validation'], with_info=False, shuffle_files=True)

In [None]:
visualize_progress(model, confidence_threshold=0.5, iou_threshold=0.2, skip=500)

# Sanity Check

In [None]:
# Useful function for this checks section

def scale_to_pred(scale, anchors):
  grid_size, _, _, _ = shape(scale)
  scale_xy = log(scale[..., 0:2] / (1- scale[..., 0:2]))
  scale_wh = log(scale[..., 2:4] / anchors)
  scale_confidence = where(scale[..., 4:5] == 1, 100, -100)
  scale_cls = one_hot(cast(scale[..., 5], "int32"), NUM_CLASSES) * 200 - 100
  scale_pred = concatenate([scale_xy, scale_wh, scale_confidence, scale_cls], axis=-1)
  return reshape(scale_pred, newshape=(grid_size, grid_size, -1))

In [None]:
# Check if bbox_to_output_vec == bbox_to_output_loop & functions are correctly implemented

inputs = {"classes": tf.constant([1, 2, 3]), "boxes": tf.constant([[0.35, 0.35, 0.2, 0.8], [0.38, 0.38, 0.3, 0.8], [0.24038462, 0.36057692, 0.12019231, 0.18028846]])}
scale_1, scale_2, scale_3 = bbox_to_output_vec(**inputs)
scale_1_l, scale_2_l, scale_3_l = bbox_to_output_loop(**inputs)
tf.reduce_all(equal(scale_1, scale_1_l)), tf.reduce_all(equal(scale_2, scale_2_l)), tf.reduce_all(equal(scale_3, scale_3_l))

In [None]:
# Check the postproccessing step

def post_proccessing(scale_1, scale_2, scale_3):
  anchors_1 = np.array(ANCHORS[0])/IMAGE_DIMENSION
  anchors_2 = np.array(ANCHORS[1])/IMAGE_DIMENSION
  anchors_3 = np.array(ANCHORS[2])/IMAGE_DIMENSION

  scale_1_pred = scale_to_pred(scale_1, anchors_1)
  scale_2_pred = scale_to_pred(scale_2, anchors_2)
  scale_3_pred = scale_to_pred(scale_3, anchors_3)

  scale_1_reshaped = reshape(scale_1_pred, newshape=(1, GRID_SIZES[0], GRID_SIZES[0], -1))
  scale_2_reshaped = reshape(scale_2_pred, newshape=(1, GRID_SIZES[1], GRID_SIZES[1], -1))
  scale_3_reshaped = reshape(scale_3_pred, newshape=(1, GRID_SIZES[2], GRID_SIZES[2], -1))

  box_prediction_1, class_prediction_1 = output_to_bbox(scale_1_reshaped, anchors=anchors_1)
  box_prediction_2, class_prediction_2 = output_to_bbox(scale_2_reshaped, anchors=anchors_2)
  box_prediction_3, class_prediction_3 = output_to_bbox(scale_3_reshaped, anchors=anchors_3)

  box_prediction = concatenate([box_prediction_1, box_prediction_2, box_prediction_3], axis=1)
  class_prediction = concatenate([class_prediction_1, class_prediction_2, class_prediction_3], axis=1)

  nms = MultiClassNonMaxSuppression(
      bounding_box_format="center_xywh",
      from_logits=False,
      confidence_threshold=0.3
      )

  return nms(box_prediction, class_prediction)

post_proccessing(scale_1, scale_2, scale_3)

In [None]:
def temp_preprocessing_wrapper(augmenters):
  def temp_preprocessing(x):
    """
    Tranforming a valid KerasCV format sample to be a proper input and label for model trainig.

    Args:
    - x: a valid KerasCV format.

    Returns:
    - image, labels: A tuple of the preprocessed image and (scale_1, scale_2, scale_3) label.
    """

    # apply augmentations
    for augmenter in augmenters:
      x = augmenter(x)

    image = x["images"]

    x["bounding_boxes"] = to_dense(x["bounding_boxes"])
    x["bounding_boxes"]["boxes"] = x["bounding_boxes"]["boxes"] / IMAGE_DIMENSION
    labels = bbox_to_output_vec(**x["bounding_boxes"])
    x["bounding_boxes"]["boxes"] = x["bounding_boxes"]["boxes"] * IMAGE_DIMENSION

    return image, labels, x["bounding_boxes"]

  return temp_preprocessing

In [None]:
# Check visually the pipeline correctness

sample = next(iter(train_ds.skip(0).take(1)))

process_sample_fn = temp_preprocessing_wrapper(eval_ds_augmenters)
image, (scale_1, scale_2, scale_3), bounding_boxes_true = process_sample_fn(sample)

bounding_boxes_true = {
    "classes": expand_dims(bounding_boxes_true["classes"], axis=0),
    "boxes": expand_dims(bounding_boxes_true["boxes"], axis=0),
    }

bounding_boxes_pred = post_proccessing(scale_1, scale_2, scale_3)

class_mapping = dict(zip(range(len(CLASSES)), CLASSES))
plot_bounding_box_gallery(
    expand_dims(image, axis=0),
    value_range=(0, 255),
    rows=1,
    cols=1,
    y_true=bounding_boxes_true,
    y_pred=bounding_boxes_pred,
    scale=5,
    bounding_box_format="center_xywh",
    font_scale=0.6,
    class_mapping=class_mapping
)