In [None]:
# 1. Imports & Setup
import os
import sys
import json
import time
import math
import glob
import random
import copy
import shutil
from pathlib import Path
from datetime import datetime
from collections import defaultdict

import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from IPython.display import display, clear_output
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler

import fiftyone as fo
import fiftyone.zoo as foz

# Set Seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("‚úÖ Notebook Updated for Fixed Dataset Training")



In [None]:
# 2. Configuration
# --- USER SETTINGS ---
QUICK_TEST = True  # Set to True for a fast smoke test
NUM_SAMPLES = 1000 if QUICK_TEST else 20000  # Total images on disk
TOTAL_EPOCHS = 3 if QUICK_TEST else 100     # Training goal

BASE_DIR = os.path.abspath("yolo-lab")
DIRS = {
    "datasets": os.path.join(BASE_DIR, "datasets"),
    "runs": os.path.join(BASE_DIR, "runs"),
    "configs": os.path.join(BASE_DIR, "configs"),
}
for d in DIRS.values():
    os.makedirs(d, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
EXP_NAME = "yolov8_fixed_dataset"
RUN_NAME = EXP_NAME # Stable run name
if QUICK_TEST: RUN_NAME += f"_test_{timestamp}"
RUN_DIR = os.path.join(DIRS["runs"], RUN_NAME)
os.makedirs(RUN_DIR, exist_ok=True)

# Model & Training Config
CFG = {
    "exp_name": EXP_NAME,
    "run_name": RUN_NAME,
    "seed": 42,
    "imgsz": 640,
    "batch_size": 8 if QUICK_TEST else 16,
    "num_classes": 80,
    "time_limit": 36000, # 10 hours default
    
    # Model
    "width": 1.0,
    "depth": 1.0,
    "reg_max": 16,
    "head_hidden": 256,
    "backbone": "yolov8_cspdarknet",
    
    # Optimizer
    "optimizer": "adamw",
    "lr": 1e-3,
    "weight_decay": 0.05,
    "cosine_schedule": True,
    "epochs": TOTAL_EPOCHS,
    "amp": True,
    "grad_clip_norm": 10.0,
    "ema_decay": 0.9998,
    
    # Loss
    "tal_alpha": 1.0,
    "tal_beta": 6.0,
    "tal_topk": 10,
    "tal_center_radius": 2.5,
    "loss_weights": {"box": 7.5, "cls": 0.5, "dfl": 1.5},
    
    # Augmentation
    "letterbox_pad": 114,
    "hflip_p": 0.5,
    "hsv_h": 0.015,
    "hsv_s": 0.7,
    "hsv_v": 0.4,
    "grad_clip_norm": 10.0,
    "accumulate": 1, 
    "min_lr_ratio": 0.05,
    "pretrained": None,
    
    # Paths (Fixed local dataset)
    "data_root": os.path.join(DIRS["datasets"], "coco_fixed"),
    "train_img_dir": "images/train",
    "train_lbl_dir": "labels/train",
    "val_img_dir": "images/val",
    "val_lbl_dir": "labels/val",
}

print("Run Directory:", RUN_DIR)



In [None]:
# 3. Cleanup Utilities
def confirm_action(prompt):
    ans = input(f"{prompt} (type 'yes' to confirm): ")
    return ans.lower() == 'yes'

def clear_last_checkpoint():
    """Deletes last.pt in the current RUN_DIR"""
    path = os.path.join(RUN_DIR, "last.pt")
    if os.path.exists(path):
        if confirm_action(f"Clean checkpoint at {path}?"):
            os.remove(path)
            print(f"üóëÔ∏è Deleted {path}")
    else:
        print(f"‚ö†Ô∏è No checkpoint found at {path}")

def wipe_photos():
    """Deletes the images folder in data_root"""
    path = os.path.join(CFG["data_root"], "images")
    if os.path.exists(path):
        if confirm_action(f"WIPE ALL PHOTOS in {path}?"):
            shutil.rmtree(path)
            print(f"üóëÔ∏è Deleted {path}")
    else:
        print(f"‚ö†Ô∏è No photos found at {path}")

def wipe_data():
    """Deletes the labels folder in data_root"""
    path = os.path.join(CFG["data_root"], "labels")
    if os.path.exists(path):
        if confirm_action(f"WIPE ALL DATA (labels) in {path}?"):
            shutil.rmtree(path)
            print(f"üóëÔ∏è Deleted {path}")
    else:
        print(f"‚ö†Ô∏è No data found at {path}")

def full_wipe():
    """Wipes the entire yolo-lab project directory (runs and datasets)"""
    if confirm_action("DANGER: WIPE EVERYTHING (runs, datasets, configs)?"):
        if os.path.exists(BASE_DIR):
            shutil.rmtree(BASE_DIR)
            print(f"üóëÔ∏è Deleted {BASE_DIR}")
            # Re-create structure for safety
            for d in DIRS.values():
                os.makedirs(d, exist_ok=True)
            os.makedirs(RUN_DIR, exist_ok=True)
            print("‚úÖ Structure re-initialized.")

print("‚õëÔ∏è Cleanup Utils Loaded: clear_last_checkpoint(), wipe_photos(), wipe_data(), full_wipe()")



In [None]:
# 4. Model Architecture
def autopad(k, p=None, d=1):
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
    return p

class Conv(nn.Module):
    default_act = nn.SiLU()
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        super().__init__()
        c_ = int(c2 * e)
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = Conv(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2
    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C2f(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        super().__init__()
        self.c = int(c2 * e)
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3)), e=1.0) for _ in range(n))
    def forward(self, x):
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

