In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn.functional import cross_entropy, one_hot

# model.py - Fix the RepViTBlock for mixed precision

class RepViTBlock(nn.Module):
    def __init__(self, in_channels, se_ratio=0.25, stride=1):
        super().__init__()
        self.stride = stride
        self.in_channels = in_channels
        self.fused = False

        # depthwise 3x3
        self.rbr_dense = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels)
        )

        # depthwise 1x1
        self.rbr_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1, stride, 0, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels)
        )

        # pointwise conv
        self.pwconv = nn.Conv2d(in_channels, in_channels, 1, bias=False)
        self.bn_pw = nn.BatchNorm2d(in_channels)

        # SE block
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, max(1, int(in_channels * se_ratio)), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(1, int(in_channels * se_ratio)), in_channels, 1),
            nn.Sigmoid()
        )

        self.shortcut = (stride == 1)

    def forward(self, x):
        # Check if we're using reparameterized version
        if hasattr(self, 'rbr_reparam'):
            # Using reparameterized convolution
            if self.rbr_reparam.weight.device != x.device:
                self.rbr_reparam = self.rbr_reparam.to(x.device)
            if self.rbr_reparam.weight.dtype != x.dtype:
                self.rbr_reparam = self.rbr_reparam.to(x.dtype)
            y = self.rbr_reparam(x)
        else:
            # Using original branches - ensure they're on correct device
            if not self.fused:
                if self.rbr_dense[0].weight.device != x.device:
                    self.rbr_dense = self.rbr_dense.to(x.device)
                if self.rbr_1x1[0].weight.device != x.device:
                    self.rbr_1x1 = self.rbr_1x1.to(x.device)

            y = self.rbr_dense(x) + self.rbr_1x1(x)

        # Ensure other components are on correct device
        if self.pwconv.weight.device != x.device:
            self.pwconv = self.pwconv.to(x.device)
            self.bn_pw = self.bn_pw.to(x.device)
        if self.se[1].weight.device != x.device:
            self.se = self.se.to(x.device)

        y = self.pwconv(y)
        y = self.bn_pw(y)
        y = y * self.se(y)
        y = F.relu(y)
        return x + y if self.shortcut else y

    def _fuse_reparam(self):
        # Ensure fusion happens on the correct device
        device = self.rbr_dense[0].weight.device
        k3, b3 = self._fuse_conv_bn(self.rbr_dense)
        k1, b1 = self._fuse_conv_bn(self.rbr_1x1)
        k1_pad = F.pad(k1, [1, 1, 1, 1])
        fused_k = k3 + k1_pad
        fused_b = b3 + b1

        self.rbr_reparam = nn.Conv2d(
            self.in_channels, self.in_channels, kernel_size=3,
            stride=self.stride, padding=1, groups=self.in_channels, bias=True
        ).to(device)

        with torch.no_grad():
            self.rbr_reparam.weight.copy_(fused_k)
            self.rbr_reparam.bias.copy_(fused_b)

        # Only delete if they exist (for safety)
        if hasattr(self, 'rbr_dense'):
            del self.rbr_dense
        if hasattr(self, 'rbr_1x1'):
            del self.rbr_1x1
        self.fused = True

    @staticmethod
    def _fuse_conv_bn(branch):
        conv = branch[0]
        bn = branch[1]
        w = conv.weight
        if conv.bias is None:
            bias = torch.zeros(w.size(0), device=w.device)
        else:
            bias = conv.bias

        bn_var_rsqrt = 1.0 / torch.sqrt(bn.running_var + bn.eps)
        w_fused = w * (bn.weight * bn_var_rsqrt).reshape(-1, 1, 1, 1)
        b_fused = bn.bias + (bias - bn.running_mean) * bn_var_rsqrt * bn.weight
        return w_fused, b_fused

    def fuse(self):
        """Fuse the block for inference"""
        if not self.fused:
            self._fuse_reparam()
        return self

# =========================================================
# RepViT Backbone
# =========================================================
class RepViTBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.stage1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            RepViTBlock(32)
        )
        self.stage2 = nn.Sequential(
            RepViTBlock(32, stride=2),
            RepViTBlock(32)
        )
        self.stage3 = nn.Sequential(
            RepViTBlock(32, stride=2),
            RepViTBlock(32),
            RepViTBlock(32)
        )

    def forward(self, x):
        c2 = self.stage1(x)
        c3 = self.stage2(c2)
        c4 = self.stage3(c3)
        return [c2, c3, c4]

