In [1]:
import glob
import os

import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display
from PIL import Image

In [2]:
def glob_sorted(pattern):
    return sorted(glob.glob(pattern))

In [26]:
classes_df = pd.read_csv(
    "/Users/javkhlan-ochirganbat/fiftyone/open-images-v7/train/metadata/classes.csv",
    header=None,
)
classes_df.columns = ["id", "label"]
shrimp_id = classes_df[classes_df["label"] == "Shrimp"]["id"].values[0]
shrimp_id

'/m/0ll1f78'

In [3]:
reload_dataset = False
if reload_dataset:
    import fiftyone.zoo as foz

    import fiftyone as fo

    dataset = foz.load_zoo_dataset(
        "open-images-v7",
        split="test",
        label_types=["detections", "points"],
        classes=["Shrimp"],
    )

In [58]:
dataset_dir = "./data/open-images-v7/test"
detections_df = pd.read_csv(f"{dataset_dir}/detections.csv")
detections_df.columns = [elem.lower() for elem in detections_df.columns]

In [59]:
img_list = glob_sorted("./data/open-images-v7/test/data/*jpg")

In [55]:
# img_names = [img.split("/")[-1].removesuffix(".jpg") for img in img_list]
# fltrd_detection_df = detections_df[detections_df["imageid"].isin(img_names)]
# fltrd_detection_df.to_csv("./data/open-images-v7/test/detections.csv", index=None)

In [63]:
import matplotlib.patches as patches


def get_pixels(img_file):
    img = Image.open(img_file)
    pixels = np.asarray(img)
    return pixels


def get_rect_box(rel_xmin, rel_xmax, rel_ymin, rel_ymax, width, height):
    xmin = width * rel_xmin
    ymin = height * rel_ymin
    patch_width = (rel_xmax - rel_xmin) * width
    patch_height = (rel_ymax - rel_ymin) * height
    # Create a Rectangle patch
    rect_patch = patches.Rectangle(
        (xmin, ymin),
        patch_width,
        patch_height,
        linewidth=1,
        edgecolor="r",
        facecolor="none",
    )

    return rect_patch


def get_img_with_bounding_box(
    img_file, detections_df, outdir_for_label, isgroupof=0, isoccluded=0, show_img=False
):
    img_basename = os.path.basename(img_file).removesuffix(".jpg")
    pixels = get_pixels(img_file)
    height, width, _ = pixels.shape
    selected_img_detection_df = detections_df[detections_df["imageid"] == img_basename]

    rect_patch_list = []
    label_lines = []
    with open(
        os.path.join(outdir_for_label, f"{img_basename}.txt"),
        "w",
    ) as label_outfile:
        for idx, row in selected_img_detection_df.iterrows():
            if row["labelname"] != shrimp_id:
                continue
            # if row["isgroupof"] != isgroupof:
            #     continue
            # if row["isoccluded"] != isoccluded:
            #     continue

            rel_xmin, rel_xmax, rel_ymin, rel_ymax = (
                row["xmin"],
                row["xmax"],
                row["ymin"],
                row["ymax"],
            )

            label_line = f"0 {(rel_xmin + rel_xmax) / 2} {(rel_ymin+rel_ymax)/2} {rel_xmax-rel_xmin} {rel_ymax-rel_ymin}"
            label_outfile.write(label_line + "\n")

            if show_img:
                rect_patch_list.append(
                    get_rect_box(rel_xmin, rel_xmax, rel_ymin, rel_ymax, width, height)
                )

    if show_img:
        fig, ax = plt.subplots()
        ax.imshow(pixels)
        # Add the patch to the Axes
        for rect in rect_patch_list:
            ax.add_patch(rect)
        plt.show()

In [65]:
for idx, img_file in enumerate(img_list):
    get_img_with_bounding_box(
        img_file,
        detections_df,
        isgroupof=0,
        isoccluded=1,
        outdir_for_label="data/open-images-v7/test/labels/",
        show_img=False
    )

In [None]:
"""
Training and fine tuning
"""
from ultralytics import YOLO, settings

settings.update({'datasets_dir': './data'})


def train_custom_model(epochs: int):
    """
    Training YOLO v8 
    """
    # load a pretrained model (recommended for training)
    # build from YAML and transfer weights
    model = YOLO("yolov8n.yaml").load("yolov8n.pt")
    model.info()

    # Train the model
    model.train(
        data="./YOLO_configs/openimage_shrimp.yaml",
        epochs=epochs,
        device="mps",
        fraction=0.5,
        # imgsz=(480, 848),
        verbose=True,
        weight_decay=0.01,
        lr0=0.1,
        dropout=0.1,
    )


train_custom_model(epochs=100)



                   from  n    params  module                                       arguments                     
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]                 
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]                
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]             
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]                
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]             
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]               
  6                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]           
  7                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128

GETTING A DEVICE: mps
HAS MPS= True
TORCH MPS AVAILABLE: True
TORCH_VERSION: True
FOUND MPS DEVICE!!!


YOLOv8n summary: 225 layers, 3011043 parameters, 3011027 gradients, 8.2 GFLOPs

Transferred 319/355 items from pretrained weights
[34m[1mTensorBoard: [0mStart with 'tensorboard --logdir runs/detect/train12', view at http://localhost:6006/
Freezing layer 'model.22.dfl.conv.weight'
[34m[1mtrain: [0mScanning /Users/javkhlan-ochirganbat/repos/machine-learning/underwater_object/data/open-images-v7/train/labels/shrimp... 308 images, 0 backgrounds, 0 corrupt: 100%|██████████| 308/308 [00:00<00:00, 5014.73it/s][0m
[34m[1mtrain: [0mNew cache created: /Users/javkhlan-ochirganbat/repos/machine-learning/underwater_object/data/open-images-v7/train/labels/shrimp.cache
[34m[1mval: [0mScanning /Users/javkhlan-ochirganbat/repos/machine-learning/underwater_object/data/open-images-v7/train/labels/shrimp... 617 images, 0 backgrounds, 0 corrupt: 100%|██████████| 617/617 [00:00<00:00, 5410.77it/s][0m
[34m[1mval: [0mNew cache created: /Users/javkhlan-ochirganbat/repos/machine-learning/underw