class SPPF(nn.Module):
    def __init__(self, c1, c2, k=5):
        super().__init__()
        c_ = c1 // 2
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * 4, c2, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
    def forward(self, x):
        x = self.cv1(x)
        y1 = self.m(x)
        y2 = self.m(y1)
        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))

class CSPDarknet(nn.Module):
    def __init__(self, width=1.0, depth=1.0):
        super().__init__()
        base_c = [64, 128, 256, 512, 1024]
        base_d = [3, 6, 6, 3]
        self.c = [int(x * width) for x in base_c]
        self.d = [max(round(x * depth), 1) if x > 1 else x for x in base_d]
        self.stem = Conv(3, self.c[0], 3, 2)
        self.stage1 = nn.Sequential(Conv(self.c[0], self.c[1], 3, 2), C2f(self.c[1], self.c[1], n=self.d[0], shortcut=True))
        self.stage2 = nn.Sequential(Conv(self.c[1], self.c[2], 3, 2), C2f(self.c[2], self.c[2], n=self.d[1], shortcut=True))
        self.stage3 = nn.Sequential(Conv(self.c[2], self.c[3], 3, 2), C2f(self.c[3], self.c[3], n=self.d[2], shortcut=True))
        self.stage4 = nn.Sequential(Conv(self.c[3], self.c[4], 3, 2), C2f(self.c[4], self.c[4], n=self.d[3], shortcut=True), SPPF(self.c[4], self.c[4], k=5))

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        c3 = self.stage2(x)
        c4 = self.stage3(c3)
        c5 = self.stage4(c4)
        return c3, c4, c5

class YOLOv8PAFPN(nn.Module):
    def __init__(self, c3, c4, c5, out_ch=256, width=1.0, depth=1.0):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.reduce5 = Conv(c5, c4, 1, 1)
        self.c2f_p4 = C2f(c4 + c4, c4, n=3, shortcut=False)
        self.reduce4 = Conv(c4, c3, 1, 1)
        self.c2f_p3 = C2f(c3 + c3, c3, n=3, shortcut=False)
        self.down3 = Conv(c3, c3, 3, 2)
        self.c2f_n4 = C2f(c3 + c4, c4, n=3, shortcut=False)
        self.down4 = Conv(c4, c4, 3, 2)
        self.c2f_n5 = C2f(c4 + c5, c5, n=3, shortcut=False)

    def forward(self, c3, c4, c5):
        p5 = c5
        p4 = self.reduce5(p5)
        p4_out = self.c2f_p4(torch.cat([self.up(p4), c4], dim=1))
        p3 = self.reduce4(p4_out)
        p3_out = self.c2f_p3(torch.cat([self.up(p3), c3], dim=1))
        n3 = p3_out
        n4_out = self.c2f_n4(torch.cat([self.down3(n3), p4_out], dim=1))
        n5_out = self.c2f_n5(torch.cat([self.down4(n4_out), p5], dim=1))
        return n3, n4_out, n5_out

class Integral(nn.Module):
    def __init__(self, reg_max=16):
        super().__init__()
        self.reg_max = int(reg_max)
        self.register_buffer("proj", torch.arange(self.reg_max + 1, dtype=torch.float32), persistent=False)
    def forward(self, logits):
        return (logits.softmax(dim=-1) * self.proj).sum(dim=-1)

class YoloV8LiteHead(nn.Module):
    def __init__(self, in_channels_list, num_classes=80, hidden=256, reg_max=16):
        super().__init__()
        self.num_classes = num_classes
        self.reg_max = reg_max
        self.integral = Integral(self.reg_max)
        self.cls_towers = nn.ModuleList()
        self.reg_towers = nn.ModuleList()
        self.cls_preds = nn.ModuleList()
        self.box_preds = nn.ModuleList()
        
        for in_ch in in_channels_list:
            self.cls_towers.append(nn.Sequential(Conv(in_ch, hidden, 3, 1), Conv(hidden, hidden, 3, 1)))
            self.reg_towers.append(nn.Sequential(Conv(in_ch, hidden, 3, 1), Conv(hidden, hidden, 3, 1)))
            self.cls_preds.append(nn.Conv2d(hidden, num_classes, 1))
            self.box_preds.append(nn.Conv2d(hidden, 4 * (self.reg_max + 1), 1))
            
        # Bias Init
        p = 0.01
        bias = -math.log((1 - p) / p)
        for m in self.cls_preds:
             nn.init.constant_(m.bias, bias)

    def forward(self, features):
        cls_outs = []
        box_outs = []
        for i, f in enumerate(features):
            cls_outs.append(self.cls_preds[i](self.cls_towers[i](f)))
            box_outs.append(self.box_preds[i](self.reg_towers[i](f)))
        return cls_outs, box_outs