# =========================================================
# Ghost Module
# =========================================================
class GhostModule(nn.Module):
    def __init__(self, in_ch, out_ch, ratio=2, kernel_size=1, dw_size=3, stride=1, relu=True):
        super().__init__()
        init_ch = out_ch // ratio
        new_ch = out_ch - init_ch
        self.primary = nn.Sequential(
            nn.Conv2d(in_ch, init_ch, kernel_size, stride, kernel_size // 2, bias=False),
            nn.BatchNorm2d(init_ch),
            nn.ReLU(inplace=True) if relu else nn.Identity()
        )
        self.cheap = nn.Sequential(
            nn.Conv2d(init_ch, new_ch, dw_size, 1, dw_size // 2, groups=init_ch, bias=False),
            nn.BatchNorm2d(new_ch),
            nn.ReLU(inplace=True) if relu else nn.Identity()
        )

    def forward(self, x):
        y = self.primary(x)
        z = self.cheap(y)
        return torch.cat([y, z], dim=1)

# =========================================================
# GhostNeck (multi-scale features)
# =========================================================
class GhostNeck(nn.Module):
    def __init__(self, chs=[32, 32, 32]):
        super().__init__()
        c2, c3, c4 = chs
        self.reduce_c4 = GhostModule(c4, 64)            # p5 channels = 64
        self.reduce_c3 = GhostModule(c3 + 64, 48)       # p4 channels = 48
        self.reduce_c2 = GhostModule(c2 + 48, 32)       # p3 channels = 32
        self.up = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, feats):
        c2, c3, c4 = feats
        p5 = self.reduce_c4(c4)
        p4 = self.reduce_c3(torch.cat([self.up(p5), c3], dim=1))
        p3 = self.reduce_c2(torch.cat([self.up(p4), c2], dim=1))
        return p3, p4, p5

# =========================================================
# GhostHead (detection head)
# =========================================================
class GhostHead(nn.Module):
    def __init__(self, in_ch, num_classes):
        super().__init__()
        self.conv = GhostModule(in_ch, 32)
        # Output: 4 box coordinates + num_classes
        self.pred = nn.Conv2d(32, 4 + num_classes, 1)

    def forward(self, x):
        x = self.conv(x)
        x = self.pred(x)
        return x

# =========================================================
# YOLOv8Hybrid (single or multi-head)
# =========================================================
class YOLOv8Hybrid(nn.Module):
    def __init__(self, num_classes=80, multi_head=False):
        super().__init__()
        self.multi_head = multi_head
        self.num_classes = num_classes
        self.backbone = RepViTBackbone()
        self.neck = GhostNeck([32, 32, 32])

        if multi_head:
            self.head_p3 = GhostHead(32, num_classes)
            self.head_p4 = GhostHead(48, num_classes)
            self.head_p5 = GhostHead(64, num_classes)
        else:
            # project p4/p5 -> p3 channels before summation
            self.proj_p4 = nn.Conv2d(48, 32, 1, bias=False)
            self.bn_p4 = nn.BatchNorm2d(32)
            self.proj_p5 = nn.Conv2d(64, 32, 1, bias=False)
            self.bn_p5 = nn.BatchNorm2d(32)
            self.head = GhostHead(32, num_classes)

    def forward(self, x):
        p3, p4, p5 = self.neck(self.backbone(x))

        if self.multi_head:
            out_p3 = self.head_p3(p3)
            out_p4 = self.head_p4(p4)
            out_p5 = self.head_p5(p5)
            return [out_p3, out_p4, out_p5]
        else:
            p4_up = F.interpolate(p4, size=p3.shape[2:], mode='nearest')
            p5_up = F.interpolate(p5, size=p3.shape[2:], mode='nearest')

            p4_proj = self.bn_p4(self.proj_p4(p4_up))
            p5_proj = self.bn_p5(self.proj_p5(p5_up))

            fused = p3 + p4_proj + p5_proj
            return self.head(fused)

In [2]:
import copy
import math
import random
import time

import numpy
import torch
import torchvision
from torch.nn.functional import cross_entropy, one_hot


def setup_seed():
    """
    Setup random seed.
    """
    random.seed(0)
    numpy.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def setup_multi_processes():
    """
    Setup multi-processing environment variables.
    """
    import cv2
    from os import environ
    from platform import system

    # set multiprocess start method as `fork` to speed up the training
    if system() != 'Windows':
        torch.multiprocessing.set_start_method('fork', force=True)

    # disable opencv multithreading to avoid system being overloaded
    cv2.setNumThreads(0)

    # setup OMP threads
    if 'OMP_NUM_THREADS' not in environ:
        environ['OMP_NUM_THREADS'] = '1'

    # setup MKL threads
    if 'MKL_NUM_THREADS' not in environ:
        environ['MKL_NUM_THREADS'] = '1'


def scale(coords, shape1, gain, pad):
    coords[:, [0, 2]] -= pad[0]  # x padding
    coords[:, [1, 3]] -= pad[1]  # y padding
    coords[:, :4] /= gain[0]  # gain_x == gain_y for letterbox
    coords[:, 0].clamp_(0, shape1[1])  # x1
    coords[:, 1].clamp_(0, shape1[0])  # y1
    coords[:, 2].clamp_(0, shape1[1])  # x2
    coords[:, 3].clamp_(0, shape1[0])  # y2
    return coords


def make_anchors(x, strides, offset=0.5):
    """
    Generate anchors from features
    """
    assert x is not None
    anchor_points, stride_tensor = [], []
    for i, stride in enumerate(strides):
        _, _, h, w = x[i].shape
        sx = torch.arange(end=w, dtype=x[i].dtype, device=x[i].device) + offset  # shift x
        sy = torch.arange(end=h, dtype=x[i].dtype, device=x[i].device) + offset  # shift y
        sy, sx = torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=x[i].dtype, device=x[i].device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)


def box_iou(box1, box2):
    # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
    """
    Return intersection-over-union (Jaccard index) of boxes.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
    Arguments:
        box1 (Tensor[N, 4])
        box2 (Tensor[M, 4])
    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise
            IoU values for every element in boxes1 and boxes2
    """

    # intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    (a1, a2), (b1, b2) = box1[:, None].chunk(2, 2), box2.chunk(2, 1)
    intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)

    # IoU = intersection / (area1 + area2 - intersection)
    box1 = box1.T
    box2 = box2.T

    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

    return intersection / (area1[:, None] + area2 - intersection)


def wh2xy(x):
    y = x.clone()
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
    y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
    y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
    return y


def non_max_suppression(prediction, conf_threshold=0.25, iou_threshold=0.45):
    nc = prediction.shape[1] - 4  # number of classes
    xc = prediction[:, 4:4 + nc].amax(1) > conf_threshold  # candidates

    # Settings
    max_wh = 7680  # (pixels) maximum box width and height
    max_det = 300  # the maximum number of boxes to keep after NMS
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()

    start = time.time()
    outputs = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    for index, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        x = x.transpose(0, -1)[xc[index]]  # confidence

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Detections matrix nx6 (box, conf, cls)
        box, cls = x.split((4, nc), 1)
        # center_x, center_y, width, height) to (x1, y1, x2, y2)
        box = wh2xy(box)
        if nc > 1:
            i, j = (cls > conf_threshold).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_threshold]
        # Check shape
        if not x.shape[0]:  # no boxes
            continue
        # sort by confidence and remove excess boxes
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]

        # Batched NMS
        c = x[:, 5:6] * max_wh  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_threshold)  # NMS
        i = i[:max_det]  # limit detections
        outputs[index] = x[i]
        if (time.time() - start) > 0.5 + 0.05 * prediction.shape[0]:
            print(f'WARNING ⚠️ NMS time limit {0.5 + 0.05 * prediction.shape[0]:.3f}s exceeded')
            break  # time limit exceeded

    return outputs


def smooth(y, f=0.05):
    # Box filter of fraction f
    nf = round(len(y) * f * 2) // 2 + 1  # number of filter elements (must be odd)
    p = numpy.ones(nf // 2)  # ones padding
    yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0)  # y padded
    return numpy.convolve(yp, numpy.ones(nf) / nf, mode='valid')  # y-smoothed


def compute_ap(tp, conf, pred_cls, target_cls, eps=1e-16):
    """
    Compute the average precision, given the recall and precision curves.
    Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
    # Arguments
        tp:  True positives (nparray, nx1 or nx10).
        conf:  Object-ness value from 0-1 (nparray).
        pred_cls:  Predicted object classes (nparray).
        target_cls:  True object classes (nparray).
    # Returns
        The average precision
    """
    # Sort by object-ness
    i = numpy.argsort(-conf)
    tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]

    # Find unique classes
    unique_classes, nt = numpy.unique(target_cls, return_counts=True)
    nc = unique_classes.shape[0]  # number of classes, number of detections

    # Create Precision-Recall curve and compute AP for each class
    p = numpy.zeros((nc, 1000))
    r = numpy.zeros((nc, 1000))
    ap = numpy.zeros((nc, tp.shape[1]))
    px, py = numpy.linspace(0, 1, 1000), []  # for plotting
    for ci, c in enumerate(unique_classes):
        i = pred_cls == c
        nl = nt[ci]  # number of labels
        no = i.sum()  # number of outputs
        if no == 0 or nl == 0:
            continue

        # Accumulate FPs and TPs
        fpc = (1 - tp[i]).cumsum(0)
        tpc = tp[i].cumsum(0)

        # Recall
        recall = tpc / (nl + eps)  # recall curve
        # negative x, xp because xp decreases
        r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0)

        # Precision
        precision = tpc / (tpc + fpc)  # precision curve
        p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1)  # p at pr_score

        # AP from recall-precision curve
        for j in range(tp.shape[1]):
            m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0]))
            m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0]))

            # Compute the precision envelope
            m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre)))

            # Integrate area under curve
            x = numpy.linspace(0, 1, 101)  # 101-point interp (COCO)
            ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x)  # integrate

    # Compute F1 (harmonic mean of precision and recall)
    f1 = 2 * p * r / (p + r + eps)

    i = smooth(f1.mean(0), 0.1).argmax()  # max F1 index
    p, r, f1 = p[:, i], r[:, i], f1[:, i]
    tp = (r * nt).round()  # true positives
    fp = (tp / (p + eps) - tp).round()  # false positives
    ap50, ap = ap[:, 0], ap.mean(1)  # AP@0.5, AP@0.5:0.95
    m_pre, m_rec = p.mean(), r.mean()
    map50, mean_ap = ap50.mean(), ap.mean()
    return tp, fp, m_pre, m_rec, map50, mean_ap


