In [None]:
# --- ENVIRONMENT CHECK ABOUT GPU ---

import torch
import subprocess

print(f"PyTorch version   : {torch.__version__}")
print(f"CUDA available    : {torch.cuda.is_available()}")
print(f"CUDA version      : {torch.version.cuda}")

if torch.cuda.is_available():
    gpu = torch.cuda.get_device_properties(0)
    vram_gb = gpu.total_memory / 1024**3
    print(f"GPU               : {gpu.name}")
    print(f"VRAM              : {vram_gb:.1f} GB")
    print(f"SM count          : {gpu.multi_processor_count}")

    if vram_gb < 20:
        print("VRAM < 20GB, check batch size settings.")
    else:
        print("High-VRAM GPU detected")
else:
    print("No GPU detected")

# Check ultralytics version
try:
    import ultralytics
    print(f"Ultralytics       : {ultralytics.__version__}")
except ImportError:
    print("Ultralytics not installed — run: pip install ultralytics")

In [None]:
# --- CONFIGS ---
# ------------------------------------

import os
import yaml
import random
import colorsys
from pathlib import Path

import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
from ultralytics import YOLO

ROOT        = Path("..").resolve()
CONFIG_YAML = ROOT / "config.yml"
RUNS_DIR    = ROOT / "runs"
VAL_IMG_DIR = ROOT / "images" / "val"

# LOAD CLASS NAMES
with open(CONFIG_YAML) as f:
    cfg = yaml.safe_load(f)
NAMES = cfg["names"]
NC    = cfg["nc"]
print(f"Classes ({NC}): {NAMES}")

# --- TRAINING HYPERPARAMETERS ---
# --------------------------------

EPOCHS     = 150
BATCH_SIZE = 64
IMG_SIZE   = 1024
DEVICE     = ("cuda" if torch.cuda.is_available()
              else "mps" if torch.backends.mps.is_available()
              else "cpu")

# FIX 1: batch=64 at imgsz=1024 will OOM even on 32GB VRAM.
# At 1024px each image is 2.56x more pixels than 640px.
# Safe batch size scales as: batch = desired_batch * (640 / imgsz)^2
# RTX 5090 32GB: 64 @ 640px → 25 @ 1024px
SAFE_BATCH = max(1, int(BATCH_SIZE * (640 / IMG_SIZE) ** 2))
print(f"Adjusted batch size for imgsz={IMG_SIZE}: {BATCH_SIZE} → {SAFE_BATCH}")

TRAIN_CFG = dict(
    model         = "yolov8s.pt",
    data          = str(CONFIG_YAML),
    epochs        = EPOCHS,
    imgsz         = IMG_SIZE,
    batch         = SAFE_BATCH,       # FIX 1: scaled for 1024px
    workers       = 8,
    device        = DEVICE,
    project       = str(RUNS_DIR),
    name          = "yolov8s_thermal",
    exist_ok      = False,            # set True to resume into same folder
    # ── Optimiser ──
    optimizer     = "AdamW",
    lr0           = 0.001,
    lrf           = 0.01,
    momentum      = 0.937,
    weight_decay  = 0.0005,
    warmup_epochs = 3,
    # ── Augmentation ──
    # Custom aug pipeline already exists, keep built-in conservative
    mosaic        = 0.5,
    mixup         = 0.1,
    copy_paste    = 0.1,
    degrees       = 5.0,
    translate     = 0.1,
    scale         = 0.5,
    fliplr        = 0.5,
    flipud        = 0.0,              # thermal images are always upright
    hsv_h         = 0.0,              # thermal is greyscale — no hue shift
    hsv_s         = 0.0,              # no saturation shift
    hsv_v         = 0.3,              # brightness variation only
    # ── Loss weights ──
    box           = 7.5,
    cls           = 0.5,
    dfl           = 1.5,
    # ── Precision ──
    amp           = True,             # BF16 on RTX 5090 Blackwell arch
    # ── Early stopping ──
    patience      = 30,
    # ── Saving ──
    save          = True,
    save_period   = 10,
    val           = True,
    plots         = True,
    verbose       = True,
)

print("\nTraining config:")
for k, v in TRAIN_CFG.items():
    print(f"  {k:<18}: {v}")

model = YOLO(TRAIN_CFG["model"])

results = model.train(**{k: v for k, v in TRAIN_CFG.items() if k != "model"})

# Path to best weights — used in all cells below
BEST_WEIGHTS = Path(results.save_dir) / "weights" / "best.pt"
print(f"\n✓ Training complete")
print(f"  Best weights : {BEST_WEIGHTS}")
print(f"  Results dir  : {results.save_dir}")