class YoloModel(nn.Module):
    def __init__(self, num_classes=80, backbone="yolov8_cspdarknet", head_hidden=256, fpn_out=256):
        super().__init__()
        width = CFG.get("width", 1.0)
        depth = CFG.get("depth", 1.0)
        self.backbone = CSPDarknet(width=width, depth=depth)
        base_c = [256, 512, 1024]
        c3, c4, c5 = [int(x * width) for x in base_c]
        self.neck = YOLOv8PAFPN(c3=c3, c4=c4, c5=c5, out_ch=fpn_out, width=width, depth=depth)
        self.head = YoloV8LiteHead(in_channels_list=[c3, c4, c5], num_classes=num_classes, hidden=head_hidden, reg_max=CFG.get("reg_max", 16))
        self.strides = [8, 16, 32]

    def forward(self, x, targets=None):
        c3, c4, c5 = self.backbone(x)
        p3, p4, p5 = self.neck(c3, c4, c5)
        cls_outs, box_outs = self.head([p3, p4, p5])
        head_out = {"features": [p3, p4, p5], "cls": cls_outs, "box": box_outs, "strides": self.strides}
        
        if self.training and targets is not None and hasattr(self, "criterion"):
            losses, stats = self.criterion(head_out, targets)
            return losses, stats
        return head_out



In [None]:
# 5. Utils, Loss & Training Helpers
def make_grid(h, w, stride, device):
    ys = torch.arange(h, device=device)
    xs = torch.arange(w, device=device)
    yy, xx = torch.meshgrid(ys, xs, indexing="ij")
    cx = (xx + 0.5) * stride
    cy = (yy + 0.5) * stride
    return cx.reshape(-1), cy.reshape(-1)

def box_iou_xyxy_matrix(a, b):
    if a.numel() == 0 or b.numel() == 0: return a.new_zeros((a.shape[0], b.shape[0]))
    area_a = ((a[:, 2] - a[:, 0]).clamp(min=0) * (a[:, 3] - a[:, 1]).clamp(min=0))[:, None]
    area_b = ((b[:, 2] - b[:, 0]).clamp(min=0) * (b[:, 3] - b[:, 1]).clamp(min=0))[None, :]
    x1 = torch.maximum(a[:, None, 0], b[None, :, 0])
    y1 = torch.maximum(a[:, None, 1], b[None, :, 1])
    x2 = torch.minimum(a[:, None, 2], b[None, :, 2])
    y2 = torch.minimum(a[:, None, 3], b[None, :, 3])
    inter = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
    return inter / (area_a + area_b - inter + 1e-6)

def bbox_iou(box1, box2, eps=1e-7):
    b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
    b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
    union = w1 * h1 + w2 * h2 - inter + eps
    return inter / union