def strip_optimizer(filename):
    x = torch.load(filename, map_location=torch.device('cpu'))
    x['model'].half()  # to FP16
    for p in x['model'].parameters():
        p.requires_grad = False
    torch.save(x, filename)


def clip_gradients(model, max_norm=10.0):
    parameters = model.parameters()
    torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)


class EMA:
    """
    Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        # Create EMA - keep in float32 to avoid reparameterization issues
        self.ema = copy.deepcopy(model).eval().float()  # Always keep EMA in float32
        self.updates = updates  # number of EMA updates
        # decay exponential ramp (to help early epochs)
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        if hasattr(model, 'module'):
            model = model.module
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)

            msd = model.state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach().to(v.dtype)


class AverageMeter:
    def __init__(self):
        self.num = 0
        self.sum = 0
        self.avg = 0

    def update(self, v, n):
        if not math.isnan(float(v)):
            self.num = self.num + n
            self.sum = self.sum + v * n
            self.avg = self.sum / self.num


class ComputeLoss:
    def __init__(self, model, params):
        super().__init__()
        if hasattr(model, 'module'):
            model = model.module

        device = next(model.parameters()).device

        # Get model parameters
        self.multi_head = hasattr(model, 'head_p3')
        self.num_classes = model.num_classes

        # For single head, get output channels from head
        if not self.multi_head:
            m = model.head
            self.nc = self.num_classes
            # Output channels: 4 (bbox) + num_classes
            self.no = 4 + self.nc
        else:
            # For multi-head, use p3 head as reference
            m = model.head_p3
            self.nc = self.num_classes
            self.no = 4 + self.nc

        self.bce = torch.nn.BCEWithLogitsLoss(reduction='none')

        # Define strides based on your model's feature map scales
        if self.multi_head:
            self.stride = torch.tensor([8, 16, 32], device=device)
        else:
            self.stride = torch.tensor([8], device=device)

        self.device = device
        self.params = params

        # Task aligned assigner
        self.top_k = 10
        self.alpha = 0.5
        self.beta = 6.0
        self.eps = 1e-9

        self.bs = 1
        self.num_max_boxes = 0

        # DFL Loss params
        self.dfl_ch = 1  # Simplified DFL for now
        self.project = torch.arange(self.dfl_ch, dtype=torch.float, device=device)

    def __call__(self, outputs, targets):
        # Handle different output formats
        if self.multi_head:
            # Multi-head outputs: [p3, p4, p5]
            p3, p4, p5 = outputs
            output_tensors = [p3, p4, p5]
        else:
            # Single-head output
            output_tensors = [outputs]

        # Get the first output for shape reference
        x = output_tensors[0]
        batch_size = x.shape[0]

        # Concatenate outputs along spatial dimension for multi-head
        if self.multi_head:
            # For multi-head, we need to handle each head separately
            all_outputs = []
            for output in output_tensors:
                # Flatten spatial dimensions: [B, C, H, W] -> [B, C, H*W]
                flattened = output.view(batch_size, self.no, -1)
                all_outputs.append(flattened)
            # Concatenate along spatial dimension: [B, C, H1*W1 + H2*W2 + H3*W3]
            output_cat = torch.cat(all_outputs, 2)
        else:
            # Single head: just flatten
            output_cat = x.view(batch_size, self.no, -1)

        # Split into box predictions and class scores
        pred_output = output_cat[:, :4, :]  # [B, 4, N]
        pred_scores = output_cat[:, 4:, :]  # [B, num_classes, N]

        pred_output = pred_output.permute(0, 2, 1).contiguous()  # [B, N, 4]
        pred_scores = pred_scores.permute(0, 2, 1).contiguous()  # [B, N, num_classes]

        # Calculate size based on first output feature map
        size = torch.tensor(x.shape[2:], dtype=pred_scores.dtype, device=self.device)
        size = size * self.stride[0]

        # Create anchors
        anchor_points, stride_tensor = self.make_anchors(output_tensors, self.stride, 0.5)

        # Process targets
        if targets.shape[0] == 0:
            gt = torch.zeros(pred_scores.shape[0], 0, 5, device=self.device)
        else:
            i = targets[:, 0]
            _, counts = i.unique(return_counts=True)
            gt = torch.zeros(pred_scores.shape[0], counts.max(), 5, device=self.device)
            for j in range(pred_scores.shape[0]):
                matches = i == j
                n = matches.sum()
                if n:
                    gt[j, :n] = targets[matches, 1:]
            # Convert from normalized to pixel coordinates
            from data_loader import wh2xy
            gt[..., 1:5] = wh2xy(gt[..., 1:5].mul_(size[[1, 0, 1, 0]]))

        gt_labels, gt_bboxes = gt.split((1, 4), 2)
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)

        # Process predictions to bounding boxes
        b, a, c = pred_output.shape
        # Simple box decoding (replace DFL with direct regression for now)
        pred_bboxes = torch.sigmoid(pred_output)  # Simple activation for box coordinates

        # Convert from center+wh to xyxy format
        a, b_vals = torch.split(pred_bboxes, 2, -1)
        pred_bboxes_xyxy = torch.cat((anchor_points - a, anchor_points + b_vals), -1)

        scores = pred_scores.detach().sigmoid()
        bboxes = (pred_bboxes_xyxy.detach() * stride_tensor).type(gt_bboxes.dtype)

        # Task-aligned assignment
        target_bboxes, target_scores, fg_mask = self.assign(scores, bboxes,
                                                          gt_labels, gt_bboxes, mask_gt,
                                                          anchor_points * stride_tensor)

        target_bboxes /= stride_tensor
        target_scores_sum = max(target_scores.sum(), 1)

        # cls loss
        loss_cls = self.bce(pred_scores, target_scores.to(pred_scores.dtype))
        loss_cls = loss_cls.sum() / target_scores_sum

        # box loss
        loss_box = torch.zeros(1, device=self.device)
        loss_dfl = torch.zeros(1, device=self.device)

        if fg_mask.sum():
            weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
            loss_box = self.iou(pred_bboxes_xyxy[fg_mask], target_bboxes[fg_mask])
            loss_box = ((1.0 - loss_box) * weight).sum() / target_scores_sum

            # Simplified DFL loss
            a, b_vals = torch.split(target_bboxes, 2, -1)
            target_lt_rb = torch.cat((anchor_points - a, b_vals - anchor_points), -1)
            target_lt_rb = target_lt_rb.clamp(0, self.dfl_ch - 1.01)
            loss_dfl = self.df_loss(pred_output[fg_mask].view(-1, 4), target_lt_rb[fg_mask])
            loss_dfl = (loss_dfl * weight).sum() / target_scores_sum

        loss_cls *= self.params['cls']
        loss_box *= self.params['box']
        loss_dfl *= self.params['dfl']
        return loss_cls + loss_box + loss_dfl

    @staticmethod
    def make_anchors(x, strides, offset=0.5):
        """Generate anchors from features"""
        assert x is not None
        anchor_points, stride_tensor = [], []
        for i, stride in enumerate(strides):
            _, _, h, w = x[i].shape
            sx = torch.arange(end=w, dtype=x[i].dtype, device=x[i].device) + offset
            sy = torch.arange(end=h, dtype=x[i].dtype, device=x[i].device) + offset
            sy, sx = torch.meshgrid(sy, sx, indexing='ij')
            anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
            stride_tensor.append(torch.full((h * w, 1), stride, dtype=x[i].dtype, device=x[i].device))
        return torch.cat(anchor_points), torch.cat(stride_tensor)

    @torch.no_grad()
    def assign(self, pred_scores, pred_bboxes, true_labels, true_bboxes, true_mask, anchors):
        """Task-aligned assignment"""
        self.bs = pred_scores.size(0)
        self.num_max_boxes = true_bboxes.size(1)

        if self.num_max_boxes == 0:
            device = true_bboxes.device
            return (torch.zeros_like(pred_bboxes).to(device),
                    torch.zeros_like(pred_scores).to(device),
                    torch.zeros_like(pred_scores[..., 0]).to(device))

        # Simplified assignment for now
        i = torch.zeros([2, self.bs, self.num_max_boxes], dtype=torch.long)
        i[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.num_max_boxes)
        i[1] = true_labels.long().squeeze(-1)

        overlaps = self.iou(true_bboxes.unsqueeze(2), pred_bboxes.unsqueeze(1))
        overlaps = overlaps.squeeze(3).clamp(0)

        # Simple assignment based on IoU
        max_overlaps, max_indices = overlaps.max(2)
        fg_mask = max_overlaps > 0.5

        target_bboxes = torch.zeros_like(pred_bboxes)
        target_scores = torch.zeros_like(pred_scores)

        for b in range(self.bs):
            for j in range(self.num_max_boxes):
                if true_mask[b, j] and fg_mask[b, j]:
                    idx = max_indices[b, j]
                    target_bboxes[b, idx] = true_bboxes[b, j]
                    target_scores[b, idx, true_labels[b, j].long()] = 1.0

        return target_bboxes, target_scores, fg_mask

    @staticmethod
    def df_loss(pred_dist, target):
        """Simplified distribution focal loss"""
        # For now, use smooth L1 loss as placeholder
        return F.smooth_l1_loss(pred_dist, target, reduction='none')

    @staticmethod
    def iou(box1, box2, eps=1e-7):
        """Calculate IoU between boxes"""
        # Get the coordinates of bounding boxes
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)

        # Intersection area
        inter_x1 = torch.max(b1_x1, b2_x1)
        inter_y1 = torch.max(b1_y1, b2_y1)
        inter_x2 = torch.min(b1_x2, b2_x2)
        inter_y2 = torch.min(b1_y2, b2_y2)

        inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)

        # Union Area
        b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
        b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)

        union_area = b1_area + b2_area - inter_area + eps

        # IoU
        iou = inter_area / union_area

        return iou


In [3]:
# data_loader.py

import math
import os
import random
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils import data
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Configuration
DATASET_CONFIG = {
    'input_size': 640,
    'batch_size': 16,
    'num_workers': 1,
    'train_test_split': 0.2,
    'random_state': 42
}

TRAIN_AUGMENTATION = {
    'mosaic': 1.0,
    'mix_up': 0.1,
    'hsv_h': 0.015,
    'hsv_s': 0.7,
    'hsv_v': 0.4,
    'degrees': 0.0,
    'translate': 0.1,
    'scale': 0.5,
    'shear': 0.0,
    'flip_lr': 0.5,
    'flip_ud': 0.0,
}

VAL_AUGMENTATION = {
    'mosaic': 0.0,
    'mix_up': 0.0,
    'hsv_h': 0.0,
    'hsv_s': 0.0,
    'hsv_v': 0.0,
    'degrees': 0.0,
    'translate': 0.0,
    'scale': 0.0,
    'shear': 0.0,
    'flip_lr': 0.0,
    'flip_ud': 0.0,
}

FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp'

# Helper functions (keep all the same as before)
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
    y = np.copy(x)
    y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w
    y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h
    y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w
    y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h
    return y

def xy2wh(x, w=640, h=640):
    x[:, [0, 2]] = x[:, [0, 2]].clip(0, w - 1E-3)
    x[:, [1, 3]] = x[:, [1, 3]].clip(0, h - 1E-3)
    y = np.copy(x)
    y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w
    y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h
    y[:, 2] = (x[:, 2] - x[:, 0]) / w
    y[:, 3] = (x[:, 3] - x[:, 1]) / h
    return y

def resample():
    choices = (cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4)
    return random.choice(choices)

def augment_hsv(image, params):
    h = params['hsv_h']
    s = params['hsv_s']
    v = params['hsv_v']
    r = np.random.uniform(-1, 1, 3) * [h, s, v] + 1
    h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV))
    x = np.arange(0, 256, dtype=r.dtype)
    lut_h = ((x * r[0]) % 180).astype('uint8')
    lut_s = np.clip(x * r[1], 0, 255).astype('uint8')
    lut_v = np.clip(x * r[2], 0, 255).astype('uint8')
    im_hsv = cv2.merge((cv2.LUT(h, lut_h), cv2.LUT(s, lut_s), cv2.LUT(v, lut_v)))
    cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=image)

def resize(image, input_size, augment):
    shape = image.shape[:2]
    r = min(input_size / shape[0], input_size / shape[1])
    if not augment:
        r = min(r, 1.0)
    pad = int(round(shape[1] * r)), int(round(shape[0] * r))
    w = (input_size - pad[0]) / 2
    h = (input_size - pad[1]) / 2
    if shape[::-1] != pad:
        image = cv2.resize(image, dsize=pad, interpolation=resample() if augment else cv2.INTER_LINEAR)
    top, bottom = int(round(h - 0.1)), int(round(h + 0.1))
    left, right = int(round(w - 0.1)), int(round(w + 0.1))
    image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT)
    return image, (r, r), (w, h)

def candidates(box1, box2):
    w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
    w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
    aspect_ratio = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16))
    return (w2 > 2) & (h2 > 2) & (w2 * h2 / (w1 * h1 + 1e-16) > 0.1) & (aspect_ratio < 100)

def random_perspective(samples, targets, params, border=(0, 0)):
    h = samples.shape[0] + border[0] * 2
    w = samples.shape[1] + border[1] * 2
    center = np.eye(3)
    center[0, 2] = -samples.shape[1] / 2
    center[1, 2] = -samples.shape[0] / 2
    perspective = np.eye(3)
    rotate = np.eye(3)
    a = random.uniform(-params['degrees'], params['degrees'])
    s = random.uniform(1 - params['scale'], 1 + params['scale'])
    rotate[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
    shear = np.eye(3)
    shear[0, 1] = math.tan(random.uniform(-params['shear'], params['shear']) * math.pi / 180)
    shear[1, 0] = math.tan(random.uniform(-params['shear'], params['shear']) * math.pi / 180)
    translate = np.eye(3)
    translate[0, 2] = random.uniform(0.5 - params['translate'], 0.5 + params['translate']) * w
    translate[1, 2] = random.uniform(0.5 - params['translate'], 0.5 + params['translate']) * h
    matrix = translate @ shear @ rotate @ perspective @ center
    if (border[0] != 0) or (border[1] != 0) or (matrix != np.eye(3)).any():
        samples = cv2.warpAffine(samples, matrix[:2], dsize=(w, h), borderValue=(0, 0, 0))
    n = len(targets)
    if n:
        xy = np.ones((n * 4, 3))
        xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2)
        xy = xy @ matrix.T
        xy = xy[:, :2].reshape(n, 8)
        x = xy[:, [0, 2, 4, 6]]
        y = xy[:, [1, 3, 5, 7]]
        new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
        new[:, [0, 2]] = new[:, [0, 2]].clip(0, w)
        new[:, [1, 3]] = new[:, [1, 3]].clip(0, h)
        indices = candidates(box1=targets[:, 1:5].T * s, box2=new.T)
        targets = targets[indices]
        targets[:, 1:5] = new[indices]
    return samples, targets

def mix_up(image1, label1, image2, label2):
    alpha = np.random.beta(32.0, 32.0)
    image = (image1 * alpha + image2 * (1 - alpha)).astype(np.uint8)
    label = np.concatenate((label1, label2), 0)
    return image, label

class Albumentations:
    def __init__(self):
        self.transform = None
        try:
            import albumentations as album
            transforms = [album.Blur(p=0.01), album.CLAHE(p=0.01),
                         album.ToGray(p=0.01), album.MedianBlur(p=0.01)]
            self.transform = album.Compose(transforms, album.BboxParams('yolo', ['class_labels']))
        except ImportError:
            pass

    def __call__(self, image, label):
        if self.transform and len(label) > 0:
            x = self.transform(image=image, bboxes=label[:, 1:], class_labels=label[:, 0])
            image = x['image']
            label = np.array([[c, *b] for c, b in zip(x['class_labels'], x['bboxes'])])
        return image, label

class CarDetectionDataset(data.Dataset):
    def __init__(self, filenames, input_size, params, augment):
        self.params = params
        self.mosaic = augment
        self.augment = augment
        self.input_size = input_size

        # Load labels without cache in read-only directories
        cache = self.load_label(filenames)
        labels, shapes = zip(*cache.values())
        self.labels = list(labels)
        self.shapes = np.array(shapes, dtype=np.float64)
        self.filenames = list(cache.keys())
        self.n = len(shapes)
        self.indices = range(self.n)
        self.albumentations = Albumentations()

    def __getitem__(self, index):
        index = self.indices[index]
        params = self.params
        mosaic = self.mosaic and random.random() < params['mosaic']

        if mosaic:
            # Load MOSAIC
            image, label = self.load_mosaic(index, params)
            shapes = None  # Mosaic doesn't have original shapes

            # MixUp augmentation
            if random.random() < params['mix_up']:
                index = random.choice(self.indices)
                mix_image1, mix_label1 = image, label
                mix_image2, mix_label2 = self.load_mosaic(index, params)
                image, label = mix_up(mix_image1, mix_label1, mix_image2, mix_label2)
        else:
            # Load image
            image, shape = self.load_image(index)
            h, w = image.shape[:2]

            # Resize
            image, ratio, pad = resize(image, self.input_size, self.augment)
            shapes = shape, ((h / shape[0], w / shape[1]), pad)  # FIXED: define shapes here

            label = self.labels[index].copy()
            if label.size:
                label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
            if self.augment:
                image, label = random_perspective(image, label, params)

        nl = len(label)
        if nl:
            label[:, 1:5] = xy2wh(label[:, 1:5], image.shape[1], image.shape[0])

        if self.augment:
            image, label = self.albumentations(image, label)
            nl = len(label)
            augment_hsv(image, self.params)
            if random.random() < params['flip_ud']:
                image = np.flipud(image)
                if nl:
                    label[:, 2] = 1 - label[:, 2]
            if random.random() < params['flip_lr']:
                image = np.fliplr(image)
                if nl:
                    label[:, 1] = 1 - label[:, 1]

        target = torch.zeros((nl, 6))
        if nl:
            target[:, 1:] = torch.from_numpy(label)

        sample = image.transpose((2, 0, 1))[::-1]
        sample = np.ascontiguousarray(sample)

        return torch.from_numpy(sample), target, shapes

    def __len__(self):
        return len(self.filenames)

    def load_image(self, i):
        image = cv2.imread(self.filenames[i])
        if image is None:
            raise ValueError(f"Could not load image: {self.filenames[i]}")
        h, w = image.shape[:2]
        r = self.input_size / max(h, w)
        if r != 1:
            image = cv2.resize(image,
                             dsize=(int(w * r), int(h * r)),
                             interpolation=resample() if self.augment else cv2.INTER_LINEAR)
        return image, (h, w)

    def load_mosaic(self, index, params):
        label4 = []
        image4 = np.full((self.input_size * 2, self.input_size * 2, 3), 0, dtype=np.uint8)
        border = [-self.input_size // 2, -self.input_size // 2]
        xc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
        yc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
        indices = [index] + random.choices(self.indices, k=3)
        random.shuffle(indices)

        for i, index in enumerate(indices):
            image, _ = self.load_image(index)
            shape = image.shape
            if i == 0:
                x1a, y1a = max(xc - shape[1], 0), max(yc - shape[0], 0)
                x2a, y2a = xc, yc
                x1b, y1b = shape[1] - (x2a - x1a), shape[0] - (y2a - y1a)
                x2b, y2b = shape[1], shape[0]
            elif i == 1:
                x1a, y1a = xc, max(yc - shape[0], 0)
                x2a, y2a = min(xc + shape[1], self.input_size * 2), yc
                x1b, y1b = 0, shape[0] - (y2a - y1a)
                x2b, y2b = min(shape[1], x2a - x1a), shape[0]
            elif i == 2:
                x1a, y1a = max(xc - shape[1], 0), yc
                x2a, y2a = xc, min(self.input_size * 2, yc + shape[0])
                x1b, y1b = shape[1] - (x2a - x1a), 0
                x2b, y2b = shape[1], min(y2a - y1a, shape[0])
            elif i == 3:
                x1a, y1a = xc, yc
                x2a, y2a = min(xc + shape[1], self.input_size * 2), min(self.input_size * 2, yc + shape[0])
                x1b, y1b = 0, 0
                x2b, y2b = min(shape[1], x2a - x1a), min(y2a - y1a, shape[0])

            image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
            pad_w, pad_h = x1a - x1b, y1a - y1b
            label = self.labels[index].copy()
            if len(label):
                label[:, 1:] = wh2xy(label[:, 1:], shape[1], shape[0], pad_w, pad_h)
            label4.append(label)

        label4 = np.concatenate(label4, 0)
        for x in label4[:, 1:]:
            np.clip(x, 0, 2 * self.input_size, out=x)
        image4, label4 = random_perspective(image4, label4, params, border)
        return image4, label4

    @staticmethod
    def collate_fn(batch):
        samples, targets, shapes = zip(*batch)
        for i, item in enumerate(targets):
            item[:, 0] = i
        return torch.stack(samples, 0), torch.cat(targets, 0), shapes

    @staticmethod
    def load_label(filenames):
        """Load labels without using cache in read-only directories"""
        # Use current directory for cache instead of dataset directory
        cache_dir = './dataset_cache'
        os.makedirs(cache_dir, exist_ok=True)

        # Create a unique cache filename based on the first image directory
        first_file_dir = os.path.dirname(filenames[0])
        cache_name = os.path.basename(first_file_dir) + '.cache'
        cache_path = os.path.join(cache_dir, cache_name)

        # Try to load cache if it exists
        if os.path.exists(cache_path):
            try:
                print(f"Loading cache from: {cache_path}")
                # Load with weights_only=False to handle numpy arrays
                return torch.load(cache_path, weights_only=False)
            except Exception as e:
                print(f"Cache loading failed: {e}, regenerating...")

        x = {}
        valid_count = 0
        for filename in filenames:
            try:
                with open(filename, 'rb') as f:
                    image = Image.open(f)
                    image.verify()
                shape = image.size

                if not ((shape[0] > 9) & (shape[1] > 9)):
                    print(f"Warning: image size {shape} <10 pixels in {filename}")
                    continue

                if image.format.lower() not in FORMATS:
                    print(f"Warning: invalid image format {image.format} in {filename}")
                    continue

                # Find label file - handle different dataset structures
                label_path = None
                possible_paths = [
                    filename.replace('images', 'labels').rsplit('.', 1)[0] + '.txt',
                    filename.replace('testing_images', 'testing_labels').rsplit('.', 1)[0] + '.txt',
                    filename.replace('training_images', 'training_labels').rsplit('.', 1)[0] + '.txt',
                ]

                for path in possible_paths:
                    if os.path.isfile(path):
                        label_path = path
                        break

                if label_path and os.path.isfile(label_path):
                    with open(label_path) as f:
                        label_lines = [x.split() for x in f.read().strip().splitlines() if len(x)]
                        label = np.array(label_lines, dtype=np.float32)
                    nl = len(label)
                    if nl:
                        if label.shape[1] != 5:
                            print(f"Warning: labels require 5 columns in {label_path}")
                            continue
                        if not (label >= 0).all():
                            print(f"Warning: negative label values in {label_path}")
                            continue
                        if not (label[:, 1:] <= 1).all():
                            print(f"Warning: non-normalized coordinates in {label_path}")
                            continue
                        _, i = np.unique(label, axis=0, return_index=True)
                        if len(i) < nl:
                            label = label[i]
                    else:
                        label = np.zeros((0, 5), dtype=np.float32)
                else:
                    # No label file found
                    label = np.zeros((0, 5), dtype=np.float32)
                    print(f"Warning: No label file found for {filename}")

                x[filename] = [label, shape]
                valid_count += 1

            except Exception as e:
                print(f"Error loading {filename}: {e}")
                continue

        print(f"Successfully loaded {valid_count}/{len(filenames)} images")

        # Save cache if we have valid data
        if valid_count > 0:
            try:
                torch.save(x, cache_path)
                print(f"Cache saved to: {cache_path}")
            except Exception as e:
                print(f"Warning: Could not save cache: {e}")

        return x

def find_image_files(dataset_path):
    """Find all image files in the dataset"""
    image_files = []
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                image_files.append(os.path.join(root, file))
    return image_files

def create_data_loaders(dataset_path):
    """Create training and validation data loaders"""
    image_files = find_image_files(dataset_path)
    print(f"Found {len(image_files)} images")

    if len(image_files) == 0:
        raise ValueError("No images found in the dataset path!")

    # Split dataset
    train_files, val_files = train_test_split(
        image_files,
        test_size=DATASET_CONFIG['train_test_split'],
        random_state=DATASET_CONFIG['random_state']
    )

    print(f"Creating training dataset with {len(train_files)} images...")
    train_dataset = CarDetectionDataset(
        train_files,
        DATASET_CONFIG['input_size'],
        TRAIN_AUGMENTATION,
        augment=True
    )

    print(f"Creating validation dataset with {len(val_files)} images...")
    val_dataset = CarDetectionDataset(
        val_files,
        DATASET_CONFIG['input_size'],
        VAL_AUGMENTATION,
        augment=False
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=DATASET_CONFIG['batch_size'],
        shuffle=True,
        num_workers=0,  # Set to 0 to avoid multiprocessing issues
        pin_memory=False,  # Set to False since we're not using GPU for data loading
        collate_fn=CarDetectionDataset.collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=DATASET_CONFIG['batch_size'],
        shuffle=False,
        num_workers=0,  # Set to 0 to avoid multiprocessing issues
        pin_memory=False,  # Set to False since we're not using GPU for data loading
        collate_fn=CarDetectionDataset.collate_fn
    )

    print(f"Training samples: {len(train_files)}")
    print(f"Validation samples: {len(val_files)}")
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")

    return train_loader, val_loader

In [4]:
# Install required packages
!pip install albumentations kagglehub
import kagglehub

# Download dataset
path = kagglehub.dataset_download("sshikamaru/car-object-detection")
print(f"Dataset path: {path}")

# Create data loaders
train_loader, val_loader = create_data_loaders(path)

# Test the data loader
for images, targets, shapes in train_loader:
    print(f"Batch - Images: {images.shape}, Targets: {targets.shape}")
    break

Using Colab cache for faster access to the 'car-object-detection' dataset.
Dataset path: /kaggle/input/car-object-detection
Found 1176 images
Creating training dataset with 940 images...
Loading cache from: ./dataset_cache/testing_images.cache


  self._set_keys()


Creating validation dataset with 236 images...
Loading cache from: ./dataset_cache/testing_images.cache
Training samples: 940
Validation samples: 236
Train batches: 59
Val batches: 59
Batch - Images: torch.Size([16, 3, 640, 640]), Targets: torch.Size([0, 6])


In [5]:
import copy
import csv
import os
import warnings
import sys
from pathlib import Path

import numpy as np
import torch
import tqdm
import yaml

from torch.utils.data import DataLoader
import torch.nn as nn



warnings.filterwarnings("ignore")
FORMATS = ('bmp','dng','jpeg','jpg','mpo','png','tif','tiff','webp')

def _list_images(root):
    root = Path(root)
    out = []
    for ext in FORMATS:
        out += [str(p) for p in root.rglob(f'*.{ext}')]
        out += [str(p) for p in root.rglob(f'*.{ext.upper()}')]
    return out

class YoloTrainer:
    def __init__(self, args, params):
        self.args = args
        self.params = params
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print("Device:", self.device)

        # Model configuration
        num_classes = params.get('nc', 5)  # Default to 1 class (car)
        multi_head = params.get('multi_head', False)

        self.model = YOLOv8Hybrid(num_classes=num_classes, multi_head=multi_head).to(self.device)
        self.world_size = 1
        self.accumulate = max(round(64 / args.batch_size), 1)
        self.params['weight_decay'] = float(self.params.get('weight_decay', 0.0005))

        self.optimizer = self._setup_optimizer()
        for g in self.optimizer.param_groups:
            g.setdefault('initial_lr', g.get('lr', self.params.get('lr0', 0.01)))

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, lr_lambda=self._get_lr_lambda()
        )

        self.ema = EMA(self.model)

        # Loss parameters
        loss_params = {
            'cls': params.get('cls_loss', 0.5),
            'box': params.get('box_loss', 7.5),
            'dfl': params.get('dfl_loss', 1.5)
        }
        self.criterion = ComputeLoss(self.model, loss_params)

        self.scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
        self.best_map = 0.0
        os.makedirs('weights', exist_ok=True)

    def _get_lr_lambda(self):
        def lr_lambda(epoch):
            return (1 - epoch / float(self.args.epochs)) * \
                   (1.0 - float(self.params.get('lrf', 0.01))) + \
                   float(self.params.get('lrf', 0.01))
        return lr_lambda

    def _setup_optimizer(self):
        biases, bn_weights, weights = [], [], []
        for module in self.model.modules():
            if hasattr(module, 'bias') and isinstance(getattr(module, 'bias'), nn.Parameter):
                biases.append(module.bias)
            if isinstance(module, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                if hasattr(module, 'weight'):
                    bn_weights.append(module.weight)
            elif hasattr(module, 'weight') and isinstance(getattr(module, 'weight'), nn.Parameter):
                weights.append(module.weight)

        lr0 = float(self.params.get('lr0', 0.01))
        momentum = float(self.params.get('momentum', 0.937))
        weight_decay = float(self.params.get('weight_decay', 0.0005))

        optimizer = torch.optim.SGD(biases, lr=lr0, momentum=momentum, nesterov=True)
        optimizer.add_param_group({'params': weights, 'weight_decay': weight_decay, 'lr': lr0})
        optimizer.add_param_group({'params': bn_weights, 'weight_decay': 0.0, 'lr': lr0})

        return optimizer

    def _load_dataset(self, train=True):
        """Load dataset using the CarDetectionDataset class"""
        data_cfg = self.params.get('data', {})

        if train:
            # For training, use the dataset path directly
            dataset_path = data_cfg.get('path', '/kaggle/input/car-object-detection')
            print(f"Loading {'training' if train else 'validation'} dataset from: {dataset_path}")

            # Use the create_data_loaders function from data_loader
            train_loader, val_loader = create_data_loaders(dataset_path)
            return train_loader if train else val_loader
        else:
            # For validation, return the validation loader
            dataset_path = data_cfg.get('path', '/kaggle/input/car-object-detection')
            train_loader, val_loader = create_data_loaders(dataset_path)
            return val_loader

    def train_epoch(self, loader, epoch, warmup_iters, num_batches):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        pbar = tqdm.tqdm(loader, desc=f"Epoch {epoch+1}/{self.args.epochs}")

        self.optimizer.zero_grad()

        for i, (images, targets, _) in enumerate(pbar):
            iteration = i + epoch * num_batches
            images = images.to(self.device).float() / 255.0
            targets = targets.to(self.device)

            # Learning rate warmup
            if iteration < warmup_iters:
                lr_scale = np.interp(iteration, [0, warmup_iters], [0.2, 1.0])
                for g in self.optimizer.param_groups:
                    g['lr'] = g['initial_lr'] * lr_scale

            # Mixed precision training
            if torch.cuda.is_available():
                with torch.cuda.amp.autocast():
                    outputs = self.model(images)
                    loss = self.criterion(outputs, targets)

                self.scaler.scale(loss).backward()

                # Gradient accumulation
                if (i + 1) % self.accumulate == 0:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
            else:
                # CPU training
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)
                loss.backward()

                # Gradient accumulation
                if (i + 1) % self.accumulate == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss / (i + 1))

            # Update EMA
            self.ema.update(self.model)

        return running_loss / len(loader)

    def train(self):
        """Main training loop"""
        train_loader = self._load_dataset(train=True)
        val_loader = self._load_dataset(train=False)

        num_batches = len(train_loader)
        warmup_iters = max(round(self.params.get('warmup_epochs', 3) * num_batches), 500)

        # Create log file
        csv_file = open('weights/training_log.csv', 'w', newline='')
        writer = csv.DictWriter(csv_file, fieldnames=['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP'])
        writer.writeheader()

        print(f"Starting training for {self.args.epochs} epochs...")
        print(f"Training samples: {len(train_loader.dataset)}")
        print(f"Validation samples: {len(val_loader.dataset)}")
        print(f"Batch size: {self.args.batch_size}, Accumulate: {self.accumulate}")

        for epoch in range(self.args.epochs):
            # Train one epoch
            train_loss = self.train_epoch(train_loader, epoch, warmup_iters, num_batches)

            # Update learning rate
            self.scheduler.step()

            # Validate
            val_loss, map50, map_val = self.validate(val_loader)

            # Log results
            writer.writerow({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'mAP50': map50,
                'mAP': map_val
            })
            csv_file.flush()

            print(f"Epoch {epoch+1}/{self.args.epochs} - "
                  f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                  f"mAP50: {map50:.4f}, mAP: {map_val:.4f}")

            # Save best model
            if map_val > self.best_map:
                self.best_map = map_val
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.ema.ema.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_map': self.best_map,
                    'params': self.params
                }, "weights/best_model.pth")
                print(f"Saved best model with mAP: {map_val:.4f}")

            # Save latest model
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.ema.ema.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'best_map': self.best_map,
                'params': self.params
            }, "weights/latest_model.pth")

            # Save checkpoint every 10 epochs
            if (epoch + 1) % 10 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.ema.ema.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_map': self.best_map,
                    'params': self.params
                }, f"weights/checkpoint_epoch_{epoch+1}.pth")

        csv_file.close()
        print(f"Training completed! Best mAP: {self.best_map:.4f}")

    @torch.no_grad()
    def validate(self, loader):
        self.model.eval()  # Use regular model instead of EMA
        val_loss = 0
        preds = []
        targets = []

        with torch.no_grad():
            for batch_idx, (images, labels, _) in enumerate(loader):
                images = images.to(self.device, non_blocking=True).float() / 255.0
                labels = labels.to(self.device)
                targets.append(labels)

                # Use regular model instead of EMA
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)

                val_loss += loss.item()
                preds.append(outputs)

        map50, map_val = self.calculate_map(loader)
        return val_loss / len(loader), map50, map_val

    # @torch.no_grad()
    # def validate(self, loader):
    #   """Validate the model"""
    #   self.ema.ema.eval()
    #   running_loss = 0.0

    #   # Disable mixed precision for validation to avoid dtype issues
    #   for images, targets, _ in tqdm.tqdm(loader, desc="Validating"):
    #       images = images.to(self.device).float() / 255.0  # Ensure float32
    #       targets = targets.to(self.device)

    #       # Use float32 for validation to avoid reparameterization issues
    #       outputs = self.ema.ema(images.float())  # Force float32
    #       loss = self.criterion(outputs, targets)

    #       running_loss += loss.item()

    #   # Calculate mAP (placeholder)
    #   map50, map_val = self.calculate_map(loader)

    #   return running_loss / len(loader), map50, map_val

    def calculate_map(self, loader):
        """Calculate mAP - placeholder implementation"""
        # This is a simplified version. You can integrate proper mAP calculation
        # using your validation dataset and detection metrics
        return 0.5, 0.3  # Placeholder values

    def test(self):
        """Test the model"""
        val_loader = self._load_dataset(train=False)
        val_loss, map50, map_val = self.validate(val_loader)

        print(f"Test Results - Loss: {val_loss:.4f}, mAP50: {map50:.4f}, mAP: {map_val:.4f}")
        return map50, map_val

# Simple configuration class for Colab
class TrainingConfig:
    def __init__(self):
        self.input_size = 416
        self.batch_size = 8
        self.epochs = 50
        self.train = True
        self.test = False
        self.resume = ''
        self.data_path = '/kaggle/input/car-object-detection'

# Default configuration
default_config = {
    'nc': 1,  # number of classes (car)
    'multi_head': False,
    'lr0': 0.01,
    'lrf': 0.01,
    'momentum': 0.937,
    'weight_decay': 0.0005,
    'warmup_epochs': 3,
    'cls_loss': 0.5,
    'box_loss': 7.5,
    'dfl_loss': 1.5,
    'data': {
        'path': '/kaggle/input/car-object-detection'
    }
}

def start_training(epochs=50, batch_size=16, resume_checkpoint=''):
    """
    Start training in Google Colab

    Args:
        epochs (int): Number of training epochs
        batch_size (int): Batch size for training
        resume_checkpoint (str): Path to checkpoint to resume from
    """
    # Create configuration
    config = TrainingConfig()
    config.epochs = epochs
    config.batch_size = batch_size
    config.resume = resume_checkpoint

    # Create trainer
    trainer = YoloTrainer(config, default_config)

    # Resume from checkpoint if specified
    if resume_checkpoint and os.path.exists(resume_checkpoint):
        print(f"Resuming from checkpoint: {resume_checkpoint}")
        checkpoint = torch.load(resume_checkpoint, map_location=trainer.device)
        trainer.ema.ema.load_state_dict(checkpoint['model_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        trainer.best_map = checkpoint.get('best_map', 0.0)
        print(f"Resumed from epoch {checkpoint['epoch']}, best mAP: {trainer.best_map:.4f}")

    # Start training
    trainer.train()

def start_testing(checkpoint_path='weights/best_model.pth'):
    """
    Test the model in Google Colab

    Args:
        checkpoint_path (str): Path to model checkpoint to test
    """
    config = TrainingConfig()
    config.train = False
    config.test = True

    trainer = YoloTrainer(config, default_config)

    if os.path.exists(checkpoint_path):
        print(f"Loading model from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=trainer.device)
        trainer.ema.ema.load_state_dict(checkpoint['model_state_dict'])
        trainer.best_map = checkpoint.get('best_map', 0.0)
        print(f"Loaded model from epoch {checkpoint['epoch']}, best mAP: {trainer.best_map:.4f}")

    trainer.test()

# Simple function to show training progress
def show_training_progress():
    """Show training progress from log file"""
    if os.path.exists('weights/training_log.csv'):
        import pandas as pd
        df = pd.read_csv('weights/training_log.csv')
        print("Training Progress:")
        print(df.tail(10))
    else:
        print("No training log found.")

# Usage examples for Colab:
def example_usage():
    """
    Example usage in Google Colab:

    # Start training with default parameters
    start_training()

    # Start training with custom parameters
    start_training(epochs=100, batch_size=8)

    # Resume training from checkpoint
    start_training(epochs=50, batch_size=16, resume_checkpoint='weights/checkpoint_epoch_10.pth')

    # Test the model
    start_testing('weights/best_model.pth')

    # Show training progress
    show_training_progress()
    """
    print(example_usage.__doc__)

if __name__ == "__main__":
    # This will run when executed directly in Colab
    print("YOLOv8 Hybrid Trainer for Google Colab")
    print("=" * 50)

    # Show example usage
    example_usage()


YOLOv8 Hybrid Trainer for Google Colab

    Example usage in Google Colab:
    
    # Start training with default parameters
    start_training()
    
    # Start training with custom parameters
    start_training(epochs=100, batch_size=8)
    
    # Resume training from checkpoint
    start_training(epochs=50, batch_size=16, resume_checkpoint='weights/checkpoint_epoch_10.pth')
    
    # Test the model
    start_testing('weights/best_model.pth')
    
    # Show training progress
    show_training_progress()
    


In [None]:
start_training(epochs=50, batch_size=16)

Device: cuda
Loading training dataset from: /kaggle/input/car-object-detection
Found 1176 images
Creating training dataset with 940 images...
Loading cache from: ./dataset_cache/testing_images.cache
Creating validation dataset with 236 images...
Loading cache from: ./dataset_cache/testing_images.cache
Training samples: 940
Validation samples: 236
Train batches: 59
Val batches: 59
Found 1176 images
Creating training dataset with 940 images...
Loading cache from: ./dataset_cache/testing_images.cache
Creating validation dataset with 236 images...
Loading cache from: ./dataset_cache/testing_images.cache
Training samples: 940
Validation samples: 236
Train batches: 59
Val batches: 59
Starting training for 50 epochs...
Training samples: 940
Validation samples: 940
Batch size: 16, Accumulate: 4


Epoch 1/50:  75%|███████▍  | 44/59 [00:53<00:14,  1.03it/s, loss=7.49e+5]