In [2]:
import sys
IN_COLAB = 'google.colab' in sys.modules

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    path_images = "./drive/MyDrive/obj_Train_data"
    path_annot = "./drive/MyDrive/obj_Train_data"
else:
    from kestrix.data import download_raw_data
    download_raw_data()
    path_images = "./data/kestrix/raw"
    path_annot = "./data/kestrix/raw"

In [1]:
#!pip install keras-cv

In [None]:
import os
from pathlib import Path
from tqdm.auto import tqdm # show progress bars

import tensorflow as tf
from tensorflow import keras
import keras_cv

from keras_cv import bounding_box
from keras_cv import visualization

In [None]:
# Hyperparameters
SPLIT_RATIO = 0.2
BATCH_SIZE = 4
LEARNING_RATE = 0.001
EPOCH = 5
GLOBAL_CLIPNORM = 10.0

In [None]:
bounding_box_format = "center_xywh"

In [None]:
# Creating a dictionary for the classes
class_ids = [
    'car',
    'person',
]

class_mapping = dict(zip(range(len(class_ids)),class_ids))
class_mapping

In [None]:
txt_files = sorted(
    [
        os.path.join(path_annot, file_name)
        for file_name in os.listdir(path_annot)
        if file_name.endswith(".txt")
    ]
)
txt_files[:5]

In [None]:
def parse_annotation(txt_file):
    with open(txt_file) as file:
        lines = file.readlines()
        file_name = Path(file.name).stem

    image_path = os.path.join(path_images, file_name + ".JPG")
    boxes = []
    class_ids = []
    for line in lines:
        line = line.split()

        cls = float(line[0])
        class_ids.append(cls)

        x = float(line[1]) * 4000
        y = float(line[2]) * 3000
        width = float(line[3]) *4000
        height = float(line[4]) * 3000

        boxes.append([x, y, width, height])

    return image_path, boxes, class_ids

In [None]:
image_paths = []
bbox = []
classes = []
for txt_file in txt_files:
    image_path, boxes, class_ids = parse_annotation(txt_file)
    image_paths.append(image_path)
    bbox.append(boxes)
    classes.append(class_ids)

In [None]:
classes[:4]

In [None]:
bbox[:4]

In [None]:
# creating ragged tensors because the number of objects varies
# from image to image
bbox = tf.ragged.constant(bbox)
classes = tf.ragged.constant(classes)
image_paths = tf.ragged.constant(image_paths)

data = tf.data.Dataset.from_tensor_slices((image_paths, classes, bbox))

In [None]:
# Splitting data
# Determine number of validation data
num_val = int(len(txt_files) * SPLIT_RATIO)

# split into train and validation
# TODO change into random split via train_test_split
val_data = data.take(num_val)
train_data = data.skip(num_val)

In [None]:
def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    return image


def load_dataset(image_path, classes, bbox):
    # Read Image
    image = load_image(image_path)
    bounding_boxes = {
        "classes": tf.cast(classes, dtype=tf.float32),
        "boxes": bbox,
    }
    return {"images": tf.cast(image, dtype=tf.float32),
            "bounding_boxes": bounding_boxes}

## Image Augmentation 
https://keras.io/api/keras_cv/layers/augmentation/

In [None]:
augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.RandomFlip(
            mode="horizontal",
            bounding_box_format=bounding_box_format),
        keras_cv.layers.RandomShear(
            x_factor=0.2,
            y_factor=0.2,
            bounding_box_format=bounding_box_format
        ),
        keras_cv.layers.JitteredResize(
        target_size=(640, 640), scale_factor=(0.75, 1.3), bounding_box_format=bounding_box_format
    ),
    ]
)

In [None]:
resizing = keras_cv.layers.Resizing(
    640, 640,
    bounding_box_format=bounding_box_format,
    pad_to_aspect_ratio=True
)

## Create Training Set

In [None]:
# create training dataset
train_ds = train_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(BATCH_SIZE * 4)
train_ds = train_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)


In [None]:
# create validation dataset
val_ds = val_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.shuffle(BATCH_SIZE * 4)
val_ds = val_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
val_ds = val_ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)

# Visualization

In [None]:
def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format):
    inputs = next(iter(inputs.take(1)))
    images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
    visualization.plot_bounding_box_gallery(
        images,
        value_range=value_range,
        rows=rows,
        cols=cols,
        y_true=bounding_boxes,
        scale=5,
        font_scale=0.7,
        bounding_box_format=bounding_box_format,
        class_mapping=class_mapping,
    )

In [None]:
visualize_dataset(
    train_ds, bounding_box_format=bounding_box_format, value_range=(0, 255), rows=2, cols=2
)

visualize_dataset(
    val_ds, bounding_box_format=bounding_box_format, value_range=(0, 255), rows=2, cols=2
)