class DetectionLoss(nn.Module):
    def __init__(self, num_classes, image_size, strides, lambda_box=7.5, lambda_cls=0.5, lambda_dfl=1.5, dfl_ch=17):
        super().__init__()
        self.nc = num_classes
        self.imgsz = image_size
        self.strides = strides
        self.lambda_box = lambda_box
        self.lambda_cls = lambda_cls
        self.lambda_dfl = lambda_dfl
        self.dfl_ch = dfl_ch
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, head_out, targets):
        cls_outs = head_out["cls"]
        box_outs = head_out["box"]
        B = cls_outs[0].shape[0]
        
        pred_logits = torch.cat([x.permute(0, 2, 3, 1).reshape(B, -1, self.nc) for x in cls_outs], 1)
        pred_dist = torch.cat([x.permute(0, 2, 3, 1).reshape(B, -1, 4 * self.dfl_ch) for x in box_outs], 1)
        
        targets_per_image, levels = self.build_targets(cls_outs, box_outs, targets)
        
        loss_cls = torch.tensor(0.0, device=cls_outs[0].device)
        loss_box = torch.tensor(0.0, device=cls_outs[0].device)
        num_pos_total = 0.0
        total_anchors = pred_logits.shape[1]  # Total anchors per image for cls normalization
        
        for b in range(B):
            t = targets_per_image[b]
            pos_mask = t["pos_index"]
            num_pos = len(pos_mask)
            num_pos_total += num_pos
            
            if num_pos > 0:
                # Classification
                t_cls = torch.zeros_like(pred_logits[b])
                t_cls[pos_mask] = t["t_cls_soft"].to(t_cls.dtype)
                loss_cls += self.bce(pred_logits[b], t_cls).sum()
                
                # Box (IoU + DFL)
                p_box_pos = pred_dist[b][pos_mask].view(-1, 4, self.dfl_ch)
                
                s_idx = torch.empty(num_pos, device=pred_dist.device, dtype=torch.float32)
                for lev in levels:
                    m = (pos_mask >= lev["start"]) & (pos_mask < lev["end"])
                    if m.any(): s_idx[m] = float(lev["stride"])
                
                t_bins = t["t_box_ltrb"] / s_idx.unsqueeze(-1)
                t_bins = t_bins.clamp(0, self.dfl_ch - 1.01)
                tl = t_bins.long(); tr = tl + 1
                wl = tr.float() - t_bins; wr = t_bins - tl.float()
                
                l_dfl = (F.cross_entropy(p_box_pos.view(-1, self.dfl_ch), tl.view(-1), reduction="none").view(-1, 4) * wl +
                         F.cross_entropy(p_box_pos.view(-1, self.dfl_ch), tr.view(-1), reduction="none").view(-1, 4) * wr).mean(-1)
                
                p_ltrb = (p_box_pos.softmax(-1) * torch.arange(self.dfl_ch, device=p_box_pos.device).float()).sum(-1) * s_idx.unsqueeze(-1)
                
                anchors_cx, anchors_cy = [], []
                for lev in levels:
                    cx, cy = make_grid(lev["H"], lev["W"], lev["stride"], pred_dist.device)
                    anchors_cx.append(cx); anchors_cy.append(cy)
                anchors_cx = torch.cat(anchors_cx)[pos_mask]
                anchors_cy = torch.cat(anchors_cy)[pos_mask]
                
                p_xyxy = torch.stack([anchors_cx - p_ltrb[:, 0], anchors_cy - p_ltrb[:, 1],
                                      anchors_cx + p_ltrb[:, 2], anchors_cy + p_ltrb[:, 3]], -1)
                
                iou = bbox_iou(p_xyxy, t["t_box_xyxy"])
                loss_box += ( (1.0 - iou) * self.lambda_box + l_dfl * self.lambda_dfl).sum()
            else:
                loss_cls += self.bce(pred_logits[b], torch.zeros_like(pred_logits[b])).sum()

        # Normalize: cls by total anchors (stable), box by num positives
        cls_norm = B * total_anchors  # Total anchor-class predictions
        box_norm = max(num_pos_total, 1.0)
        loss_cls_scaled = (loss_cls / cls_norm) * self.lambda_cls
        loss_box_scaled = loss_box / box_norm
        return {"loss": loss_cls_scaled + loss_box_scaled, "loss_cls": loss_cls_scaled, "loss_box": loss_box_scaled}, {"num_pos": num_pos_total}

    def build_targets(self, cls_outs, box_outs, targets):
        gt_classes = targets["labels"]
        gt_boxes = targets["boxes"]
        batch_idx = targets["batch_index"]
        B = cls_outs[0].shape[0]
        gt_cls_list, gt_box_list = [], []
        for i in range(B):
            mask = batch_idx == i
            gt_cls_list.append(gt_classes[mask]); gt_box_list.append(gt_boxes[mask])
        return build_targets_task_aligned(cls_outs, box_outs, self.strides, gt_cls_list, gt_box_list, self.imgsz)