In [None]:
# Plot loss and mAP curves from the CSV YOLO writes during training.
# Run this after training completes or while it's running (partial curves).

import pandas as pd

# Auto-find latest run if BEST_WEIGHTS not set (e.g. re-running notebook)
if "BEST_WEIGHTS" not in dir() or not Path(BEST_WEIGHTS).exists():
    run_dirs = sorted(Path(RUNS_DIR).glob("yolov8s_thermal*"), key=lambda p: p.stat().st_mtime)
    if not run_dirs:
        raise FileNotFoundError(f"No run directories found in {RUNS_DIR}")
    latest_run   = run_dirs[-1]
    BEST_WEIGHTS = latest_run / "weights" / "best.pt"
    print(f"Using latest run: {latest_run}")
else:
    latest_run = Path(BEST_WEIGHTS).parent.parent

results_csv = latest_run / "results.csv"
if not results_csv.exists():
    print("results.csv not found — training may not have started yet")
else:
    df = pd.read_csv(results_csv)
    df.columns = df.columns.str.strip()

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.patch.set_facecolor("#0e0e0e")
    fig.suptitle("YOLOv8s Training Curves", color="white", fontsize=15)

    plot_pairs = [
        ("train/box_loss",       "val/box_loss",           "Box Loss"),
        ("train/cls_loss",       "val/cls_loss",           "Class Loss"),
        ("train/dfl_loss",       "val/dfl_loss",           "DFL Loss"),
        ("metrics/mAP50(B)",     None,                     "mAP@0.5"),
        ("metrics/mAP50-95(B)", None,                      "mAP@0.5:0.95"),
        ("metrics/precision(B)","metrics/recall(B)",       "Precision / Recall"),
    ]

    for ax, (col1, col2, title) in zip(axes.flat, plot_pairs):
        ax.set_facecolor("#1a1a1a")
        ax.tick_params(colors="white")
        ax.set_title(title, color="white", fontsize=11)
        ax.set_xlabel("Epoch", color="grey")
        if col1 in df.columns:
            label1 = "Train" if col2 else col1.split("/")[-1]
            ax.plot(df["epoch"], df[col1], label=label1, color="steelblue", linewidth=1.8)
        if col2 and col2 in df.columns:
            label2 = "Val" if "loss" in col2 else col2.split("/")[-1]
            ax.plot(df["epoch"], df[col2], label=label2, color="coral", linewidth=1.8)
        ax.legend(facecolor="#2a2a2a", labelcolor="white", fontsize=9)
        for spine in ax.spines.values():
            spine.set_edgecolor("#333")

    plt.tight_layout()
    plt.show()
    print(f"\nFinal epoch metrics:")
    metric_cols = [c for c in df.columns if "metric" in c.lower() or "map" in c.lower()]
    print(df[metric_cols].tail(1).to_string(index=False))

In [None]:
# --- VALIDATION SET EVALUATION ---

model = YOLO(str(BEST_WEIGHTS))


metrics = model.val(
    data    = str(CONFIG_YAML),
    imgsz   = IMG_SIZE,   
    batch   = SAFE_BATCH,
    device  = DEVICE,
    conf    = 0.15,
    iou     = 0.40,
    plots   = True,
    verbose = True,
)

# --- Per-class AP table ---
print("\n" + "─" * 62)
print(f"{'Class':<16} {'AP@0.5':>9} {'AP@0.5:0.95':>12} {'Precision':>10} {'Recall':>8}")
print("─" * 62)

ap50   = metrics.box.ap50
ap5095 = metrics.box.ap
prec   = metrics.box.p
rec    = metrics.box.r

for i, name in enumerate(NAMES):
    p  = prec[i]   if i < len(prec)   else float('nan')
    r  = rec[i]    if i < len(rec)    else float('nan')
    a5 = ap50[i]   if i < len(ap50)   else float('nan')
    a9 = ap5095[i] if i < len(ap5095) else float('nan')
    print(f"{name:<16} {a5:>9.3f} {a9:>12.3f} {p:>10.3f} {r:>8.3f}")

print("─" * 62)
print(f"{'mAP':<16} {metrics.box.map50:>9.3f} {metrics.box.map:>12.3f}")
print("─" * 62)

#  --- Flag classes with low recall ---
print("\n**** Classes with recall < 0.5 (missed detections):")
low_recall = [(NAMES[i], rec[i]) for i in range(len(NAMES)) if i < len(rec) and rec[i] < 0.5]
if low_recall:
    for name, r in low_recall:
        print(f"  {name:<16} recall = {r:.3f}")
else:
    print("  None — all classes above 0.5 recall ✓")


# Display the confusion matrix and per-class PR curves saved by YOLO during val.
val_dir = Path(metrics.save_dir)