In [None]:
def dict_to_tuple(inputs):
    return inputs["images"], bounding_box.to_dense(
        inputs["bounding_boxes"], max_boxes=32
    )


train_ds = train_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)

# Creating the model

We can switch the backbone by using `keras_cv.models.YOLOV8Detector.from_preset` and another [preset](https://keras.io/api/keras_cv/models/tasks/yolo_v8_detector/)

In [None]:
 # We will use yolov8 small backbone with coco weights
backbone = keras_cv.models.YOLOV8Backbone.from_preset(
    "yolo_v8_s_backbone_coco"
)

Todo: fine tune decoder

In [None]:

prediction_decoder = keras_cv.layers.NonMaxSuppression(
    bounding_box_format=bounding_box_format,
    from_logits=True,
    iou_threshold=0.2,
    confidence_threshold=0.7,
)

The `NonMaxSuppression` layer is responsible to prune underconfident boxes. Raising the `confidence_threshold` will cause the model to only output boxes that have a higher confidence score. `iou_threshold` controls the threshold of intersection over union (IoU) that two boxes must have in order for one to be pruned out.

Next, let's build a YOLOV8 model using the `YOLOV8Detector`, which accepts a feature extractor as the `backbone` argument, a `num_classes` argument that specifies the number of object classes to detect based on the size of the `class_mapping` list, a `bounding_box_format` argument that informs the model of the format of the bbox in the dataset, and a finally, the feature pyramid network (FPN) depth is specified by the `fpn_depth` argument.

# Compilation
Loss used for YOLOV8

1. Classification Loss: This loss function calculates the discrepancy between anticipated class probabilities and actual class probabilities. In this instance, `binary_crossentropy`, a prominent solution for binary classification issues, is Utilized. We Utilized binary crossentropy since each thing that is identified is either classed as belonging to or not belonging to a certain object class (such as a person, a car, etc.).

2. Box Loss: `box_loss` is the loss function used to measure the difference between the predicted bounding boxes and the ground truth. In this case, the Complete IoU (CIoU) metric is used, which not only measures the overlap between predicted and ground truth bounding boxes but also considers the difference in aspect ratio, center distance, and box size. Together, these loss functions help optimize the model for object detection by minimizing the difference between the predicted and ground truth class probabilities and bounding boxes.

In [None]:
yolo = keras_cv.models.YOLOV8Detector(
    num_classes=len(class_mapping),
    bounding_box_format=bounding_box_format,
    backbone=backbone,
    fpn_depth=1,
   # prediction_decoder=prediction_decoder
)

You will always want to include a global_clipnorm when training object detection models. This is to remedy exploding gradient problems that frequently occur when training object detection models.

In [None]:
optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE,
    global_clipnorm=GLOBAL_CLIPNORM,
)

To achieve the best results on your dataset, you'll likely want to hand craft a `PiecewiseConstantDecay` learning rate schedule.

In [None]:
yolo.compile(
    optimizer=optimizer, classification_loss="binary_crossentropy", box_loss="ciou"
)

## COCO Metric Callback
The most popular object detection metrics are COCO metrics, which were published alongside the MSCOCO dataset. KerasCV provides an easy-to-use suite of COCO metrics under the `keras_cv.callbacks.PyCOCOCallback` symbol. Note that we use a Keras callback instead of a Keras metric to compute COCO metrics. This is because computing COCO metrics requires storing all of a model's predictions for the entire evaluation dataset in memory at once, which is impractical to do during training time.

# Train model

In [None]:
coco_metrics_callback = keras_cv.callbacks.PyCOCOCallback(
    val_ds,
    bounding_box_format)

yolo.fit(
    train_ds,
    validation_data=val_ds,
    epochs=2,
    callbacks=[coco_metrics_callback],
)

In [None]:
yolo.save("./drive/MyDrive/model.keras")

# Visualize Predictions

In [None]:
def visualize_detections(model, dataset, bounding_box_format):
    images, y_true = next(iter(dataset.take(1)))
    y_pred = model.predict(images)
    y_pred = bounding_box.to_ragged(y_pred)
    visualization.plot_bounding_box_gallery(
        images,
        value_range=(0, 255),
        bounding_box_format=bounding_box_format,
        y_true=y_true,
        y_pred=y_pred,
        scale=4,
        rows=2,
        cols=2,
        show=True,
        font_scale=0.7,
        class_mapping=class_mapping,
    )


visualize_detections(yolo, dataset=val_ds, bounding_box_format=bounding_box_format)

## Function to visualize a prediction

In [None]:
y_pred = pretrained_model.predict(image_batch)
visualization.plot_bounding_box_gallery(
    image_batch,
    value_range=(0, 255),
    rows=1,
    cols=1,
    y_pred=y_pred,
    scale=5,
    font_scale=0.7,
    bounding_box_format=bounding_box_format,
    class_mapping=class_mapping,
)