def build_targets_task_aligned(cls_outs, box_outs, strides, gt_classes, gt_boxes_xyxy, image_size):
    device = cls_outs[0].device
    B, C = cls_outs[0].shape[0], cls_outs[0].shape[1]
    levels, grids, start = [], [], 0
    for (cl, s) in zip(cls_outs, strides):
        _, _, H, W = cl.shape
        levels.append({"H": H, "W": W, "stride": s, "start": start, "end": start + H * W})
        cx, cy = make_grid(H, W, s, device)
        grids.append((cx, cy)); start += H * W
        
    alpha, beta, topk, cr = CFG["tal_alpha"], CFG["tal_beta"], CFG["tal_topk"], CFG["tal_center_radius"]
    per_image_targets = []
    for b in range(B):
        cls_flat = torch.cat([cl[b].permute(1, 2, 0).reshape(-1, C) for cl in cls_outs], 0)
        gtc, gtb = gt_classes[b], gt_boxes_xyxy[b]
        Ng = int(gtc.numel())
        if Ng == 0:
            per_image_targets.append({"t_cls_soft": torch.zeros(0, C, device=device), "t_box_xyxy": torch.zeros(0, 4, device=device), "t_box_ltrb": torch.zeros(0, 4, device=device), "pos_index": torch.zeros(0, dtype=torch.long, device=device)})
            continue
        
        preds_xyxy = []
        for (bx, lev, (cx, cy)) in zip(box_outs, levels, grids):
            bl = bx[b].view(4, 17, -1).permute(2, 0, 1).softmax(-1)
            dist = (bl * torch.arange(17, device=device)).sum(-1) * lev["stride"]
            preds_xyxy.append(torch.stack([cx - dist[:, 0], cy - dist[:, 1], cx + dist[:, 2], cy + dist[:, 3]], -1))
        preds_xyxy = torch.cat(preds_xyxy, 0).clamp(0, image_size)

        mask = torch.zeros(cls_flat.shape[0], Ng, dtype=torch.bool, device=device)
        for lev, (cx, cy) in zip(levels, grids):
            gt_centers = 0.5 * (gtb[:, :2] + gtb[:, 2:])
            half = cr * lev["stride"]
            in_cr = (cx.view(-1, 1) >= gt_centers[:, 0] - half) & (cy.view(-1, 1) >= gt_centers[:, 1] - half) & \
                    (cx.view(-1, 1) <= gt_centers[:, 0] + half) & (cy.view(-1, 1) <= gt_centers[:, 1] + half)
            mask[lev["start"]:lev["end"]] |= in_cr
                
        align = (cls_flat.sigmoid()[:, gtc].pow(alpha)) * (box_iou_xyxy_matrix(preds_xyxy, gtb).pow(beta))
        align = torch.where(mask, align, torch.full_like(align, -1e-9))
        
        val, idx = torch.topk(align, min(topk, align.shape[0]), dim=0)
        best_gt = torch.full((cls_flat.shape[0],), -1, dtype=torch.long, device=device)
        best_score = torch.full((cls_flat.shape[0],), -1e-9, device=device)
        for j in range(Ng):
            better = val[:, j] > best_score[idx[:, j]]
            best_gt[idx[better, j]] = j
            best_score[idx[better, j]] = val[better, j]
            
        pos = best_gt >= 0; p_idx = torch.nonzero(pos).squeeze(1)
        if p_idx.numel() == 0:
            per_image_targets.append({"t_cls_soft": torch.zeros(0, C, device=device), "t_box_xyxy": torch.zeros(0, 4, device=device), "t_box_ltrb": torch.zeros(0, 4, device=device), "pos_index": p_idx})
            continue
        
        gt_idx = best_gt[p_idx]; soft = torch.zeros(len(p_idx), C, device=device); soft[torch.arange(len(p_idx)), gtc[gt_idx]] = best_score[p_idx]
        ltrb = torch.empty(len(p_idx), 4, device=device)
        for i, lev in enumerate(levels):
            m = (p_idx >= lev["start"]) & (p_idx < lev["end"])
            if m.any():
                ii = p_idx[m] - lev["start"]
                ltrb[m] = torch.stack([grids[i][0][ii] - gtb[gt_idx[m], 0], grids[i][1][ii] - gtb[gt_idx[m], 1], gtb[gt_idx[m], 2] - grids[i][0][ii], gtb[gt_idx[m], 3] - grids[i][1][ii]], -1)
        per_image_targets.append({"t_cls_soft": soft, "t_box_xyxy": gtb[gt_idx], "t_box_ltrb": ltrb, "pos_index": p_idx})
    return per_image_targets, levels

def decode_outputs(head_out, strides, conf_thres=0.25, iou_thres=0.45, max_det=300, imgsz=640):
    device = head_out["cls"][0].device
    B, C = head_out["cls"][0].shape[0], head_out["cls"][0].shape[1]
    final_preds = []
    for b in range(B):
        boxes, scores, clss = [], [], []
        for i, s in enumerate(strides):
            cls = head_out["cls"][i][b].permute(1, 2, 0).reshape(-1, C).sigmoid()
            box = head_out["box"][i][b].view(4, 17, -1).permute(2, 0, 1).softmax(-1)
            dist = (box * torch.arange(17, device=device)).sum(-1) * s
            cx, cy = make_grid(head_out["cls"][i].shape[2], head_out["cls"][i].shape[3], s, device)
            
            sc, cl = cls.max(1)
            mask = sc > conf_thres
            if mask.any():
                bx = torch.stack([cx[mask] - dist[mask, 0], cy[mask] - dist[mask, 1], cx[mask] + dist[mask, 2], cy[mask] + dist[mask, 3]], -1)
                boxes.append(bx); scores.append(sc[mask]); clss.append(cl[mask])
        
        if not boxes: final_preds.append(torch.zeros(0, 6, device=device)); continue
        boxes, scores, clss = torch.cat(boxes), torch.cat(scores), torch.cat(clss)
        
        # Clamp boxes to image size
        boxes.clamp_(0, imgsz)

        keep = torch.ops.torchvision.nms(boxes, scores, iou_thres)
        
        # Limit max detections
        keep = keep[:max_det]
        
        final_preds.append(torch.cat([boxes[keep], scores[keep, None], clss[keep, None].float()], 1))
    return final_preds