plot_files = {
    "Confusion Matrix"       : val_dir / "confusion_matrix.png",
    "Confusion Matrix (norm)": val_dir / "confusion_matrix_normalized.png",
    "PR Curve"               : val_dir / "PR_curve.png",
    "F1 Curve"               : val_dir / "F1_curve.png",
}

available = {k: v for k, v in plot_files.items() if v.exists()}
if not available:
    print("No plot files found — ensure plots=True was set during val")
else:
    fig, axes = plt.subplots(1, len(available), figsize=(8 * len(available), 7))
    if len(available) == 1:
        axes = [axes]
    fig.patch.set_facecolor("#0e0e0e")

    for ax, (title, path) in zip(axes, available.items()):
        img = cv2.imread(str(path))
        ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        ax.set_title(title, color="white", fontsize=12)
        ax.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# --- INFERENCE ON SPECIFIC CLASS --- 
TARGET_CLASS = "large_vehicle"   # change to any class name
NUM_SAMPLES  = 4
CONF_THRESH  = 0.15
IOU_THRESH   = 0.40

target_idx = NAMES.index(TARGET_CLASS) if TARGET_CLASS in NAMES else None
if target_idx is None:
    raise ValueError(f"Class '{TARGET_CLASS}' not in {NAMES}")

target_imgs = []
for txt in (ROOT / "labels" / "val").glob("*.txt"):
    for ln in txt.read_text().splitlines():
        if ln.strip() and int(ln.split()[0]) == target_idx:
            img_path = VAL_IMG_DIR / (txt.stem + ".jpg")
            if img_path.exists():
                target_imgs.append(img_path)
            break

print(f"Found {len(target_imgs)} val images containing '{TARGET_CLASS}'")
if not target_imgs:
    print("No images found — check TARGET_CLASS name and val label directory")
else:
    samples = random.sample(target_imgs, min(NUM_SAMPLES, len(target_imgs)))
    results = model.predict(
        source  = [str(p) for p in samples],
        imgsz   = IMG_SIZE,
        conf    = CONF_THRESH,
        iou     = IOU_THRESH,
        device  = DEVICE,
        verbose = False,
    )

    for img_path, result in zip(samples, results):
        bgr = cv2.imread(str(img_path))
        if bgr is None:
            continue
        rgb      = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        lbl_path = ROOT / "labels" / "val" / (img_path.stem + ".txt")
        gt_img   = draw_gt_boxes(rgb, lbl_path, NAMES)
        pred_img = draw_pred_boxes(rgb, result, NAMES)

        gt_count   = sum(1 for ln in lbl_path.read_text().splitlines()
                         if ln.strip() and int(ln.split()[0]) == target_idx)
        pred_count = sum(1 for b in (result.boxes or [])
                         if int(b.cls[0]) == target_idx)

        fig, axes = plt.subplots(1, 2, figsize=(18, 7))
        fig.patch.set_facecolor("#0e0e0e")
        fig.suptitle(f"Class filter: {TARGET_CLASS}  |  {img_path.name}",
                     color="#aaa", fontsize=9)
        axes[0].imshow(gt_img)
        axes[0].set_title(f"Ground Truth  ({gt_count} {TARGET_CLASS})",
                          color="white", fontsize=12)
        axes[0].axis("off")
        axes[1].imshow(pred_img)
        axes[1].set_title(f"Prediction  ({pred_count} {TARGET_CLASS} detected)",
                          color="white", fontsize=12)
        axes[1].axis("off")
        plt.tight_layout()
        plt.show()

In [None]:
# --- INFERENCE AND SIDE BY SIDE COMPARISON ---

NUM_SAMPLES  = 6
CONF_THRESH  = 0.15
IOU_THRESH   = 0.40
SEED         = 42

IMG_EXTS    = (".jpg", ".jpeg", ".png", ".bmp")
VAL_LBL_DIR = ROOT / "labels" / "val"

model = YOLO(str(BEST_WEIGHTS))


def class_color(cls_id: int):
    """Consistent HSV-derived colour per class index."""
    h = (cls_id * 37) % 360
    r, g, b = colorsys.hsv_to_rgb(h / 360.0, 0.9, 0.95)
    return (int(r*255), int(g*255), int(b*255))


