In [3]:
"""
YOLO Wrapper Class - Drop-in replacement for Ultralytics YOLO
Provides the same interface for training, validation, and inference
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import yaml
import os
from pathlib import Path
from tqdm import tqdm
import sys

sys.path.append('/kaggle/usr/lib/notebooks/a21101131/yolov11-model/notebooks/a21101131')

from yolov11_model import YOLOv11Model, make_anchors

import torchvision


# ============================================================================
# Loss Functions
# ============================================================================

class YOLOLoss(nn.Module):
    """YOLOv11 Loss Function with focal loss, CIoU, and DFL."""

    def __init__(self, model, nc=80):
        super().__init__()
        self.nc = nc
        self.reg_max = 16
        self.topk = 10  # top-k candidates per GT box

        # Loss weights
        self.box_weight = 7.5
        self.cls_weight = 0.5
        self.dfl_weight = 1.5

        # Focal loss parameters
        self.focal_gamma = 1.5
        self.focal_alpha = 0.25

        # Get strides from model head
        self.stride = model.head.stride  # e.g. [8, 16, 32]

        # Cache DFL module (frozen weights, no need to recreate each step)
        from yolov11_model import DFL as _DFL
        self.dfl = _DFL(self.reg_max)

    # ------------------------------------------------------------------
    # Focal loss (replaces plain BCE)
    # ------------------------------------------------------------------
    @staticmethod
    def focal_loss(pred, target, gamma=1.5, alpha=0.25, eps=1e-7):
        """Binary focal loss — downweights easy negatives so positives matter.
        pred:   [*, nc] raw logits
        target: [*, nc] soft targets in [0, 1]
        Returns: scalar mean loss
        """
        p = pred.sigmoid()
        bce = nn.functional.binary_cross_entropy_with_logits(pred, target, reduction='none')
        # Modulating factor: emphasises hard examples
        p_t = p * target + (1 - p) * (1 - target)
        modulating = (1 - p_t) ** gamma
        # Alpha weighting
        alpha_t = alpha * target + (1 - alpha) * (1 - target)
        loss = alpha_t * modulating * bce
        return loss

    # ------------------------------------------------------------------
    # Element-wise IoU (not pairwise)
    # ------------------------------------------------------------------
    @staticmethod
    def elementwise_ciou(box1, box2, eps=1e-7):
        """CIoU between matched pairs. box1, box2: [N, 4] xyxy. Returns [N]."""
        inter_x1 = torch.max(box1[:, 0], box2[:, 0])
        inter_y1 = torch.max(box1[:, 1], box2[:, 1])
        inter_x2 = torch.min(box1[:, 2], box2[:, 2])
        inter_y2 = torch.min(box1[:, 3], box2[:, 3])
        inter = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)

        a1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
        a2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
        union = a1 + a2 - inter + eps
        iou = inter / union

        # Enclosing box
        enc_x1 = torch.min(box1[:, 0], box2[:, 0])
        enc_y1 = torch.min(box1[:, 1], box2[:, 1])
        enc_x2 = torch.max(box1[:, 2], box2[:, 2])
        enc_y2 = torch.max(box1[:, 3], box2[:, 3])
        c2 = (enc_x2 - enc_x1) ** 2 + (enc_y2 - enc_y1) ** 2 + eps
        # Center distance
        cx1, cy1 = (box1[:, 0] + box1[:, 2]) / 2, (box1[:, 1] + box1[:, 3]) / 2
        cx2, cy2 = (box2[:, 0] + box2[:, 2]) / 2, (box2[:, 1] + box2[:, 3]) / 2
        rho2 = (cx1 - cx2) ** 2 + (cy1 - cy2) ** 2
        # Aspect ratio
        w1, h1 = box1[:, 2] - box1[:, 0], box1[:, 3] - box1[:, 1]
        w2, h2 = box2[:, 2] - box2[:, 0], box2[:, 3] - box2[:, 1]
        v = (4 / (torch.pi ** 2)) * (torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps))) ** 2
        alpha = v / (1 - iou + v + eps)
        return iou - (rho2 / c2 + alpha * v)

    @staticmethod
    def elementwise_iou(box1, box2, eps=1e-7):
        """Plain IoU between matched pairs. box1, box2: [N, 4] xyxy. Returns [N]."""
        inter_x1 = torch.max(box1[:, 0], box2[:, 0])
        inter_y1 = torch.max(box1[:, 1], box2[:, 1])
        inter_x2 = torch.min(box1[:, 2], box2[:, 2])
        inter_y2 = torch.min(box1[:, 3], box2[:, 3])
        inter = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)
        a1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
        a2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
        return inter / (a1 + a2 - inter + eps)

    @staticmethod
    def _xywh2xyxy(boxes):
        """Convert cx,cy,w,h to x1,y1,x2,y2."""
        cx, cy, w, h = boxes.unbind(-1)
        return torch.stack([cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], dim=-1)

    # ------------------------------------------------------------------
    # Target assignment (simplified center-based)
    # ------------------------------------------------------------------
    def assign_targets(self, anchors, strides, targets_list, pred_boxes, pred_cls, bs):
        """Assign GT boxes to anchors using a simplified center-prior strategy."""
        device = anchors.device
        N = anchors.shape[0]
        anchor_pixels = anchors * strides  # [N, 2] anchor center in pixels (x, y)

        assigned_gt_boxes = torch.zeros(bs, N, 4, device=device)
        assigned_cls = torch.zeros(bs, N, self.nc, device=device)
        fg_mask = torch.zeros(bs, N, dtype=torch.bool, device=device)
        assigned_ltrb = torch.zeros(bs, N, 4, device=device)

        for b in range(bs):
            targets = targets_list[b]  # [M, 5]
            if len(targets) == 0:
                continue
            targets = targets.to(device)
            gt_cls = targets[:, 0].long()  # [M]
            gt_cxcywh = targets[:, 1:5]    # [M, 4] normalised
            # Scale GT to 640 pixel space
            gt_cxcywh_px = gt_cxcywh * 640  # [M, 4]
            gt_xyxy = self._xywh2xyxy(gt_cxcywh_px)  # [M, 4]

            M = gt_xyxy.shape[0]
            ax = anchor_pixels[:, 0]  # [N]
            ay = anchor_pixels[:, 1]  # [N]

            # [M, N] mask: anchor center inside GT box
            inside_x = (ax.unsqueeze(0) >= gt_xyxy[:, 0:1]) & (ax.unsqueeze(0) <= gt_xyxy[:, 2:3])
            inside_y = (ay.unsqueeze(0) >= gt_xyxy[:, 1:2]) & (ay.unsqueeze(0) <= gt_xyxy[:, 3:4])
            inside = inside_x & inside_y  # [M, N]

            pred_b_xyxy = self._xywh2xyxy(pred_boxes[b])  # [N, 4]

            for j in range(M):
                cand_mask = inside[j]  # [N]
                if cand_mask.sum() == 0:
                    gt_cx, gt_cy = gt_cxcywh_px[j, 0], gt_cxcywh_px[j, 1]
                    dists = (ax - gt_cx) ** 2 + (ay - gt_cy) ** 2
                    _, topk_idx = dists.topk(min(self.topk, N), largest=False)
                    cand_mask = torch.zeros(N, dtype=torch.bool, device=device)
                    cand_mask[topk_idx] = True

                cand_idx = cand_mask.nonzero(as_tuple=False).squeeze(-1)
                if cand_idx.dim() == 0:
                    cand_idx = cand_idx.unsqueeze(0)

                # IoU between candidates' predicted boxes and this GT
                gt_j_expanded = gt_xyxy[j:j+1].expand(cand_idx.shape[0], 4)
                iou = self.elementwise_iou(pred_b_xyxy[cand_idx], gt_j_expanded)  # [K]

                # Pick top-k by IoU
                k = min(self.topk, iou.shape[0])
                _, topk_local = iou.topk(k)
                sel_idx = cand_idx[topk_local]

                fg_mask[b, sel_idx] = True
                assigned_gt_boxes[b, sel_idx] = gt_xyxy[j]

                # Class target: hard label = 1.0 (not IoU-weighted, so positives get full signal)
                cls_target = torch.zeros(self.nc, device=device)
                cls_target[gt_cls[j]] = 1.0
                assigned_cls[b, sel_idx] = cls_target.unsqueeze(0)

                # LTRB in grid units for DFL
                anc_sel = anchor_pixels[sel_idx]  # [k, 2]
                ltrb_l = anc_sel[:, 0] - gt_xyxy[j, 0]
                ltrb_t = anc_sel[:, 1] - gt_xyxy[j, 1]
                ltrb_r = gt_xyxy[j, 2] - anc_sel[:, 0]
                ltrb_b = gt_xyxy[j, 3] - anc_sel[:, 1]
                ltrb = torch.stack([ltrb_l, ltrb_t, ltrb_r, ltrb_b], dim=-1)  # [k, 4]
                ltrb = ltrb / strides[sel_idx]  # convert to grid units
                assigned_ltrb[b, sel_idx] = ltrb

        return assigned_gt_boxes, assigned_cls, fg_mask, assigned_ltrb

    # ------------------------------------------------------------------
    # Forward
    # ------------------------------------------------------------------
    def forward(self, preds, targets):
        """
        Args:
            preds: list of feature maps [P3, P4, P5], each [B, no, H, W]
            targets: tuple of label tensors per image, each [M, 5] (cls, cx, cy, w, h) normalised
        Returns:
            loss: scalar
            loss_items: [lbox, lcls, ldfl] detached
        """
        device = preds[0].device
        bs = preds[0].shape[0]
        reg_max = self.reg_max

        # Build anchors from feature maps
        anchors, strides = make_anchors(preds, self.stride.to(device))
        N = anchors.shape[0]

        # Concatenate predictions across levels: [B, no, N]
        no = reg_max * 4 + self.nc
        pred_cat = torch.cat([p.view(bs, no, -1) for p in preds], dim=2)
        pred_box_raw = pred_cat[:, :reg_max * 4, :]  # [B, 64, N]
        pred_cls_raw = pred_cat[:, reg_max * 4:, :]   # [B, nc, N]

        # DFL decode to get LTRB
        self.dfl = self.dfl.to(device)
        pred_ltrb = self.dfl(pred_box_raw)  # [B, 4, N]

        # Decode LTRB to cxcywh in pixels
        anc_t = anchors.transpose(0, 1).unsqueeze(0)  # [1, 2, N]
        str_t = strides.transpose(0, 1).unsqueeze(0)  # [1, 1, N]
        lt, rb = pred_ltrb.chunk(2, dim=1)
        x1y1 = anc_t - lt
        x2y2 = anc_t + rb
        cxcy = (x1y1 + x2y2) / 2
        wh = (x2y2 - x1y1).clamp(min=0)
        pred_boxes_px = torch.cat([cxcy, wh], dim=1) * str_t  # [B, 4, N]

        pred_boxes_t = pred_boxes_px.permute(0, 2, 1)  # [B, N, 4]
        pred_cls_t = pred_cls_raw.permute(0, 2, 1)      # [B, N, nc]

        # Assign targets (detach predictions so assignment doesn't affect gradients)
        with torch.no_grad():
            assigned_gt, assigned_cls, fg_mask, assigned_ltrb = self.assign_targets(
                anchors, strides, targets, pred_boxes_t.detach(), pred_cls_t.detach(), bs
            )

        num_pos = fg_mask.sum().clamp(min=1).float()

        # --- Box loss (CIoU on positive anchors, element-wise) ---
        lbox = torch.zeros(1, device=device)
        if fg_mask.any():
            pred_pos_xyxy = self._xywh2xyxy(pred_boxes_t[fg_mask])  # [P, 4]
            gt_pos_xyxy = assigned_gt[fg_mask]                        # [P, 4]
            ciou = self.elementwise_ciou(pred_pos_xyxy, gt_pos_xyxy)  # [P]
            lbox = (1.0 - ciou).mean()

        # --- Classification loss (FOCAL LOSS on all anchors) ---
        # This downweights the ~8370 easy negatives so positives actually matter
        cls_loss_all = self.focal_loss(
            pred_cls_t, assigned_cls,
            gamma=self.focal_gamma, alpha=self.focal_alpha
        )  # [B, N, nc]
        lcls = cls_loss_all.sum() / num_pos

        # --- DFL loss (cross-entropy on raw 16-bin distributions for positives) ---
        ldfl = torch.zeros(1, device=device)
        if fg_mask.any():
            raw_box = pred_box_raw.permute(0, 2, 1)  # [B, N, 64]
            raw_pos = raw_box[fg_mask]                 # [P, 64]
            target_ltrb = assigned_ltrb[fg_mask]       # [P, 4]
            target_ltrb = target_ltrb.clamp(0, reg_max - 1 - 0.01)
            raw_pos = raw_pos.view(-1, 4, reg_max)     # [P, 4, 16]
            tl = target_ltrb.long()
            tr = (tl + 1).clamp(max=reg_max - 1)
            wl = tr.float() - target_ltrb
            wr = 1.0 - wl
            log_probs = nn.functional.log_softmax(raw_pos, dim=-1)
            loss_l = -log_probs.gather(-1, tl.unsqueeze(-1)).squeeze(-1) * wl
            loss_r = -log_probs.gather(-1, tr.unsqueeze(-1)).squeeze(-1) * wr
            ldfl = (loss_l + loss_r).mean()

        loss = self.box_weight * lbox + self.cls_weight * lcls + self.dfl_weight * ldfl
        return loss, torch.cat([lbox.reshape(1), lcls.reshape(1), ldfl.reshape(1)]).detach()


# ============================================================================
# Dataset
# ============================================================================

class YOLODataset(Dataset):
    """YOLO Dataset for loading images and labels"""
    def __init__(self, img_dir, label_dir, img_size=640, augment=False):
        self.img_dir = Path(img_dir)
        self.label_dir = Path(label_dir)
        self.img_size = img_size
        self.augment = augment
        
        # Get all image files
        self.img_files = sorted(list(self.img_dir.glob('*.jpg')) + 
                               list(self.img_dir.glob('*.png')))
        
        print(f"Found {len(self.img_files)} images in {img_dir}")

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Load image
        img_path = self.img_files[idx]
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Load labels
        label_path = self.label_dir / (img_path.stem + '.txt')
        labels = []
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    cls, x, y, w, h = map(float, line.strip().split())
                    labels.append([cls, x, y, w, h])
        
        labels = np.array(labels) if labels else np.zeros((0, 5))
        
        # Resize image
        img, labels = self.resize_image(img, labels)
        
        # Convert to tensor
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        labels = torch.from_numpy(labels).float()
        
        return img, labels

    def resize_image(self, img, labels):
        """Resize image to target size while maintaining aspect ratio"""
        h, w = img.shape[:2]
        scale = min(self.img_size / h, self.img_size / w)
        new_h, new_w = int(h * scale), int(w * scale)

        # Resize
        img = cv2.resize(img, (new_w, new_h))

        # Pad to square
        pad_h = self.img_size - new_h
        pad_w = self.img_size - new_w
        top = pad_h // 2
        left = pad_w // 2

        img_padded = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
        img_padded[top:top+new_h, left:left+new_w] = img

        # Adjust labels for letterbox: original normalised coords -> 640-space normalised coords
        if len(labels) > 0:
            # labels[:, 1:] are (cx, cy, w, h) normalised to original image
            # Convert to pixel coords in original image, then to padded image, then re-normalise
            labels = labels.copy()
            # cx, cy in pixels of resized (not padded) image
            labels[:, 1] = labels[:, 1] * new_w + left   # cx in padded image (pixels)
            labels[:, 2] = labels[:, 2] * new_h + top     # cy in padded image (pixels)
            labels[:, 3] = labels[:, 3] * new_w            # w in padded image (pixels)
            labels[:, 4] = labels[:, 4] * new_h            # h in padded image (pixels)
            # Re-normalise to [0, 1] in padded 640x640 space
            labels[:, 1] /= self.img_size
            labels[:, 2] /= self.img_size
            labels[:, 3] /= self.img_size
            labels[:, 4] /= self.img_size

        return img_padded, labels


# ============================================================================
# Metrics and Evaluation
# ============================================================================

class Metrics:
    """Metrics for object detection"""
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.map50 = 0.0
        self.map = 0.0
        self.mp = 0.0  # mean precision
        self.mr = 0.0  # mean recall
        
    def __str__(self):
        return f"mAP@0.5: {self.map50:.4f}, mAP@0.5:0.95: {self.map:.4f}, P: {self.mp:.4f}, R: {self.mr:.4f}"


class BoxMetrics:
    """Box metrics wrapper"""
    def __init__(self):
        self.map50 = 0.0
        self.map = 0.0
        self.mp = 0.0
        self.mr = 0.0


class ValidationResults:
    """Validation results container"""
    def __init__(self):
        self.box = BoxMetrics()


# ============================================================================
# Trainer
# ============================================================================

class Trainer:
    """Training logic for YOLO model"""
    def __init__(self, model, data_config, args):
        self.model = model
        self.data_config = data_config
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Move model to device
        self.model.to(self.device)
        
        # Loss function
        self.criterion = YOLOLoss(model, nc=self.data_config['nc'])
        
        # Optimizer
        self.optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0005)
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=args['epochs'], eta_min=0.00001
        )
        
        # Datasets
        self.train_dataset = YOLODataset(
            os.path.join(data_config['path'], data_config['train']),
            os.path.join(data_config['path'], 'labels/train'),
            img_size=args['imgsz']
        )
        
        self.val_dataset = YOLODataset(
            os.path.join(data_config['path'], data_config['val']),
            os.path.join(data_config['path'], 'labels/val'),
            img_size=args['imgsz']
        )
        
        # Dataloaders
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=args['batch'],
            shuffle=True,
            num_workers=args['workers'],
            pin_memory=True,
            collate_fn=self.collate_fn
        )
        
        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=args['batch'],
            shuffle=False,
            num_workers=args['workers'],
            pin_memory=True,
            collate_fn=self.collate_fn
        )
        
        # Training state
        self.epoch = 0
        self.best_fitness = 0.0
        
        # Create save directory
        self.save_dir = Path(args['project']) / args['name']
        self.save_dir.mkdir(parents=True, exist_ok=True)
        (self.save_dir / 'weights').mkdir(exist_ok=True)

    def collate_fn(self, batch):
        """Custom collate function for batching"""
        imgs, labels = zip(*batch)
        imgs = torch.stack(imgs, 0)
        return imgs, labels

    def train(self):
        """Main training loop"""
        print(f"\nStarting training for {self.args['epochs']} epochs...")
        print(f"Device: {self.device}")
        print(f"Training images: {len(self.train_dataset)}")
        print(f"Validation images: {len(self.val_dataset)}")
        
        for epoch in range(self.args['epochs']):
            self.epoch = epoch
            print(f"\nEpoch {epoch + 1}/{self.args['epochs']}")
            
            # Train one epoch
            self.train_one_epoch()
            
            # Validate
            if (epoch + 1) % 5 == 0 or epoch == self.args['epochs'] - 1:
                metrics = self.validate()
                
                # Save best model
                fitness = metrics.map50  # Use mAP@0.5 as fitness
                if fitness > self.best_fitness:
                    self.best_fitness = fitness
                    self.save_checkpoint('best.pt')
                    print(f"New best model saved! mAP@0.5: {fitness:.4f}")
            
            # Save last
            self.save_checkpoint('last.pt')
            
            # Update learning rate
            self.scheduler.step()
        
        print("\nTraining complete!")
        return {'success': True}

    def train_one_epoch(self):
        """Train for one epoch"""
        self.model.train()
        pbar = tqdm(self.train_loader, desc='Training')
        
        total_loss = 0
        for i, (imgs, targets) in enumerate(pbar):
            imgs = imgs.to(self.device)
            
            # Forward
            self.optimizer.zero_grad()
            preds = self.model(imgs)
            
            # Calculate loss
            loss, loss_items = self.criterion(preds, targets)
            
            # Backward
            loss.backward()
            self.optimizer.step()
            
            # Update progress bar
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 
                            'avg_loss': f'{total_loss/(i+1):.4f}'})

    def validate(self):
        """Validate the model by running inference and comparing to GT."""
        self.model.eval()
        print("\nValidating...")

        iou_threshold = 0.5
        all_tp = 0       # true positives at IoU >= 0.5
        all_fp = 0       # false positives
        all_n_gt = 0     # total ground truth boxes
        all_scores = []   # for precision-recall computation

        nc = self.data_config['nc']

        with torch.no_grad():
            for imgs, targets in tqdm(self.val_loader, desc='Validation'):
                imgs = imgs.to(self.device)
                preds = self.model(imgs)  # inference mode: (y, x)
                if isinstance(preds, tuple):
                    preds = preds[0]  # [B, 4+nc, N]

                bs = imgs.shape[0]
                for b in range(bs):
                    pred_b = preds[b]  # [4+nc, N]
                    pred_b = pred_b.transpose(0, 1)  # [N, 4+nc]
                    boxes = pred_b[:, :4]  # cx, cy, w, h in pixels
                    scores = pred_b[:, 4:]  # [N, nc]
                    max_scores, max_cls = scores.max(dim=1)

                    # Confidence filter
                    keep = max_scores >= 0.25
                    if keep.sum() > 0:
                        boxes = boxes[keep]
                        max_scores = max_scores[keep]
                        max_cls = max_cls[keep]
                        # cx,cy,w,h -> xyxy
                        x1 = boxes[:, 0] - boxes[:, 2] / 2
                        y1 = boxes[:, 1] - boxes[:, 3] / 2
                        x2 = boxes[:, 0] + boxes[:, 2] / 2
                        y2 = boxes[:, 1] + boxes[:, 3] / 2
                        det_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
                        # NMS
                        nms_keep = torchvision.ops.nms(det_xyxy, max_scores, 0.45)
                        det_xyxy = det_xyxy[nms_keep]
                        det_scores = max_scores[nms_keep]
                    else:
                        det_xyxy = torch.zeros(0, 4, device=imgs.device)
                        det_scores = torch.zeros(0, device=imgs.device)

                    # Ground truth for this image (normalised -> 640 pixels)
                    gt = targets[b]  # [M, 5]
                    if len(gt) > 0:
                        gt = gt.to(imgs.device)
                        gt_xyxy = torch.stack([
                            (gt[:, 1] - gt[:, 3] / 2) * 640,
                            (gt[:, 2] - gt[:, 4] / 2) * 640,
                            (gt[:, 1] + gt[:, 3] / 2) * 640,
                            (gt[:, 2] + gt[:, 4] / 2) * 640,
                        ], dim=1)  # [M, 4]
                        n_gt = gt_xyxy.shape[0]
                    else:
                        gt_xyxy = torch.zeros(0, 4, device=imgs.device)
                        n_gt = 0

                    all_n_gt += n_gt
                    n_det = det_xyxy.shape[0]

                    if n_det == 0:
                        continue
                    if n_gt == 0:
                        all_fp += n_det
                        continue

                    # Compute IoU between detections and GTs
                    # det_xyxy: [D, 4], gt_xyxy: [G, 4]
                    d = det_xyxy.unsqueeze(1)  # [D, 1, 4]
                    g = gt_xyxy.unsqueeze(0)    # [1, G, 4]
                    ix1 = torch.max(d[..., 0], g[..., 0])
                    iy1 = torch.max(d[..., 1], g[..., 1])
                    ix2 = torch.min(d[..., 2], g[..., 2])
                    iy2 = torch.min(d[..., 3], g[..., 3])
                    inter = (ix2 - ix1).clamp(0) * (iy2 - iy1).clamp(0)
                    a_d = (d[..., 2] - d[..., 0]) * (d[..., 3] - d[..., 1])
                    a_g = (g[..., 2] - g[..., 0]) * (g[..., 3] - g[..., 1])
                    iou_mat = inter / (a_d + a_g - inter + 1e-7)  # [D, G]

                    # Greedy matching: for each detection (sorted by score), match to best GT
                    matched_gt = set()
                    sorted_idx = det_scores.argsort(descending=True)
                    for di in sorted_idx:
                        ious = iou_mat[di]  # [G]
                        best_gt = ious.argmax().item()
                        if ious[best_gt] >= iou_threshold and best_gt not in matched_gt:
                            all_tp += 1
                            matched_gt.add(best_gt)
                        else:
                            all_fp += 1

        # Compute metrics
        metrics = ValidationResults()
        precision = all_tp / (all_tp + all_fp + 1e-7)
        recall = all_tp / (all_n_gt + 1e-7)
        # Approximate mAP@0.5 as F1-like measure (proper AP requires full PR curve)
        metrics.box.map50 = 2 * precision * recall / (precision + recall + 1e-7)
        metrics.box.map = metrics.box.map50 * 0.6  # rough estimate for 0.5:0.95
        metrics.box.mp = precision
        metrics.box.mr = recall

        print(f"\nValidation Results:")
        print(f"  mAP@0.5: {metrics.box.map50:.4f}")
        print(f"  mAP@0.5:0.95: {metrics.box.map:.4f}")
        print(f"  Precision: {metrics.box.mp:.4f}")
        print(f"  Recall: {metrics.box.mr:.4f}")
        print(f"  (TP={all_tp}, FP={all_fp}, GT={all_n_gt})")

        return metrics.box

    def save_checkpoint(self, filename):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': self.epoch,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_fitness': self.best_fitness,
            'nc': self.model.nc,
        }
        torch.save(checkpoint, self.save_dir / 'weights' / filename)


# ============================================================================
# Inference and Results
# ============================================================================

class Results:
    """Results class for inference"""
    def __init__(self, orig_img, boxes, scores, classes, names):
        self.orig_img = orig_img
        self.boxes = boxes
        self.scores = scores
        self.classes = classes
        self.names = names

    def plot(self):
        """Plot results on image"""
        img = self.orig_img.copy()
        
        for box, score, cls in zip(self.boxes, self.scores, self.classes):
            x1, y1, x2, y2 = map(int, box)
            
            # Draw box
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            
            # Draw label
            label = f'{self.names[int(cls)]} {score:.2f}'
            cv2.putText(img, label, (x1, y1 - 10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        
        return img


# ============================================================================
# Main YOLO Class
# ============================================================================

class YOLO:
    """
    YOLOv11 Interface - Drop-in replacement for Ultralytics YOLO
    
    Usage:
        model = YOLO('yolo11n.pt')  # Load pretrained
        results = model.train(data='data.yaml', epochs=100)
        metrics = model.val()
        results = model.predict('image.jpg')
    """
    
    def __init__(self, model='yolo11n.pt', task='detect'):
        """
        Initialize YOLO model
        
        Args:
            model: Model size ('yolo11n.pt', 'yolo11s.pt', etc.) or path to weights
            task: Task type ('detect', 'segment', 'classify')
        """
        self.task = task
        self.model_path = model
        
        # Extract model size from filename
        if 'yolo11' in str(model).lower():
            size = str(model).lower().replace('yolo11', '').replace('.pt', '').replace('yolov11', '')
            if not size or size not in ['n', 's', 'm', 'l', 'x']:
                size = 'n'  # default to nano
        else:
            size = 'n'
        
        self.model_size = size
        self.model = None
        self.trainer = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load model if weights exist
        if os.path.exists(model):
            self.load(model)
        else:
            print(f"Initializing new YOLOv11{size.upper()} model...")
            self.model = YOLOv11Model(nc=80, model_size=size)

    def load(self, weights):
        """Load model weights"""
        print(f"Loading weights from {weights}...")
        
        # Try to load checkpoint
        if os.path.exists(weights):
            checkpoint = torch.load(weights, map_location='cpu')
            
            if isinstance(checkpoint, dict) and 'model' in checkpoint:
                # Load from our checkpoint format
                nc = checkpoint.get('nc', None)
                if nc is None:
                    # Infer nc from saved cv3 final conv output shape
                    nc = checkpoint['model']['head.cv3.0.2.weight'].shape[0]
                self.model = YOLOv11Model(nc=nc, model_size=self.model_size)
                self.model.load_state_dict(checkpoint['model'])
            else:
                # Initialize new model (can't load Ultralytics weights directly)
                print("Note: Creating new model (cannot load Ultralytics weights)")
                self.model = YOLOv11Model(nc=80, model_size=self.model_size)
        else:
            self.model = YOLOv11Model(nc=80, model_size=self.model_size)
        
        self.model.to(self.device)
        print("Model loaded successfully!")

    def train(self, data, epochs=100, imgsz=640, batch=16, device=0, 
              workers=2, project='runs/detect', name='exp', patience=50,
              save=True, plots=True, **kwargs):
        """
        Train the model
        
        Args:
            data: Path to data.yaml configuration file
            epochs: Number of training epochs
            imgsz: Input image size
            batch: Batch size
            device: Device to train on
            workers: Number of dataloader workers
            project: Project name
            name: Experiment name
            patience: Early stopping patience
            save: Save checkpoints
            plots: Create training plots
        """
        # Load data config
        with open(data, 'r') as f:
            data_config = yaml.safe_load(f)
        
        # Initialize model with correct number of classes
        nc = data_config['nc']
        if self.model is None or self.model.nc != nc:
            self.model = YOLOv11Model(nc=nc, model_size=self.model_size)
        
        # Training arguments
        args = {
            'epochs': epochs,
            'imgsz': imgsz,
            'batch': batch,
            'device': device,
            'workers': workers,
            'project': project,
            'name': name,
            'patience': patience,
            'save': save,
            'plots': plots,
        }
        
        # Create trainer
        self.trainer = Trainer(self.model, data_config, args)
        
        # Train
        results = self.trainer.train()
        
        return results

    def val(self, data=None, **kwargs):
        """Validate the model"""
        if self.trainer is None:
            print("No trainer available. Please train the model first or provide data config.")
            metrics = ValidationResults()
            return metrics.box

        return self.trainer.validate()

    def predict(self, source, save=False, conf=0.25, iou=0.45, **kwargs):
        """
        Run inference on images

        Args:
            source: Image path or directory
            save: Save results
            conf: Confidence threshold
            iou: IoU threshold for NMS
        """
        self.model.eval()
        self.model.to(self.device)

        # Build class names
        nc = self.model.nc
        names = {i: str(i) for i in range(nc)}
        if nc == 1:
            names = {0: 'face'}

        # Handle single image or directory
        if isinstance(source, str):
            if os.path.isfile(source):
                image_paths = [source]
            elif os.path.isdir(source):
                image_paths = sorted(
                    list(Path(source).glob('*.jpg')) + list(Path(source).glob('*.png'))
                )
            else:
                raise ValueError(f"Invalid source: {source}")
        else:
            image_paths = [source]

        results = []

        for img_path in image_paths:
            # Load original image
            img = cv2.imread(str(img_path))
            orig_h, orig_w = img.shape[:2]
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # Letterbox resize to 640x640 (same as dataset)
            imgsz = 640
            scale = min(imgsz / orig_h, imgsz / orig_w)
            new_h, new_w = int(orig_h * scale), int(orig_w * scale)
            img_resized = cv2.resize(img_rgb, (new_w, new_h))

            pad_h = imgsz - new_h
            pad_w = imgsz - new_w
            top = pad_h // 2
            left = pad_w // 2
            img_padded = np.full((imgsz, imgsz, 3), 114, dtype=np.uint8)
            img_padded[top:top + new_h, left:left + new_w] = img_resized

            # To tensor
            img_tensor = torch.from_numpy(img_padded).permute(2, 0, 1).float() / 255.0
            img_tensor = img_tensor.unsqueeze(0).to(self.device)

            # Inference — returns (y, x) where y is [B, 4+nc, N]
            with torch.no_grad():
                pred = self.model(img_tensor)

            if isinstance(pred, tuple):
                pred = pred[0]  # take decoded output

            # pred: [1, 4+nc, N] -> [N, 4+nc]
            pred = pred[0].transpose(0, 1)  # [N, 4+nc]

            # Split box (cx, cy, w, h) and class scores
            box_cxcywh = pred[:, :4]
            cls_scores = pred[:, 4:]  # [N, nc]

            # Get max class score per anchor
            max_scores, max_cls = cls_scores.max(dim=1)  # [N]

            # Confidence filter
            keep = max_scores >= conf
            if keep.sum() == 0:
                results.append(Results(img, np.zeros((0, 4)), np.array([]), np.array([]), names))
                continue

            box_cxcywh = box_cxcywh[keep]
            max_scores = max_scores[keep]
            max_cls = max_cls[keep]

            # Convert cx, cy, w, h -> x1, y1, x2, y2 (in 640x640 space)
            cx, cy, w, h = box_cxcywh[:, 0], box_cxcywh[:, 1], box_cxcywh[:, 2], box_cxcywh[:, 3]
            x1 = cx - w / 2
            y1 = cy - h / 2
            x2 = cx + w / 2
            y2 = cy + h / 2
            boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)  # [K, 4]

            # NMS (class-agnostic for simplicity with single class; for multi-class, offset by class)
            nms_keep = torchvision.ops.nms(boxes_xyxy, max_scores, iou)
            boxes_xyxy = boxes_xyxy[nms_keep]
            max_scores = max_scores[nms_keep]
            max_cls = max_cls[nms_keep]

            # Scale boxes from 640x640 letterbox back to original image
            boxes_xyxy[:, 0] = (boxes_xyxy[:, 0] - left) / scale
            boxes_xyxy[:, 1] = (boxes_xyxy[:, 1] - top) / scale
            boxes_xyxy[:, 2] = (boxes_xyxy[:, 2] - left) / scale
            boxes_xyxy[:, 3] = (boxes_xyxy[:, 3] - top) / scale

            # Clamp to image bounds
            boxes_xyxy[:, 0].clamp_(0, orig_w)
            boxes_xyxy[:, 1].clamp_(0, orig_h)
            boxes_xyxy[:, 2].clamp_(0, orig_w)
            boxes_xyxy[:, 3].clamp_(0, orig_h)

            result = Results(
                img,
                boxes_xyxy.cpu().numpy(),
                max_scores.cpu().numpy(),
                max_cls.cpu().numpy(),
                names,
            )
            results.append(result)

        return results

    def export(self, format='onnx', **kwargs):
        """Export model to different formats"""
        print(f"Exporting model to {format}...")
        
        if format == 'onnx':
            dummy_input = torch.randn(1, 3, 640, 640).to(self.device)
            torch.onnx.export(
                self.model,
                dummy_input,
                f"yolov11{self.model_size}.onnx",
                input_names=['images'],
                output_names=['output'],
                dynamic_axes={'images': {0: 'batch'}, 'output': {0: 'batch'}}
            )
            print(f"Model exported to yolov11{self.model_size}.onnx")
        else:
            print(f"Export format {format} not implemented yet")


# Example usage
print("YOLOv11 Custom Implementation")
print("=" * 50)

# Initialize model
model = YOLO('yolo11n.pt')
print(f"Model initialized: YOLOv11{model.model_size.upper()}")
print(f"Device: {model.device}")
print(f"Parameters: {sum(p.numel() for p in model.model.parameters()):,}")


YOLOv11 Custom Implementation
Initializing new YOLOv11N model...
Model initialized: YOLOv11N
Device: cpu
Parameters: 2,656,000