def compute_ap(recall, precision):
    mrec = np.concatenate(([0.], recall, [1.]))
    mpre = np.concatenate(([1.], precision, [0.]))
    for i in range(mpre.size - 1, 0, -1): mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
    i = np.where(mrec[1:] != mrec[:-1])[0]
    return np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])

def validate(model, loader, device, conf_thres=0.001, max_det=300):
    was_training = model.training
    model.eval()
    total_preds = 0
    num_imgs = 0
    with torch.no_grad():
        for imgs, targets in loader:
            imgs = imgs.to(device)
            out = model(imgs)
            preds = decode_outputs(out, [8, 16, 32], conf_thres=conf_thres, max_det=max_det, imgsz=CFG["imgsz"])
            for p in preds:
                total_preds += len(p)
                num_imgs += 1
            if num_imgs > 100: break # Small sample for speed
    
    avg_preds = total_preds / max(1, num_imgs)
    if was_training: model.train()
    return avg_preds

def visualize_inline(model, imgs, targets, step, device, classes):
    model.eval()
    with torch.no_grad():
        out = model(imgs[:2])
        preds = decode_outputs(out, [8, 16, 32])
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    for i in range(min(len(imgs), 2)):
        img = imgs[i].cpu().permute(1, 2, 0).numpy()
        axes[i].imshow(img)
        # Plot GT
        mask = targets["batch_index"] == i
        for b in targets["boxes"][mask].cpu().numpy():
            axes[i].add_patch(Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], fill=False, color='green'))
        # Plot Pred
        for p in preds[i].cpu().numpy():
            if p[4] > 0.3:
                axes[i].add_patch(Rectangle((p[0], p[1]), p[2]-p[0], p[3]-p[1], fill=False, color='red'))
                axes[i].text(p[0], p[1], f"{classes[int(p[5])]}:{p[4]:.2f}", color='white', backgroundcolor='red', fontsize=8)
    axes[0].set_title(f"Step {step} - Green:GT Red:Pred")
    plt.show()

class ModelEMA:
    def __init__(self, model, decay=0.9998, tau=2000):
        self.module = copy.deepcopy(model).eval()
        for p in self.module.parameters(): p.requires_grad_(False)
        self.decay = decay
        self.tau = tau
        self.updates = 0
    def update(self, model):
        self.updates += 1
        d = self.decay * (1 - math.exp(-self.updates / self.tau))
        msd = model.state_dict(); esd = self.module.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point: esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else: esd[k].copy_(msd[k])

def save_checkpoint(epoch, model, ema, optimizer, scaler, scheduler, run_dir, step, filename="last.pt"):
    ckpt = {
        "epoch": epoch, "global_step": step,
        "model": model.state_dict(),
        "ema": ema.module.state_dict(),
        "ema_updates": ema.updates,
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),
        "scheduler": scheduler.state_dict(),
        "cfg": CFG
    }
    path = os.path.join(run_dir, filename)
    tmp_path = path + ".tmp"
    torch.save(ckpt, tmp_path)
    if os.path.exists(path):
        os.remove(path)
    os.rename(tmp_path, path)
    print(f"üì• Saved checkpoint: {path}")



In [None]:
# 6. Dataset (YOLOv8 Style)
def augment_hsv(image, hgain=0.015, sgain=0.7, vgain=0.4):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.float32)
    hsv[..., 0] *= (1 + (random.random() * 2 - 1) * hgain); hsv[..., 0] = np.clip(hsv[..., 0], 0, 179)
    hsv[..., 1] *= (1 + (random.random() * 2 - 1) * sgain); hsv[..., 1] = np.clip(hsv[..., 1], 0, 255)
    hsv[..., 2] *= (1 + (random.random() * 2 - 1) * vgain); hsv[..., 2] = np.clip(hsv[..., 2], 0, 255)
    return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)

def letterbox(img, new_shape=640, color=(114, 114, 114)):
    shape = img.shape[:2]
    r = min(new_shape / shape[0], new_shape / shape[1])
    new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r)))
    dw, dh = (new_shape - new_unpad[0]) / 2, (new_shape - new_unpad[1]) / 2
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
    return img, r, (left, top)