def draw_gt_boxes(img_rgb: np.ndarray, label_path: Path, names: list) -> np.ndarray:
    """Draw ground truth YOLO boxes onto an RGB image."""
    h, w   = img_rgb.shape[:2]
    canvas = img_rgb.copy()
    if not label_path.exists():
        return canvas
    for ln in label_path.read_text().splitlines():
        parts = ln.strip().split()
        if len(parts) < 5:
            continue
        cls = int(parts[0])
        cx, cy, bw, bh = map(float, parts[1:5])
        x1 = max(0, int((cx - bw/2) * w))
        y1 = max(0, int((cy - bh/2) * h))
        x2 = min(w, int((cx + bw/2) * w))
        y2 = min(h, int((cy + bh/2) * h))
        color = class_color(cls)
        cv2.rectangle(canvas, (x1, y1), (x2, y2), color, 2, cv2.LINE_AA)
        label = names[cls] if cls < len(names) else str(cls)
        _draw_label(canvas, label, x1, y1, color)
    return canvas


def draw_pred_boxes(img_rgb: np.ndarray, result, names: list) -> np.ndarray:
    """Draw YOLO prediction boxes onto an RGB image."""
    canvas = img_rgb.copy()
    boxes  = result.boxes
    if boxes is None or len(boxes) == 0:
        return canvas
    for box in boxes:
        cls   = int(box.cls[0])
        conf  = float(box.conf[0])
        x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
        color = class_color(cls)
        cv2.rectangle(canvas, (x1, y1), (x2, y2), color, 2, cv2.LINE_AA)
        label = f"{names[cls] if cls < len(names) else cls} {conf:.2f}"
        _draw_label(canvas, label, x1, y1, color)
    return canvas


def _draw_label(img: np.ndarray, text: str, x: int, y: int, color: tuple):
    """Draw filled label rectangle above a bounding box."""
    font        = cv2.FONT_HERSHEY_SIMPLEX
    scale       = 0.48
    thickness   = 1
    (tw, th), _ = cv2.getTextSize(text, font, scale, thickness)
    pad = 4
    ly1 = max(0, y - th - pad * 2)
    ly2 = y
    lx2 = min(img.shape[1], x + tw + pad)
    cv2.rectangle(img, (x, ly1), (lx2, ly2), color, -1)
    lum = (0.299*color[0] + 0.587*color[1] + 0.114*color[2]) / 255
    tc  = (0, 0, 0) if lum > 0.55 else (255, 255, 255)
    cv2.putText(img, text, (x + pad, ly2 - pad), font, scale, tc, thickness, cv2.LINE_AA)


all_imgs = sorted([p for p in VAL_IMG_DIR.iterdir() if p.suffix.lower() in IMG_EXTS])
if not all_imgs:
    raise FileNotFoundError(f"No images found in {VAL_IMG_DIR}")

if SEED is not None:
    random.seed(SEED)
samples = random.sample(all_imgs, min(NUM_SAMPLES, len(all_imgs)))

results = model.predict(
    source  = [str(p) for p in samples],
    imgsz   = IMG_SIZE,   
    conf    = CONF_THRESH,
    iou     = IOU_THRESH,
    device  = DEVICE,
    verbose = False,
)

for img_path, result in zip(samples, results):
    bgr = cv2.imread(str(img_path))
    if bgr is None:
        continue
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

    lbl_path = VAL_LBL_DIR / (img_path.stem + ".txt")
    gt_img   = draw_gt_boxes(rgb,  lbl_path, NAMES)
    pred_img = draw_pred_boxes(rgb, result,  NAMES)

    gt_count   = len(lbl_path.read_text().splitlines()) if lbl_path.exists() else 0
    pred_count = len(result.boxes) if result.boxes is not None else 0

    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    fig.patch.set_facecolor("#0e0e0e")
    fig.suptitle(img_path.name, color="#aaa", fontsize=9)

    axes[0].imshow(gt_img)
    axes[0].set_title(f"Ground Truth  ({gt_count} boxes)",
                      color="white", fontsize=12, pad=8)
    axes[0].axis("off")

    axes[1].imshow(pred_img)
    axes[1].set_title(f"Prediction  ({pred_count} boxes)  conf≥{CONF_THRESH}",
                      color="white", fontsize=12, pad=8)
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()
    print(f"  GT: {gt_count} boxes  |  Pred: {pred_count} boxes  |  {img_path.name}")

In [None]:
# -- EXPORT MODEL ---

model = YOLO(str(BEST_WEIGHTS))

# --- ONNX ---
# exported = model.export(
#     format   = "onnx",
#     imgsz    = IMG_SIZE,
#     dynamic  = False,
#     simplify = True,
#     device   = DEVICE,
# )

# -- TensorRT --
exported = model.export(
    format    = "engine",
    imgsz     = IMG_SIZE,
    half      = True,
    device    = DEVICE,
    workspace = 8,
    verbose   = False,
)

print(f"\n Model exported to: {exported}")
# print("  Load with: model = YOLO('path/to/best.engine')")