In [None]:
from ultralytics import YOLO
import torch

# Train the Ultralytics YOLOv8l model on the Zhousidun data

In [None]:
# Load a COCO-pretrained YOLOv8n model
model = YOLO('yolov8l.pt')

gpus = torch.cuda.device_count()

# Train the model on the dataset for 100 epochs
results = model.train(data='/home/ritwik/data/zeus/yolov8/data.yaml',
    epochs=200,
    imgsz=640,
    augment=True,
    device=list(range(gpus)),
    batch=-1)

# Evaluate the model on the synthetic Blender scene

In [None]:
from ultralytics import YOLO
model = YOLO('runs/detect/yolo_best_200ep/weights/best.pt')
model.val(data='/home/ritwik/data/blender_boat/data_valid.yaml', imgsz=640, plots=True)

In [None]:
model.val(data='/home/ritwik/data/blender_boat/data_oblique.yaml', imgsz=640, plots=True)

In [None]:
model.val(data='/home/ritwik/data/blender_boat/data_satellite.yaml', imgsz=640, plots=True)

# Various plotting functions

In [None]:
def yolo_to_xywh(cx, cy, nw, nh, img_width, img_height):
    x = (cx - nw / 2) * img_width
    y = (cy - nh / 2) * img_height
    w = nw * img_width
    h = nh * img_height

    return x, y, w, h

In [None]:
from pathlib import Path
from ultralytics.utils.plotting import Annotator
from PIL import Image
import numpy as np
from matplotlib.patches import Rectangle
from ultralytics import YOLO
import matplotlib.pyplot as plt

model = YOLO('/home/ritwik/zeus/runs/detect/yolo_best_200ep/weights/best.pt')

images = list(Path("/home/ritwik/data/blender_boat/valid/images").glob("*.png"))[:-1]
labels = [str(x).replace("images", "labels").replace(".png", ".txt") for x in images]
print(len(images))

In [None]:
fig, ax = plt.subplots(16, 4, figsize=(8, 16))
ax = ax.flatten()

for idx, paths in enumerate(zip(images, labels)):
    img_path = paths[0]
    label_path = paths[1]
    img = Image.open(img_path)
    width = img.width
    height = img.height
    img = np.array(img.convert('RGB'))
    ax[idx].imshow(img)
    ax[idx].axis('off')

    results = model.predict(img, imgsz=640)

    for gt in open(label_path, 'r').readlines():
        splits = gt.split(" ")
        cx = float(splits[1])
        cy = float(splits[2])
        nw = float(splits[3])
        nh = float(splits[4])
        x, y, w, h = yolo_to_xywh(cx, cy, nw, nh, width, height)
        rect = Rectangle((x, y), w, h, linewidth=2, edgecolor='g', facecolor='none')
        ax[idx].add_patch(rect)

    for r in results:
        boxes = r.boxes.xywh.cpu()
        for box in boxes:
            rect = Rectangle((box[0], box[1]), box[2], box[3], linewidth=2, edgecolor='r', facecolor='none')
            ax[idx].add_patch(rect)

plt.tight_layout(pad=0.1, h_pad=0.1)
plt.savefig("all_blender_valid.png")
plt.show()

In [None]:
images = list(Path("/home/ritwik/data/blender_boat/satellite/images").glob("*.png"))[:-1]
labels = [str(x).replace("images", "labels").replace(".png", ".txt") for x in images]
print(len(images))

In [None]:
fig, ax = plt.subplots(4, 4, figsize=(12, 8))
ax = ax.flatten()

for idx, paths in enumerate(zip(images, labels)):
    img_path = paths[0]
    label_path = paths[1]
    img = Image.open(img_path)
    width = img.width
    height = img.height
    img = np.array(img.convert('RGB'))
    ax[idx].imshow(img)
    ax[idx].axis('off')

    results = model.predict(img, imgsz=640)

    for gt in open(label_path, 'r').readlines():
        splits = gt.split(" ")
        cx = float(splits[1])
        cy = float(splits[2])
        nw = float(splits[3])
        nh = float(splits[4])
        x, y, w, h = yolo_to_xywh(cx, cy, nw, nh, width, height)
        rect = Rectangle((x, y), w, h, linewidth=2, edgecolor='g', facecolor='none')
        ax[idx].add_patch(rect)

    for r in results:
        boxes = r.boxes.xywh.cpu()
        for box in boxes:
            rect = Rectangle((box[0], box[1]), box[2], box[3], linewidth=2, edgecolor='r', facecolor='none')
            ax[idx].add_patch(rect)

plt.tight_layout(w_pad=0.1, h_pad=0.05)
plt.savefig("satellite_blender_valid.jpg")
plt.show()