class YoloDataset(Dataset):
    def __init__(self, img_dir, lbl_dir, imgsz=640, augment=True):
        self.img_dir, self.lbl_dir, self.imgsz, self.augment = img_dir, lbl_dir, imgsz, augment
        self.img_files = sorted(glob.glob(os.path.join(img_dir, "*.*")))
        self.lbl_files = [os.path.join(lbl_dir, Path(f).stem + ".txt") for f in self.img_files]
    def __len__(self): return len(self.img_files)
    def __getitem__(self, i):
        try:
            img = cv2.imread(self.img_files[i])
            if img is None: raise ValueError("Image decode failed")
            h0, w0 = img.shape[:2]
            if self.augment: img = augment_hsv(img)
            img, r, (padw, padh) = letterbox(img, self.imgsz)
        except Exception as e:
            print(f"‚ö†Ô∏è Warning: Dataset failure on {self.img_files[i]}: {e}")
            img = np.full((self.imgsz, self.imgsz, 3), 114, dtype=np.uint8)
            return torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0, {"labels": torch.zeros(0, dtype=torch.long), "boxes": torch.zeros((0, 4), dtype=torch.float32)}
        
        boxes = []
        if os.path.exists(self.lbl_files[i]):
            with open(self.lbl_files[i]) as f:
                for line in f:
                    line = line.strip()
                    if not line:  # Skip empty lines
                        continue
                    parts = line.split()
                    if len(parts) == 5:
                        try:
                            c, x, y, w, h = map(float, parts)
                            x1, y1 = (x - w/2) * w0 * r + padw, (y - h/2) * h0 * r + padh
                            x2, y2 = (x + w/2) * w0 * r + padw, (y + h/2) * h0 * r + padh
                            boxes.append([c, x1, y1, x2, y2])
                        except ValueError:
                            continue  # Skip malformed lines
        
        boxes = np.array(boxes) if boxes else np.zeros((0, 5))
        img = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
        return img, {"labels": torch.from_numpy(boxes[:, 0]).long(), "boxes": torch.from_numpy(boxes[:, 1:]).float()}

def collate_fn(batch):
    imgs, targets = zip(*batch)
    imgs = torch.stack(imgs, 0)
    all_boxes, all_labels, all_idx = [], [], []
    for i, t in enumerate(targets):
        all_boxes.append(t["boxes"])
        all_labels.append(t["labels"])
        all_idx.append(torch.full((len(t["labels"]),), i, dtype=torch.long))
    return imgs, {"boxes": torch.cat(all_boxes, 0), "labels": torch.cat(all_labels, 0), "batch_index": torch.cat(all_idx, 0)}



In [None]:
# 7. Dataset Preparation
COCO_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

def ensure_dataset_ready(size=5000):
    out_dir = CFG["data_root"]
    if os.path.exists(os.path.join(out_dir, ".complete")):
        print(f"‚úÖ Dataset ready at {out_dir}"); return
    
    print(f"üì¶ Downloading COCO (size={size})...")
    if os.path.exists(out_dir): shutil.rmtree(out_dir)
    os.makedirs(out_dir, exist_ok=True)
    
    for split in ["train", "validation"]:
        ds = foz.load_zoo_dataset("coco-2017", split=split, label_types=["detections"], max_samples=size if split=="train" else 500, shuffle=True)
        ds.export(export_dir=out_dir, dataset_type=fo.types.YOLOv5Dataset, label_field="ground_truth", split="train" if split=="train" else "val", classes=COCO_CLASSES)
        fo.delete_dataset(ds.name)
        
    # Verify label integrity
    print("üîé Verifying label indices (0-79)...")
    train_lbl = os.path.join(out_dir, "labels/train")
    if os.path.exists(train_lbl):
        bad_files = 0
        for lf in glob.glob(os.path.join(train_lbl, "*.txt"))[:500]:
             with open(lf) as f:
                for line in f:
                    try:
                        c = int(float(line.split()[0]))
                        if c < 0 or c >= 80: print(f"‚ùå Bad class {c} in {lf}"); bad_files += 1
                    except: pass
        if bad_files == 0: print("‚úÖ Label indices look correct.")

    with open(os.path.join(out_dir, ".complete"), "w") as f: f.write("done")
    print(f"‚úÖ Export complete.")



In [None]:
# 8. Main Training Loop
set_seed(CFG["seed"])
ensure_dataset_ready(size=NUM_SAMPLES)

train_loader = DataLoader(YoloDataset(os.path.join(CFG["data_root"], "images/train"), os.path.join(CFG["data_root"], "labels/train")), batch_size=CFG["batch_size"], shuffle=True, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(YoloDataset(os.path.join(CFG["data_root"], "images/val"), os.path.join(CFG["data_root"], "labels/val"), augment=False), batch_size=CFG["batch_size"], shuffle=False, collate_fn=collate_fn)

# Calculate total training steps for proper LR scheduling
steps_per_epoch = len(train_loader)
total_training_steps = steps_per_epoch * CFG["epochs"]
warmup_steps = min(1000, total_training_steps // 10)  # 10% warmup, max 1000 steps
print(f"Steps per epoch: {steps_per_epoch}, Total steps: {total_training_steps}, Warmup steps: {warmup_steps}")

model = YoloModel(num_classes=CFG["num_classes"]).to(device)

if CFG.get("pretrained"):
    pt = CFG["pretrained"]
    if os.path.exists(pt):
        print(f"üì¶ Loading weights from {pt}...")
        st = torch.load(pt, map_location=device)
        model.load_state_dict(st["model"] if "model" in st else st, strict=False)
    else:
        print(f"‚ö†Ô∏è Pretrained {pt} not found.")

model.criterion = DetectionLoss(num_classes=80, image_size=640, strides=[8,16,32]).to(device)
ema = ModelEMA(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
scaler = GradScaler("cuda", enabled=CFG["amp"])

# Cosine scheduler with warmup (LambdaLR)
def get_lr(step):
    if step < warmup_steps:
        return step / warmup_steps
    else:
        progress = (step - warmup_steps) / max(1, total_training_steps - warmup_steps)
        cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
        min_lr_ratio = CFG.get("min_lr_ratio", 0.05)
        return min_lr_ratio + (1 - min_lr_ratio) * cosine

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_lr)

start_epoch, global_step = 0, 0

# Skip checkpoint loading in QUICK_TEST mode (always start fresh)
if not QUICK_TEST:
    ckpt_path = os.path.join(RUN_DIR, "last.pt")
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt["model"]); ema.module.load_state_dict(ckpt["ema"]); ema.updates = ckpt["ema_updates"]
        optimizer.load_state_dict(ckpt["optimizer"]); scaler.load_state_dict(ckpt["scaler"])
        if "scheduler" in ckpt: scheduler.load_state_dict(ckpt["scheduler"])
        start_epoch, global_step = ckpt["epoch"], ckpt["global_step"]
        print(f"Resumed from epoch {start_epoch}, step {global_step}, LR={scheduler.get_last_lr()[0]:.2e}")

accumulate = CFG.get("accumulate", 1)
log_file = os.path.join(RUN_DIR, "train_log.csv")
if not os.path.exists(log_file):
    with open(log_file, "w") as f: f.write("epoch,step,loss,lr,avg_pos,val_avg_preds
")

start_time = time.time()

try:
    for epoch in range(start_epoch, CFG["epochs"]):
        if time.time() - start_time > CFG["time_limit"]:
            print(f"‚è∞ Time limit reached ({CFG['time_limit']}s). Stopping training.")
            break
            
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CFG['epochs']}")
        optimizer.zero_grad()
        
        total_loss = 0.0
        total_num_pos = 0.0
        
        for batch_idx, (imgs, targets) in enumerate(pbar):
            imgs = imgs.to(device, non_blocking=True)
            for k, v in targets.items(): targets[k] = v.to(device, non_blocking=True)
            
            with torch.amp.autocast("cuda", enabled=CFG["amp"]):
                losses, stats = model(imgs, targets)
                loss = losses["loss"] / accumulate
            
            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % accumulate == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.get("grad_clip_norm", 10.0))
                
                scale_before = scaler.get_scale()
                scaler.step(optimizer)
                scaler.update()
                
                # Only step scheduler/EMA/global_step if the optimizer actually stepped (not skipped by AMP)
                if scaler.get_scale() >= scale_before:
                    scheduler.step()
                    ema.update(model)
                    global_step += 1
                optimizer.zero_grad()
            
            total_loss += losses["loss"].item()
            total_num_pos += stats["num_pos"]
            
            avg_loss = total_loss * accumulate / (batch_idx + 1)
            avg_pos = total_num_pos / (batch_idx + 1)
            
            if batch_idx % 10 == 0:
                pbar.set_postfix({"loss": f"{avg_loss:.3f}", "lr": f"{scheduler.get_last_lr()[0]:.2e}", "avg_pos": f"{avg_pos:.1f}"})
            
            if global_step % 200 == 0 and global_step > 0: 
                visualize_inline(ema.module, imgs, targets, global_step, device, COCO_CLASSES)
        
        # Validation
        val_preds = validate(ema.module, val_loader, device, conf_thres=0.001)
        print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | Avg Pos: {avg_pos:.1f} | Val Avg Preds/Img: {val_preds:.2f} (EMA)")
        
        with open(log_file, "a") as f:
             f.write(f"{epoch+1},{global_step},{avg_loss:.4f},{scheduler.get_last_lr()[0]:.2e},{avg_pos:.1f},{val_preds:.2f}
")

        # Skip checkpoint saving in QUICK_TEST mode
        if not QUICK_TEST:
            save_checkpoint(epoch + 1, model, ema, optimizer, scaler, scheduler, RUN_DIR, global_step)
        
except KeyboardInterrupt: 
    print("üõë Interrupted by User! Saving checkpoint...")
    if not QUICK_TEST:
        save_checkpoint(epoch, model, ema, optimizer, scaler, scheduler, RUN_DIR, global_step)

print("Training complete!" if not QUICK_TEST else "Quick test complete!")

