In [30]:
!pip install ultralytics --quiet
from tqdm import tqdm
import os
import timm, torch
import torchvision
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.functional import cross_entropy
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision.ops import batched_nms
import numpy
import math
import cv2
import random
from PIL import Image
import copy
from time import time
import yaml

In [39]:
FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp'


class YoloDataset(data.Dataset):
    def __init__(self, foldername=Path('../.'), imgsz=640, names={}, augment=False):
        self.names = names
        self.mosaic = augment
        self.augment = augment
        self.imgsz = imgsz

        # Read labels
        filenames = sorted(str(p) for p in foldername.rglob("*") if p.is_file())
        
        labels = self.load_label(filenames)
        self.labels = list(labels.values())
        self.filenames = list(labels.keys())  # update
        self.n = len(self.filenames)  # number of samples
        self.indices = range(self.n)
        # Albumentations (optional, only used if package is installed)
        self.albumentations = Albumentations()

    def __getitem__(self, index):
        index = self.indices[index]

        if self.mosaic and random.random() < self.params['mosaic']:
            # Load MOSAIC
            image, label = self.load_mosaic(index, self.params)
            # MixUp augmentation
            if random.random() < self.params['mix_up']:
                index = random.choice(self.indices)
                mix_image1, mix_label1 = image, label
                mix_image2, mix_label2 = self.load_mosaic(index, self.params)

                image, label = mix_up(mix_image1, mix_label1, mix_image2, mix_label2)
        else:
            # Load image
            image, shape = self.load_image(index)
            h, w = image.shape[:2]

            # Resize
            image, ratio, pad = resize(image, self.input_size, self.augment)
            label = self.labels[index].copy()
            if label.size:
                label[:, 1:] = xywhn2xyxy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
            if self.augment:
                image, label = random_perspective(image, label, self.params)

        nl = len(label)  # number of labels
        h, w = image.shape[:2]
        cls = label[:, 0:1]
        box = label[:, 1:5]
        box = xyxy2xywhn(box, w, h)

        if self.augment:
            # Albumentations
            image, box, cls = self.albumentations(image, box, cls)
            nl = len(box)  # update after albumentations
            # HSV color-space
            augment_hsv(image, self.params)
            # Flip up-down
            if random.random() < self.params['flip_ud']:
                image = numpy.flipud(image)
                if nl:
                    box[:, 1] = 1 - box[:, 1]
            # Flip left-right
            if random.random() < self.params['flip_lr']:
                image = numpy.fliplr(image)
                if nl:
                    box[:, 0] = 1 - box[:, 0]

        target_cls = torch.zeros((nl, 1))
        target_box = torch.zeros((nl, 4))
        if nl:
            target_cls = torch.from_numpy(cls)
            target_box = torch.from_numpy(box)

        # Convert HWC to CHW, BGR to RGB
        sample = image.transpose((2, 0, 1))[::-1]
        sample = numpy.ascontiguousarray(sample)

        return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)

    def __len__(self):
        return len(self.filenames)

    def load_image(self, i):
        image = cv2.imread(self.filenames[i])
        h, w = image.shape[:2]
        r = self.input_size / max(h, w)
        if r != 1:
            image = cv2.resize(image,
                               dsize=(int(w * r), int(h * r)),
                               interpolation=resample() if self.augment else cv2.INTER_LINEAR)
        return image, (h, w)

    def load_mosaic(self, index, params):
        label4 = []
        border = [-self.input_size // 2, -self.input_size // 2]
        image4 = numpy.full((self.input_size * 2, self.input_size * 2, 3), 0, dtype=numpy.uint8)
        y1a, y2a, x1a, x2a, y1b, y2b, x1b, x2b = (None, None, None, None, None, None, None, None)

        xc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
        yc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))

        indices = [index] + random.choices(self.indices, k=3)
        random.shuffle(indices)

        for i, index in enumerate(indices):
            # Load image
            image, _ = self.load_image(index)
            shape = image.shape
            if i == 0:  # top left
                x1a = max(xc - shape[1], 0)
                y1a = max(yc - shape[0], 0)
                x2a = xc
                y2a = yc
                x1b = shape[1] - (x2a - x1a)
                y1b = shape[0] - (y2a - y1a)
                x2b = shape[1]
                y2b = shape[0]
            if i == 1:  # top right
                x1a = xc
                y1a = max(yc - shape[0], 0)
                x2a = min(xc + shape[1], self.input_size * 2)
                y2a = yc
                x1b = 0
                y1b = shape[0] - (y2a - y1a)
                x2b = min(shape[1], x2a - x1a)
                y2b = shape[0]
            if i == 2:  # bottom left
                x1a = max(xc - shape[1], 0)
                y1a = yc
                x2a = xc
                y2a = min(self.input_size * 2, yc + shape[0])
                x1b = shape[1] - (x2a - x1a)
                y1b = 0
                x2b = shape[1]
                y2b = min(y2a - y1a, shape[0])
            if i == 3:  # bottom right
                x1a = xc
                y1a = yc
                x2a = min(xc + shape[1], self.input_size * 2)
                y2a = min(self.input_size * 2, yc + shape[0])
                x1b = 0
                y1b = 0
                x2b = min(shape[1], x2a - x1a)
                y2b = min(y2a - y1a, shape[0])

            pad_w = x1a - x1b
            pad_h = y1a - y1b
            image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]

            # Labels
            label = self.labels[index].copy()
            if len(label):
                label[:, 1:] = xywhn2xyxy(label[:, 1:], shape[1], shape[0], pad_w, pad_h)
            label4.append(label)

        # Concat/clip labels
        label4 = numpy.concatenate(label4, 0)
        for x in label4[:, 1:]:
            numpy.clip(x, 0, 2 * self.input_size, out=x)

        # Augment
        image4, label4 = random_perspective(image4, label4, params, border)

        return image4, label4

    @staticmethod
    def collate_fn(batch):
        samples, cls, box, indices = zip(*batch)

        cls = torch.cat(cls, dim=0)
        box = torch.cat(box, dim=0)

        new_indices = list(indices)
        for i in range(len(indices)):
            new_indices[i] += i
        indices = torch.cat(new_indices, dim=0)

        targets = {'cls': cls,
                   'box': box,
                   'idx': indices}
        return torch.stack(samples, dim=0), targets

    @staticmethod
    def load_label(filenames):
        path = f'{os.path.dirname(filenames[0])}.cache'
        if os.path.exists(path):
            return torch.load(path, weights_only=False)
        x = {}
        for filename in filenames:
            try:
                # verify images
                with open(filename, 'rb') as f:
                    image = Image.open(f)
                    image.verify()  # PIL verify
                shape = image.size  # image size
                assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
                assert image.format.lower() in FORMATS, f'invalid image format {image.format}'

                # verify labels
                a = f'{os.sep}images{os.sep}'
                b = f'{os.sep}labels{os.sep}'
                if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit('.', 1)[0] + '.txt'):
                    with open(b.join(filename.rsplit(a, 1)).rsplit('.', 1)[0] + '.txt') as f:
                        label = [x.split() for x in f.read().strip().splitlines() if len(x)]
                        label = numpy.array(label, dtype=numpy.float32)
                    nl = len(label)
                    if nl:
                        assert (label >= 0).all()
                        assert label.shape[1] == 5
                        assert (label[:, 1:] <= 1).all()
                        _, i = numpy.unique(label, axis=0, return_index=True)
                        if len(i) < nl:  # duplicate row check
                            label = label[i]  # remove duplicates
                    else:
                        label = numpy.zeros((0, 5), dtype=numpy.float32)
                else:
                    label = numpy.zeros((0, 5), dtype=numpy.float32)
            except FileNotFoundError:
                label = numpy.zeros((0, 5), dtype=numpy.float32)
            except AssertionError:
                continue
            x[filename] = label
        torch.save(x, path)
        return x


def xywhn2xyxy(x, w=640, h=640, pad_w=0, pad_h=0):
    # Convert nx4 boxes
    # from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = numpy.copy(x)
    y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w  # top left x
    y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h  # top left y
    y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w  # bottom right x
    y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h  # bottom right y
    return y


def xyxy2xywhn(x, w, h):
    # warning: inplace clip
    x[:, [0, 2]] = x[:, [0, 2]].clip(0, w - 1E-3)  # x1, x2
    x[:, [1, 3]] = x[:, [1, 3]].clip(0, h - 1E-3)  # y1, y2

    # Convert nx4 boxes
    # from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
    y = numpy.copy(x)
    y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w  # x center
    y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h  # y center
    y[:, 2] = (x[:, 2] - x[:, 0]) / w  # width
    y[:, 3] = (x[:, 3] - x[:, 1]) / h  # height
    return y


def resample():
    choices = (cv2.INTER_AREA,
               cv2.INTER_CUBIC,
               cv2.INTER_LINEAR,
               cv2.INTER_NEAREST,
               cv2.INTER_LANCZOS4)
    return random.choice(choices)


def augment_hsv(image, params):
    # HSV color-space augmentation
    h = params['hsv_h']
    s = params['hsv_s']
    v = params['hsv_v']

    r = numpy.random.uniform(-1, 1, 3) * [h, s, v] + 1
    h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV))

    x = numpy.arange(0, 256, dtype=r.dtype)
    lut_h = ((x * r[0]) % 180).astype('uint8')
    lut_s = numpy.clip(x * r[1], 0, 255).astype('uint8')
    lut_v = numpy.clip(x * r[2], 0, 255).astype('uint8')

    hsv = cv2.merge((cv2.LUT(h, lut_h), cv2.LUT(s, lut_s), cv2.LUT(v, lut_v)))
    cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR, dst=image)  # no return needed


def resize(image, input_size, augment):
    # Resize and pad image while meeting stride-multiple constraints
    shape = image.shape[:2]  # current shape [height, width]

    # Scale ratio (new / old)
    r = min(input_size / shape[0], input_size / shape[1])
    if not augment:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)

    # Compute padding
    pad = int(round(shape[1] * r)), int(round(shape[0] * r))
    w = (input_size - pad[0]) / 2
    h = (input_size - pad[1]) / 2

    if shape[::-1] != pad:  # resize
        image = cv2.resize(image,
                           dsize=pad,
                           interpolation=resample() if augment else cv2.INTER_LINEAR)
    top, bottom = int(round(h - 0.1)), int(round(h + 0.1))
    left, right = int(round(w - 0.1)), int(round(w + 0.1))
    image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT)  # add border
    return image, (r, r), (w, h)


def candidates(box1, box2):
    # box1(4,n), box2(4,n)
    w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
    w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
    aspect_ratio = numpy.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16))  # aspect ratio
    return (w2 > 2) & (h2 > 2) & (w2 * h2 / (w1 * h1 + 1e-16) > 0.1) & (aspect_ratio < 100)


def random_perspective(image, label, params, border=(0, 0)):
    h = image.shape[0] + border[0] * 2
    w = image.shape[1] + border[1] * 2

    # Center
    center = numpy.eye(3)
    center[0, 2] = -image.shape[1] / 2  # x translation (pixels)
    center[1, 2] = -image.shape[0] / 2  # y translation (pixels)

    # Perspective
    perspective = numpy.eye(3)

    # Rotation and Scale
    rotate = numpy.eye(3)
    a = random.uniform(-params['degrees'], params['degrees'])
    s = random.uniform(1 - params['scale'], 1 + params['scale'])
    rotate[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)

    # Shear
    shear = numpy.eye(3)
    shear[0, 1] = math.tan(random.uniform(-params['shear'], params['shear']) * math.pi / 180)
    shear[1, 0] = math.tan(random.uniform(-params['shear'], params['shear']) * math.pi / 180)

    # Translation
    translate = numpy.eye(3)
    translate[0, 2] = random.uniform(0.5 - params['translate'], 0.5 + params['translate']) * w
    translate[1, 2] = random.uniform(0.5 - params['translate'], 0.5 + params['translate']) * h

    # Combined rotation matrix, order of operations (right to left) is IMPORTANT
    matrix = translate @ shear @ rotate @ perspective @ center
    if (border[0] != 0) or (border[1] != 0) or (matrix != numpy.eye(3)).any():  # image changed
        image = cv2.warpAffine(image, matrix[:2], dsize=(w, h), borderValue=(0, 0, 0))

    # Transform label coordinates
    n = len(label)
    if n:
        xy = numpy.ones((n * 4, 3))
        xy[:, :2] = label[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1
        xy = xy @ matrix.T  # transform
        xy = xy[:, :2].reshape(n, 8)  # perspective rescale or affine

        # create new boxes
        x = xy[:, [0, 2, 4, 6]]
        y = xy[:, [1, 3, 5, 7]]
        box = numpy.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T

        # clip
        box[:, [0, 2]] = box[:, [0, 2]].clip(0, w)
        box[:, [1, 3]] = box[:, [1, 3]].clip(0, h)
        # filter candidates
        indices = candidates(box1=label[:, 1:5].T * s, box2=box.T)

        label = label[indices]
        label[:, 1:5] = box[indices]

    return image, label


def mix_up(image1, label1, image2, label2):
    # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
    alpha = numpy.random.beta(a=32.0, b=32.0)  # mix-up ratio, alpha=beta=32.0
    image = (image1 * alpha + image2 * (1 - alpha)).astype(numpy.uint8)
    label = numpy.concatenate((label1, label2), 0)
    return image, label


class Albumentations:
    def __init__(self):
        self.transform = None
        try:
            import albumentations

            transforms = [albumentations.Blur(p=0.01),
                          albumentations.CLAHE(p=0.01),
                          albumentations.ToGray(p=0.01),
                          albumentations.MedianBlur(p=0.01)]
            self.transform = albumentations.Compose(transforms,
                                                    albumentations.BboxParams(format='yolo', label_fields=['class_labels']))

        except ImportError:  # package not installed, skip
            pass

    def __call__(self, image, box, cls):
        if self.transform:
            x = self.transform(image=image,
                               bboxes=box,
                               class_labels=cls)
            image = x['image']
            box = numpy.array(x['bboxes'])
            cls = numpy.array(x['class_labels'])
        return image, box, cls

In [40]:
def setup_seed():
    """
    Setup random seed.
    """
    random.seed(0)
    numpy.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def setup_multi_processes():
    """
    Setup multi-processing environment variables.
    """
    import cv2
    from os import environ
    from platform import system

    # set multiprocess start method as `fork` to speed up the training
    if system() != 'Windows':
        torch.multiprocessing.set_start_method('fork', force=True)

    # disable opencv multithreading to avoid system being overloaded
    cv2.setNumThreads(0)

    # setup OMP threads
    if 'OMP_NUM_THREADS' not in environ:
        environ['OMP_NUM_THREADS'] = '1'

    # setup MKL threads
    if 'MKL_NUM_THREADS' not in environ:
        environ['MKL_NUM_THREADS'] = '1'


def export_onnx(args):
    import onnx  # noqa

    inputs = ['images']
    outputs = ['outputs']
    dynamic = {'outputs': {0: 'batch', 1: 'anchors'}}

    m = torch.load('./weights/best.pt')['model'].float()
    x = torch.zeros((1, 3, args.input_size, args.input_size))

    torch.onnx.export(m.cpu(), x.cpu(),
                      f='./weights/best.onnx',
                      verbose=False,
                      opset_version=12,
                      # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
                      do_constant_folding=True,
                      input_names=inputs,
                      output_names=outputs,
                      dynamic_axes=dynamic or None)

    # Checks
    model_onnx = onnx.load('./weights/best.onnx')  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model

    onnx.save(model_onnx, './weights/best.onnx')
    # Inference example
    # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/autobackend.py


def xywh2xyxy(x):
    y = x.clone() if isinstance(x, torch.Tensor) else numpy.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y


def make_anchors(x, strides, offset=0.5):
    assert x is not None
    anchor_tensor, stride_tensor = [], []
    dtype, device = x[0].dtype, x[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = x[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + offset  # shift y
        sy, sx = torch.meshgrid(sy, sx)
        anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_tensor), torch.cat(stride_tensor)


def compute_metric(output, target, iou_v):
    # intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    (a1, a2) = target[:, 1:].unsqueeze(1).chunk(2, 2)
    (b1, b2) = output[:, :4].unsqueeze(0).chunk(2, 2)
    intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
    # IoU = intersection / (area1 + area2 - intersection)
    iou = intersection / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - intersection + 1e-7)

    correct = numpy.zeros((output.shape[0], iou_v.shape[0]))
    correct = correct.astype(bool)
    for i in range(len(iou_v)):
        # IoU > threshold and classes match
        x = torch.where((iou >= iou_v[i]) & (target[:, 0:1] == output[:, 5]))
        if x[0].shape[0]:
            matches = torch.cat((torch.stack(x, 1),
                                 iou[x[0], x[1]][:, None]), 1).cpu().numpy()  # [label, detect, iou]
            if x[0].shape[0] > 1:
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]]
                matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]]
            correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=output.device)


def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.7):
    max_wh = 7680
    max_det = 300
    max_nms = 30000

    bs = outputs.shape[0]  # batch size
    nc = outputs.shape[1] - 4  # number of classes
    xc = outputs[:, 4:4 + nc].amax(1) > confidence_threshold  # candidates

    # Settings
    start = time()
    limit = 0.5 + 0.05 * bs  # seconds to quit after
    output = [torch.zeros((0, 6), device=outputs.device)] * bs
    for index, x in enumerate(outputs):  # image index, image inference
        x = x.transpose(0, -1)[xc[index]]  # confidence

        # If none remain process next image
        if not x.shape[0]:
            continue

        # matrix nx6 (box, confidence, cls)
        box, cls = x.split((4, nc), 1)
        box = xywh2xyxy(box)  # (cx, cy, w, h) to (x1, y1, x2, y2)
        if nc > 1:
            i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * max_wh  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes, scores
        indices = torchvision.ops.nms(boxes, scores, iou_threshold)  # NMS
        indices = indices[:max_det]  # limit detections

        output[index] = x[indices]
        if (time() - start) > limit:
            break  # time limit exceeded

    return output


def smooth(y, f=0.1):
    # Box filter of fraction f
    nf = round(len(y) * f * 2) // 2 + 1  # number of filter elements (must be odd)
    p = numpy.ones(nf // 2)  # ones padding
    yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0)  # y padded
    return numpy.convolve(yp, numpy.ones(nf) / nf, mode='valid')  # y-smoothed


def plot_pr_curve(px, py, ap, names, save_dir):
    from matplotlib import pyplot
    fig, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)
    py = numpy.stack(py, axis=1)

    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
        for i, y in enumerate(py.T):
            ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}")  # plot(recall, precision)
    else:
        ax.plot(px, py, linewidth=1, color="grey")  # plot(recall, precision)

    ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    ax.set_title("Precision-Recall Curve")
    fig.savefig(save_dir, dpi=250)
    pyplot.close(fig)


def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
    from matplotlib import pyplot

    figure, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)

    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
        for i, y in enumerate(py):
            ax.plot(px, y, linewidth=1, label=f"{names[i]}")  # plot(confidence, metric)
    else:
        ax.plot(px, py.T, linewidth=1, color="grey")  # plot(confidence, metric)

    y = smooth(py.mean(0), f=0.05)
    ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}")
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    ax.set_title(f"{y_label}-Confidence Curve")
    figure.savefig(save_dir, dpi=250)
    pyplot.close(figure)


def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1E-16):
    """
    Compute the average precision, given the recall and precision curves.
    Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
    # Arguments
        tp:  True positives (nparray, nx1 or nx10).
        conf:  Object-ness value from 0-1 (nparray).
        output:  Predicted object classes (nparray).
        target:  True object classes (nparray).
    # Returns
        The average precision
    """
    # Sort by object-ness
    i = numpy.argsort(-conf)
    tp, conf, output = tp[i], conf[i], output[i]

    # Find unique classes
    unique_classes, nt = numpy.unique(target, return_counts=True)
    nc = unique_classes.shape[0]  # number of classes, number of detections

    # Create Precision-Recall curve and compute AP for each class
    p = numpy.zeros((nc, 1000))
    r = numpy.zeros((nc, 1000))
    ap = numpy.zeros((nc, tp.shape[1]))
    px, py = numpy.linspace(start=0, stop=1, num=1000), []  # for plotting
    for ci, c in enumerate(unique_classes):
        i = output == c
        nl = nt[ci]  # number of labels
        no = i.sum()  # number of outputs
        if no == 0 or nl == 0:
            continue

        # Accumulate FPs and TPs
        fpc = (1 - tp[i]).cumsum(0)
        tpc = tp[i].cumsum(0)

        # Recall
        recall = tpc / (nl + eps)  # recall curve
        # negative x, xp because xp decreases
        r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0)

        # Precision
        precision = tpc / (tpc + fpc)  # precision curve
        p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1)  # p at pr_score

        # AP from recall-precision curve
        for j in range(tp.shape[1]):
            m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0]))
            m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0]))

            # Compute the precision envelope
            m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre)))

            # Integrate area under curve
            x = numpy.linspace(start=0, stop=1, num=101)  # 101-point interp (COCO)
            ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x)  # integrate
            if plot and j == 0:
                py.append(numpy.interp(px, m_rec, m_pre))  # precision at mAP@0.5

    # Compute F1 (harmonic mean of precision and recall)
    f1 = 2 * p * r / (p + r + eps)
    if plot:
        names = dict(enumerate(names))  # to dict
        names = [v for k, v in names.items() if k in unique_classes]  # list: only classes that have data
        plot_pr_curve(px, py, ap, names, save_dir="./weights/PR_curve.png")
        plot_curve(px, f1, names, save_dir="./weights/F1_curve.png", y_label="F1")
        plot_curve(px, p, names, save_dir="./weights/P_curve.png", y_label="Precision")
        plot_curve(px, r, names, save_dir="./weights/R_curve.png", y_label="Recall")
    i = smooth(f1.mean(0), 0.1).argmax()  # max F1 index
    p, r, f1 = p[:, i], r[:, i], f1[:, i]
    tp = (r * nt).round()  # true positives
    fp = (tp / (p + eps) - tp).round()  # false positives
    ap50, ap = ap[:, 0], ap.mean(1)  # AP@0.5, AP@0.5:0.95
    m_pre, m_rec = p.mean(), r.mean()
    map50, mean_ap = ap50.mean(), ap.mean()
    return tp, fp, m_pre, m_rec, map50, mean_ap


def compute_iou(box1, box2, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
    b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
    ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
    c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
    rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
    # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
    v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
    with torch.no_grad():
        alpha = v / (v - iou + (1 + eps))
    return iou - (rho2 / c2 + v * alpha)  # CIoU


def strip_optimizer(filename):
    x = torch.load(filename, map_location="cpu")
    x['model'].half()  # to FP16
    for p in x['model'].parameters():
        p.requires_grad = False
    torch.save(x, f=filename)


def clip_gradients(model, max_norm=10.0):
    parameters = model.parameters()
    torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)


def load_weight(model, ckpt):
    dst = model.state_dict()
    src = torch.load(ckpt)['model'].float().cpu()

    ckpt = {}
    for k, v in src.state_dict().items():
        if k in dst and v.shape == dst[k].shape:
            ckpt[k] = v

    model.load_state_dict(state_dict=ckpt, strict=False)
    return model


def set_params(model, decay):
    p1 = []
    p2 = []
    norm = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k)
    for m in model.modules():
        for n, p in m.named_parameters(recurse=0):
            if not p.requires_grad:
                continue
            if n == "bias":  # bias (no decay)
                p1.append(p)
            elif n == "weight" and isinstance(m, norm):  # norm-weight (no decay)
                p1.append(p)
            else:
                p2.append(p)  # weight (with decay)
    return [{'params': p1, 'weight_decay': 0.00},
            {'params': p2, 'weight_decay': decay}]


def plot_lr(args, optimizer, scheduler, num_steps):
    from matplotlib import pyplot

    optimizer = copy.copy(optimizer)
    scheduler = copy.copy(scheduler)

    y = []
    for epoch in range(args.epochs):
        for i in range(num_steps):
            step = i + num_steps * epoch
            scheduler.step(step, optimizer)
            y.append(optimizer.param_groups[0]['lr'])
    pyplot.plot(y, '.-', label='LR')
    pyplot.xlabel('step')
    pyplot.ylabel('LR')
    pyplot.grid()
    pyplot.xlim(0, args.epochs * num_steps)
    pyplot.ylim(0)
    pyplot.savefig('./weights/lr.png', dpi=200)
    pyplot.close()


class CosineLR:
    def __init__(self, args, params, num_steps):
        max_lr = params['max_lr']
        min_lr = params['min_lr']

        warmup_steps = int(max(params['warmup_epochs'] * num_steps, 100))
        decay_steps = int(args.epochs * num_steps - warmup_steps)

        warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps))

        decay_lr = []
        for step in range(1, decay_steps + 1):
            alpha = math.cos(math.pi * step / decay_steps)
            decay_lr.append(min_lr + 0.5 * (max_lr - min_lr) * (1 + alpha))

        self.total_lr = numpy.concatenate((warmup_lr, decay_lr))

    def step(self, step, optimizer):
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.total_lr[step]


class LinearLR:
    def __init__(self, args, params, num_steps):
        max_lr = params['max_lr']
        min_lr = params['min_lr']

        warmup_steps = int(max(params['warmup_epochs'] * num_steps, 100))
        decay_steps = int(args.epochs * num_steps - warmup_steps)

        warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
        decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)

        self.total_lr = numpy.concatenate((warmup_lr, decay_lr))

    def step(self, step, optimizer):
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.total_lr[step]


class EMA:
    """
    Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        # Create EMA
        self.ema = copy.deepcopy(model).eval()  # FP32 EMA
        self.updates = updates  # number of EMA updates
        # decay exponential ramp (to help early epochs)
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        if hasattr(model, 'module'):
            model = model.module
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)

            msd = model.state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()


class AverageMeter:
    def __init__(self):
        self.num = 0
        self.sum = 0
        self.avg = 0

    def update(self, v, n):
        if not math.isnan(float(v)):
            self.num = self.num + n
            self.sum = self.sum + v * n
            self.avg = self.sum / self.num


class Assigner(torch.nn.Module):
    def __init__(self, nc=80, top_k=10, alpha=0.5, beta=6.0, eps=1E-9):
        super().__init__()
        self.top_k = top_k
        self.nc = nc
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    @torch.no_grad()
    def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
        batch_size = pd_scores.size(0)
        num_max_boxes = gt_bboxes.size(1)

        if num_max_boxes == 0:
            device = gt_bboxes.device
            return (torch.zeros_like(pd_bboxes).to(device),
                    torch.zeros_like(pd_scores).to(device),
                    torch.zeros_like(pd_scores[..., 0]).to(device))

        num_anchors = anc_points.shape[0]
        shape = gt_bboxes.shape
        lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
        mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2)
        mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
        na = pd_bboxes.shape[-2]
        gt_mask = (mask_in_gts * mask_gt).bool()  # b, max_num_obj, h*w
        overlaps = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
        bbox_scores = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)

        ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
        ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes)  # b, max_num_obj
        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
        bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask]  # b, max_num_obj, h*w

        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask]
        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask]
        overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0)

        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)

        top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool()
        top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True)
        if top_k_mask is None:
            top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices)
        top_k_indices.masked_fill_(~top_k_mask, 0)

        mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device)
        ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device)
        for k in range(self.top_k):
            mask_top_k.scatter_add_(-1, top_k_indices[:, :, k:k + 1], ones)
        mask_top_k.masked_fill_(mask_top_k > 1, 0)
        mask_top_k = mask_top_k.to(align_metric.dtype)
        mask_pos = mask_top_k * mask_in_gts * mask_gt

        fg_mask = mask_pos.sum(-2)
        if fg_mask.max() > 1:
            mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1)
            max_overlaps_idx = overlaps.argmax(1)

            is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
            is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)

            mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
            fg_mask = mask_pos.sum(-2)
        target_gt_idx = mask_pos.argmax(-2)

        # Assigned target
        index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None]
        target_index = target_gt_idx + index * num_max_boxes
        target_labels = gt_labels.long().flatten()[target_index]

        target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_index]

        # Assigned target scores
        target_labels.clamp_(0)

        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
                                    dtype=torch.int64,
                                    device=target_labels.device)
        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)

        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc)
        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)

        # Normalize
        align_metric *= mask_pos
        pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
        pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
        target_scores = target_scores * norm_align_metric

        return target_bboxes, target_scores, fg_mask.bool()


class QFL(torch.nn.Module):
    def __init__(self, beta=2.0):
        super().__init__()
        self.beta = beta
        self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, outputs, targets):
        bce_loss = self.bce_loss(outputs, targets)
        return torch.pow(torch.abs(targets - outputs.sigmoid()), self.beta) * bce_loss


class VFL(torch.nn.Module):
    def __init__(self, alpha=0.75, gamma=2.00, iou_weighted=True):
        super().__init__()
        assert alpha >= 0.0
        self.alpha = alpha
        self.gamma = gamma
        self.iou_weighted = iou_weighted
        self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, outputs, targets):
        assert outputs.size() == targets.size()
        targets = targets.type_as(outputs)

        if self.iou_weighted:
            focal_weight = targets * (targets > 0.0).float() + \
                           self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * \
                           (targets <= 0.0).float()

        else:
            focal_weight = (targets > 0.0).float() + \
                           self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * \
                           (targets <= 0.0).float()

        return self.bce_loss(outputs, targets) * focal_weight


class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=1.5):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, outputs, targets):
        loss = self.bce_loss(outputs, targets)

        if self.alpha > 0:
            alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha)
            loss *= alpha_factor

        if self.gamma > 0:
            outputs_sigmoid = outputs.sigmoid()
            p_t = targets * outputs_sigmoid + (1 - targets) * (1 - outputs_sigmoid)
            gamma_factor = (1.0 - p_t) ** self.gamma
            loss *= gamma_factor

        return loss


class BoxLoss(torch.nn.Module):
    def __init__(self, dfl_ch):
        super().__init__()
        self.dfl_ch = dfl_ch

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        # IoU loss
        weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
        iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
        loss_box = ((1.0 - iou) * weight).sum() / target_scores_sum

        # DFL loss
        a, b = target_bboxes.chunk(2, -1)
        target = torch.cat((anchor_points - a, b - anchor_points), -1)
        target = target.clamp(0, self.dfl_ch - 0.01)
        loss_dfl = self.df_loss(pred_dist[fg_mask].view(-1, self.dfl_ch + 1), target[fg_mask])
        loss_dfl = (loss_dfl * weight).sum() / target_scores_sum

        return loss_box, loss_dfl

    @staticmethod
    def df_loss(pred_dist, target):
        # Distribution Focal Loss (DFL)
        # https://ieeexplore.ieee.org/document/9792391
        tl = target.long()  # target left
        tr = tl + 1  # target right
        wl = tr - target  # weight left
        wr = 1 - wl  # weight right
        left_loss = cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape)
        right_loss = cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape)
        return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True)


class ComputeLoss:
    def __init__(self, model, params):
        if hasattr(model, 'module'):
            model = model.module

        device = next(model.parameters()).device

        m = model[22]  # Head() module

        self.params = params
        self.stride = m.stride
        self.nc = m.nc
        self.no = m.no
        self.reg_max = m.ch
        self.device = device

        self.assigner = Assigner(self.nc)
        self.box_loss = BoxLoss(m.ch - 1).to(device)
        self.cls_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

        self.project = torch.arange(m.ch, dtype=torch.float, device=device)

    def box_decode(self, anchor_points, pred_dist):
        b, a, c = pred_dist.shape
        pred_dist = pred_dist.view(b, a, 4, c // 4)
        pred_dist = pred_dist.softmax(3)
        pred_dist = pred_dist.matmul(self.project.type(pred_dist.dtype))
        lt, rb = pred_dist.chunk(2, -1)
        x1y1 = anchor_points - lt
        x2y2 = anchor_points + rb
        return torch.cat(tensors=(x1y1, x2y2), dim=-1)

    def __call__(self, outputs, targets):
        x = torch.cat([i.view(outputs[0].shape[0], self.no, -1) for i in outputs], dim=2)
        boxes, scores = x.split(split_size=(self.reg_max * 4, self.nc), dim=1)

        boxes = boxes.permute(0, 2, 1).contiguous()
        scores = scores.permute(0, 2, 1).contiguous()

        data_type = scores.dtype
        batch_size = scores.shape[0]
        input_size = torch.tensor(outputs[0].shape[2:], device=self.device, dtype=data_type) * self.stride[0]
        anchor_points, stride_tensor = make_anchors(outputs, self.stride, offset=0.5)

        idx = targets['idx'].view(-1, 1)
        cls = targets['cls'].view(-1, 1)
        box = targets['box']

        targets = torch.cat(tensors=(idx, cls, box), dim=1).to(self.device)
        if targets.shape[0] == 0:
            gt = torch.zeros(batch_size, 0, 5, device=self.device)
        else:
            i = targets[:, 0]
            _, counts = i.unique(return_counts=True)
            counts = counts.to(dtype=torch.int32)
            gt = torch.zeros(batch_size, counts.max(), 5, device=self.device)
            for j in range(batch_size):
                matches = i == j
                n = matches.sum()
                if n:
                    gt[j, :n] = targets[matches, 1:]
            x = gt[..., 1:5].mul_(input_size[[1, 0, 1, 0]])
            y = torch.empty_like(x)
            dw = x[..., 2] / 2  # half-width
            dh = x[..., 3] / 2  # half-height
            y[..., 0] = x[..., 0] - dw  # top left x
            y[..., 1] = x[..., 1] - dh  # top left y
            y[..., 2] = x[..., 0] + dw  # bottom right x
            y[..., 3] = x[..., 1] + dh  # bottom right y
            gt[..., 1:5] = y

        target_labels, target_bboxes = gt.split(split_size=(1, 4), dim=2)
        target_mask = target_bboxes.sum(2, keepdim=True).gt_(0)

        decoded_boxes = self.box_decode(anchor_points, boxes)
        assigned_targets = self.assigner(scores.detach().sigmoid(),
                                         (decoded_boxes.detach() * stride_tensor).type(target_bboxes.dtype),
                                         anchor_points * stride_tensor, target_labels, target_bboxes, target_mask)
        target_bboxes, target_scores, fg_mask = assigned_targets

        target_scores_sum = max(target_scores.sum(), 1)

        loss_cls = self.cls_loss(scores, target_scores.to(data_type)).sum() / target_scores_sum  # BCE

        # Box loss
        loss_box = torch.zeros(1, device=self.device)
        loss_dfl = torch.zeros(1, device=self.device)
        if fg_mask.sum():
            target_bboxes /= stride_tensor
            loss_box, loss_dfl = self.box_loss(boxes,
                                               decoded_boxes,
                                               anchor_points,
                                               target_bboxes,
                                               target_scores,
                                               target_scores_sum, fg_mask)

        loss_box *= self.params['box']  # box gain
        loss_cls *= self.params['cls']  # cls gain
        loss_dfl *= self.params['dfl']  # dfl gain

        return loss_box, loss_cls, loss_dfl

In [41]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
        self.act = nn.SiLU(inplace=True)
    
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
    
class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, shortcut=True):
        super().__init__()
        self.cv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.cv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.add = shortcut

    def forward(self, x):
        x_in = x
        x = self.cv1(x)
        x = self.cv2(x)
        if self.add:
            x += x_in
        return x
        

class C2f(nn.Module):
    def __init__(self, in_channels, out_channels, num_bottlenecks, shortcut=True):
        super().__init__()
        self.mid_channels = out_channels // 2
        self.num_bottlenecks = num_bottlenecks
        self.cv1 = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.cv2 = Conv((num_bottlenecks+2)*out_channels//2, out_channels, kernel_size=1, stride=1, padding=0)
        self.m = nn.ModuleList([Bottleneck(self.mid_channels, self.mid_channels, shortcut) for _ in range(num_bottlenecks)]) # n bottlenecks
        self.add = shortcut
    
    def forward(self, x):
        x = self.cv1(x)
        x1, x2 = x[:, :x.shape[1]//2, :, :], x[:, x.shape[1]//2:, :, :]
        outputs = [x1, x2] # x1 is fed to the bottlenecks

        for i in range(self.num_bottlenecks):
            x1 = self.m[i](x1)
            outputs.insert(0, x1)
        
        outputs = torch.cat(outputs, dim=1)
        out = self.cv2(outputs)
        return out

    
class SPPF(nn.Module): # EXPLORE WHY!!!!
    def __init__(self, in_channels, out_channels, kernel_size=5): #kernel_size = size of maxpool
        super().__init__()
        hidden_channels = in_channels // 2
        self.cv1 = Conv(in_channels, hidden_channels, kernel_size=1, stride=1, padding=0) # WHY???
        self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, ceil_mode=False) # WHY???
        self.cv2 = Conv(4*hidden_channels, out_channels, kernel_size=1, stride=1, padding=0)
    
    def forward(self, x):
        x = self.cv1(x)

        y1 = self.pool(x)
        y2 = self.pool(y1)
        y3 = self.pool(y2)

        y = torch.cat([x,y1,y2,y3], dim=1)
        
        y = self.cv2(y)
        return y

    
def yolo_params(version): # return d,w,r
    if version == 'n':
        return (1/3,1/4,2.0)
    elif version == 's':
        return (1/3,1/2,2.0)
    elif version == 'm':
        return (2/3,3/4,1.5)
    elif version == 'l':
        return (1.0,1.0,1.0)
    elif version == 'x':
        return (1.0, 1.25, 1.0)
# [b, c, h, w]
# P3: [B, 256, 80, 80] → stride 8
# P4: [B, 512, 40, 40] → stride 16
# P5: [B, 1024, 20, 20] → stride 32


# DFL (“Distribution Focal Loss”) stores probabilities over reg_max(ch) bins for each side (l,t,r,b).
# To decode those distributions into distances, you take the expected value over bins:

# softmax over bins → probabilities

# dot product with [0,1,2,…,reg_max−1] → expected bin index

# (later you multiply by stride to get pixels)

class Concat(nn.Module):
    def __init__(self, dim=1): 
        super().__init__()
        self.dim = dim
    def forward(self, xs):       # xs is a tuple/list of tensors
        return torch.cat(xs, self.dim)



class DFL(nn.Module):
    def __init__(self, ch=16):
        super().__init__()
        self.ch = ch
        self.conv = nn.Conv2d(in_channels=ch, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False).requires_grad_(False)

        x = torch.arange(self.ch, dtype=torch.float).view(1, self.ch, 1, 1)
        self.conv.weight.data.copy_(x)
    
    def forward(self, x): # x = [B, C_in, c]
        b, c, a = x.shape # b = B  c = C_in = 4*ch  a = c
        x = x.view(b, 4, self.ch, a).transpose(1, 2)  # [B, ch(values), 4, c]

        x = x.softmax(1)  # [B, ch(softmax values), 4, c]
        x = self.conv(x)  # [B, 1, 4, c]
        return x.view(b, 4, a)  # [B, 4, c] so it returns the l,t,r,b values(in bin) for every batch (we don't need the out channel of conv)


class Detect(nn.Module):
    def __init__(self, version, ch=16, nc=4):
        super().__init__()
        self.ch=ch                          # dfl channels
        self.coordinates=self.ch*4          # number of bounding boxes coordinates
        self.nc=nc                 # 4 for our dataset
        self.no=self.coordinates+self.nc    # num of outputs per anchor box
        self.stride=torch.zeros(0)          # strides computed during build
        d,w,r = yolo_params(version=version)

        self.cv2=nn.ModuleList([
            # for box
            nn.Sequential(Conv(int(256*w), self.coordinates, kernel_size=3, stride=1, padding=1),
                          Conv(self.coordinates, self.coordinates, kernel_size=3, stride=1, padding=1),
                          nn.Conv2d(self.coordinates, self.coordinates, kernel_size=1, stride=1, padding=0)),
            
            nn.Sequential(Conv(int(512*w), self.coordinates, kernel_size=3, stride=1, padding=1),
                          Conv(self.coordinates, self.coordinates, kernel_size=3, stride=1, padding=1),
                          nn.Conv2d(self.coordinates, self.coordinates, kernel_size=1, stride=1, padding=0)),
            
            nn.Sequential(Conv(int(512*w*r), self.coordinates, kernel_size=3, stride=1, padding=1),
                          Conv(self.coordinates, self.coordinates, kernel_size=3, stride=1, padding=1),
                          nn.Conv2d(self.coordinates, self.coordinates, kernel_size=1, stride=1, padding=0)),
        ])

        # for classification
        self.cv3=nn.ModuleList([
            nn.Sequential(Conv(int(256*w), self.nc, kernel_size=3, stride=1, padding=1),
                          Conv(self.nc, self.nc, kernel_size=3, stride=1, padding=1),
                          nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0)),
            
            nn.Sequential(Conv(int(512*w), self.nc, kernel_size=3, stride=1, padding=1),
                          Conv(self.nc, self.nc, kernel_size=3, stride=1, padding=1),
                          nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0)),
            
            nn.Sequential(Conv(int(512*w*r), self.nc, kernel_size=3, stride=1, padding=1),
                          Conv(self.nc, self.nc, kernel_size=3, stride=1, padding=1),
                          nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0)),
        ])

        # dfl
        self.dfl = DFL()

    def forward(self, x): # x = (out1,out2,out3), outx = [B, chx, wx, hx]
        for i in range(len(self.cv2)):
            box = self.cv2[i](x[i])     # [b, num_coordinates, w, h]
            cls = self.cv3[i](x[i])     # [b, num_classes, w, h]
            x[i] = torch.cat((box, cls), dim=1) # [b, num_coordinates+num_classes, w, h] 

        # in training no dfl output
        if self.training:
            return x    # [3,b,num_coordinates+num_classes,w,h]
        
        # in inference, dfl produces refined bounding box coordinates
        anchors, strides = (i.transpose(0, 1) for i in self.make_anchors(x, self.stride))

        # concatinate predictions from all detection layers
        x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2)  # [b, 4*self.ch + self.nc, sum_i(h[i]w[i])]

        # split out predictions for box and cls
        #       box=[b,4*self.ch,sum_i(h[i]w[i])]
        #       cls=[b,self.nc,sum_i(h[i]w[i])]
        box, cls = x.split(split_size=(4*self.ch, self.nc), dim=1)

        a, b = self.dfl(box).chunk(2, 1)    # a=b=[b,2*self.ch,sum_i(h[i]w[i])]
        a = anchors.unsqueeze(0) - a
        b = anchors.unsqueeze(0) + b
        box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)

        return torch.cat(tensors=(box * strides, cls.sigmoid()), dim=1)
    
    def make_anchors(self, x, strides, offset=0.5):
        # x = list of feature maps: x = [x[0],...,x[N-1]], N=num_detection_heads=3
        # each having shape [b, ch, w, h]
        # each feature map x[i] gives output output[i] = w*h anchor coordinates + w*h stride values
        # strides = coefficient of how much feature map is reduced compared to the original image
        assert x is not None
        anchor_tensor, stride_tensor = [], []
        dtype, device = x[0].dtype, x[0].device
        for i, stride in enumerate(strides):
            _, _, h, w = x[i].shape
            sx = torch.arange(end=w, device=device, dtype=dtype) + offset 
            sy = torch.arange(end=h, device=device, dtype=dtype) + offset
            sy, sx = torch.meshgrid(sy, sx)
            anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
            stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
        return torch.cat(anchor_tensor), torch.cat(stride_tensor)
    


class YOLO(nn.Module):
    def __init__(self, task=None, verbose=False, version='n', in_channels=3):
        super().__init__()
        d, w, r = yolo_params(version)
        # self.stride=torch.zeros(0)
        self.predictor = None  # reuse predictor
        self.model = None  # model object
        self.trainer = None  # trainer object
        self.ckpt = {}  # if loaded from *.pt
        self.cfg = None  # if loaded from *.yaml
        self.ckpt_path = None
        self.overrides = {}  # overrides for trainer object
        self.metrics = None  # validation/training metrics
        self.session = None  # HUB session
        self.task = task  # task type
        self.model_name = None  # model name
        self.model = nn.ModuleList([
            # backbone
            Conv(in_channels, int(64*w), kernel_size=3, stride=2, padding=1),              #0
            Conv(int(64*w), int(128*w), kernel_size=3, stride=2, padding=1),               #1
            C2f(int(128*w), int(128*w), num_bottlenecks=int(3*d), shortcut=True),          #2
            Conv(int(128*w), int(256*w), kernel_size=3, stride=2, padding=1),              #3
            C2f(int(256*w), int(256*w), num_bottlenecks=int(6*d), shortcut=True),          #4
            Conv(int(256*w), int(512*w), kernel_size=3, stride=2, padding=1),              #5
            C2f(int(512*w), int(512*w), num_bottlenecks=int(6*d), shortcut=True),          #6
            Conv(int(512*w), int(512*w*r), kernel_size=3, stride=2, padding=1),            #7
            C2f(int(512*w*r), int(512*w*r), num_bottlenecks=int(3*d), shortcut=True),      #8
            SPPF(int(512*w*r), int(512*w*r)),                                              #9

            # neck
            nn.Upsample(scale_factor=2, mode='nearest'),                                   #10
            Concat(),                                                                      #11
            C2f(int(512*w*(1+r)), int(512*w), num_bottlenecks=int(3*d), shortcut=False),   #12
            nn.Upsample(scale_factor=2, mode='nearest'),                                   #13
            Concat(),                                                                      #14
            C2f(int(768*w), int(256*w), num_bottlenecks=int(3*d), shortcut=False),         #15
            Conv(int(256*w), int(256*w), kernel_size=3, stride=2, padding=1),              #16
            Concat(),                                                                      #17
            C2f(int(768*w), int(512*w), num_bottlenecks=int(3*d), shortcut=False),         #18
            Conv(int(512*w), int(512*w), kernel_size=3, stride=2, padding=1),              #19
            Concat(),                                                                      #20
            C2f(int(512*w*(1+r)), int(512*w*r), num_bottlenecks=int(3*d), shortcut=False), #21

            # head
            Detect(version=version),                                                       #22
        ])
        # Delete super().training for accessing self.model.training
        del self.training

    def forward(self, x):
        # backbone forward
        x = self.model[0](x)
        x = self.model[1](x)
        x = self.model[2](x)
        x = self.model[3](x)
        out1 = self.model[4](x) # for concat
        x = self.model[5](out1)
        out2 = self.model[6](x) # for concat
        x = self.model[7](out2)
        x = self.model[8](x)
        out3 = self.model[9](x)

        # neck forward
        res_1 = out3 # for residual connection
        x = self.model[10](out3)
        x = self.model[11]((x, out2))
        res_2 = self.model[12](x) # for concat
        x = self.model[13](res_2)
        x = self.model[14]((x, out1))
        x1 = self.model[15](x) # for detect
        x = self.model[16](x1)
        x = self.model[17]((x, res_2))
        x2 = self.model[18](x) # for detect
        x = self.model[19](x2)
        x = self.model[20]((x, res_1))
        x3 = self.model[21](x) # for detect

        return self.model[22]([x1,x2,x3])

    def train(self, mode: bool | None = None, **ultra_kwargs):
        # 1) Pure PyTorch toggle if mode is given and no kwargs
        if (mode is not None) and (not ultra_kwargs):
            return super().train(mode)

        # 2) Ultralytics-style API when kwargs are provided
        if ultra_kwargs:
            return self._ultra_train(**ultra_kwargs)

        # 3) Default: behave like model.train(True)
        return super().train(True)

    # ---- the high-level trainer ----
    def _ultra_train(
        self,
        data: str,
        epochs: int = 100,
        imgsz: int = 640,
        batch: int = 16,
        name: str = "exp",
        project: str = "runs/train",
        device: int | str = 0,
        patience: int = 50,
        cos_lr: bool = True,
        lr: float = 5e-4,
        weight_decay: float = 5e-4,
        workers: int = 8,
        amp: bool = True,
        grad_clip: float | None = 10.0,
    ):
        # device
        if isinstance(device, (int, str)) and str(device).isdigit() and torch.cuda.is_available():
            device = torch.device(f"cuda:{device}")
        else:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)

        # I/O
        run_dir = os.path.join(project, name)
        os.makedirs(run_dir, exist_ok=True)
        best_path, last_path = os.path.join(run_dir, "best.pt"), os.path.join(run_dir, "last.pt")

        # data.yaml (expects keys: train, val, names/nc)
        with open(data, "r") as f:
            cfg = yaml.safe_load(f)
        nc = cfg.get("nc", len(cfg["names"]))
        names = cfg.get("names")
        

        # build loaders (plug in YOUR dataset + collate_fn)
        train_ds = YoloDataset(Path(cfg["train"]), imgsz=imgsz, names=names)           # <-- your class
        val_path = cfg.get("val")
        val_ds   = YoloDataset(Path(val_path), imgsz=imgsz, names=names) if val_path else None

        train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True,
                                  num_workers=workers, pin_memory=True,
                                  collate_fn=train_ds.collate_fn)
        val_loader = None
        if val_ds:
            val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False,
                                    num_workers=max(1, workers//2), pin_memory=True,
                                    collate_fn=val_ds.collate_fn, drop_last=False)

        def _gpu_mem_str(device):
            if device.type == "cuda":
                try:
                    m = torch.cuda.memory_reserved(device.index if device.index is not None else 0)
                except Exception:
                    m = torch.cuda.max_memory_allocated()
                return f"{m / (1024**3):>7.2f}G"
            return f"{0.0:>7.2f}G"
        
        class _EMAval:
            def __init__(self, beta=0.9): self.b=beta; self.v=None
            def upd(self,x): x=float(x); self.v=x if self.v is None else self.b*self.v+(1-self.b)*x; return self.v
        
        def _epoch_header():
            print(f"{'Epoch':>10} {'GPU_mem':>9} {'box_loss':>9} {'cls_loss':>9} {'dfl_loss':>9} {'Instances':>10} {'Size':>10}")

        # loss, opt, sched
        criterion = ComputeLoss(self.model, {"nc": nc, "names": names})          # <-- your loss
        optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = (torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
                     if cos_lr else torch.optim.lr_scheduler.MultiStepLR(optimizer, [int(0.8*epochs)], gamma=0.1))
        scaler = torch.amp.GradScaler("cuda", enabled=(amp and device.type == "cuda"))

        # loop with early stopping on val loss
        best_val = float("inf"); bad_epochs = 0
        for epoch in range(epochs):
            super().train(True)  # PyTorch training mode
            _epoch_header()
            eb, ec, ed = _EMAval(), _EMAval(), _EMAval()
            epoch_loss = 0.0

            pbar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True, ncols=120,
                bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]")
            
            for i, (imgs, targets) in pbar:
                imgs = imgs.to(device, non_blocking=True).float() / 255.0
                instances = int(targets["cls"].numel())

                optimizer.zero_grad(set_to_none=True)
                with torch.autocast("cuda", dtype=torch.float16, enabled=(amp and device.type == "cuda")):
                    outputs = self(imgs)  # training path → 3 feature maps
                    box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
                    loss = box_loss + cls_loss + dfl_loss

                scaler.scale(loss).backward()
                if grad_clip is not None:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(self.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
                
                epoch_loss += loss.item()
                # running EMAs for smooth prints
                b, c, d = eb.upd(box_loss.item()), ec.upd(cls_loss.item()), ed.upd(dfl_loss.item())
                desc = f"{epoch+1:>7}/{epochs:<3} {_gpu_mem_str(device)} {b:>9.3f} {c:>9.3f} {d:>9.3f} {instances:>10d} {imgsz:>10d}:"
                pbar.set_description_str(desc)
    
            
            # for imgs, targets in train_loader:
            #     imgs = imgs.to(device, non_blocking=True).float() / 255.0
            #     optimizer.zero_grad(set_to_none=True)
            #     with torch.autocast("cuda", dtype=torch.float16, enabled=(amp and device.type == "cuda")):
            #         outputs = self(imgs)  # your forward: returns list[3] in training mode
            #         box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
            #         loss = box_loss + cls_loss + dfl_loss

            #     scaler.scale(loss).backward()
            #     if grad_clip is not None:
            #         scaler.unscale_(optimizer)
            #         nn.utils.clip_grad_norm_(self.parameters(), grad_clip)
            #     scaler.step(optimizer)
            #     scaler.update()
            #     epoch_loss += loss.item()

            scheduler.step()
            train_loss = epoch_loss / max(1, len(train_loader))

            # validation
            val_loss = train_loss
            if val_loader is not None:
                super().eval()
                # header above the val bar
                print((" " * 17) + "Class     Images  Instances      Box(P          R      mAP50  mAP50-95):", end=" ")
                
                vbar = tqdm(enumerate(val_loader), total=len(val_loader), leave=True, ncols=120,
                    bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]")

                iouv = torch.linspace(0.5, 0.95, 10, device=device)
                tp_list, conf_list, pcls_list, tcls_list = [], [], [], []
                total = 0.0

                with torch.no_grad():
                    for _, (imgs, targets) in vbar:
                        imgs = imgs.to(device, non_blocking=True).float() / 255.0
                        outputs = self(imgs)  # inference path → [B, 4+nc, N]
                        vb, vc, vd = criterion(outputs, targets)
                        total += (vb + vc + vd).item()
        
                        # NMS → list of [ni,6] (x1,y1,x2,y2,conf,cls)
                        preds = non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.7)
        
                        B, _, H, W = imgs.shape
                        for b in range(B):
                            p = preds[b]
                            m = (targets["idx"] == b)
                            if m.any():
                                gt_cls = targets["cls"][m].to(device)
                                gt_box = xywhn2xyxy(targets["box"][m].to(device), W, H)
                                t = torch.zeros((gt_cls.numel(), 5), device=device)
                                t[:, 0] = gt_cls; t[:, 1:] = gt_box
                            else:
                                t = torch.zeros((0, 5), device=device)
        
                            if p.numel() == 0:
                                if t.numel():
                                    tcls_list.append(t[:, 0].cpu().numpy())
                                continue
        
                            correct = compute_metric(p, t, iouv)  # [n_det, 10] bool
                            tp_list.append(correct.cpu().numpy().astype(int))
                            conf_list.append(p[:, 4].cpu().numpy())
                            pcls_list.append(p[:, 5].cpu().numpy())
                            if t.numel():
                                tcls_list.append(t[:, 0].cpu().numpy())
        
                val_loss = total / max(1, len(val_loader))
        
                # reduce metrics across dataset
                if conf_list:
                    tp   = np.concatenate(tp_list, 0)
                    conf = np.concatenate(conf_list, 0)
                    pcls = np.concatenate(pcls_list, 0).astype(int)
                    tcls = np.concatenate(tcls_list, 0).astype(int) if tcls_list else np.zeros((0,), dtype=int)
                    _, _, P, R, mAP50, mAP5095 = compute_ap(tp, conf, pcls, tcls)
                else:
                    P = R = mAP50 = mAP5095 = float("nan")
                    tcls = np.zeros((0,), dtype=int)
        
                val_images = len(val_loader.dataset)
                val_instances = int(tcls.shape[0])
                print(f"\n{'':>19}{'all':>7}{val_images:>11}{val_instances:>12}"
                      f"{P:>11.3f}{R:>11.3f}{mAP50:>11.3f}{mAP5095:>11.3f}\n")
        
            # ------------------- Save & early stop -------------------
            torch.save(self.state_dict(), last_path)
            improved = val_loss < best_val - 1e-6
            if improved:
                best_val, bad_epochs = val_loss, 0
                torch.save(self.state_dict(), best_path)
            else:
                bad_epochs += 1
        
            if bad_epochs >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
        return self
        # return {"best_val_loss": best_val, "best": best_path, "last": last_path}

In [34]:
from pathlib import Path

# yaml_text = """\
# min_lr: 0.000100000000            # initial learning rate
# max_lr: 0.010000000000            # maximum learning rate
# momentum: 0.9370000000            # SGD momentum/Adam beta1
# weight_decay: 0.000500            # optimizer weight decay
# warmup_epochs: 3.00000            # warmup epochs
# box: 7.500000000000000            # box loss gain
# cls: 0.500000000000000            # cls loss gain
# dfl: 1.500000000000000            # dfl loss gain
# hsv_h: 0.0150000000000            # image HSV-Hue augmentation (fraction)
# hsv_s: 0.7000000000000            # image HSV-Saturation augmentation (fraction)
# hsv_v: 0.4000000000000            # image HSV-Value augmentation (fraction)
# degrees: 0.00000000000            # image rotation (+/- deg)
# translate: 0.100000000            # image translation (+/- fraction)
# scale: 0.5000000000000            # image scale (+/- gain)
# shear: 0.0000000000000            # image shear (+/- deg)
# flip_ud: 0.00000000000            # image flip up-down (probability)
# flip_lr: 0.50000000000            # image flip left-right (probability)
# mosaic: 1.000000000000            # image mosaic (probability)
# mix_up: 0.000000000000            # image mix-up (probability)
# names:
#   0: bacterial
#   1: fungal
#   2: pest
#   3: physio
# """

# path = Path("/kaggle/working/args.yaml")  # change if not on Kaggle
# path.write_text(yaml_text)
# print("Saved to:", path)


In [35]:
!cp -r /kaggle/input/effyolo /kaggle/working/

In [36]:
# from glob import glob
# from torch.utils.data import DataLoader

# ROOT = Path("/kaggle/working/effyolo")  # or local path to apple_leaf
# IMG_TRAIN = ROOT / "train/images"
# IMG_VALID = ROOT / "valid/images"
# IMG_TEST  = ROOT / "test/images"

# def collect_images(folder):
#     return sorted(str(p) for p in folder.rglob("*") if p.is_file())

# files_train = collect_images(IMG_TRAIN)
# files_valid = collect_images(IMG_VALID)
# files_test  = collect_images(IMG_TEST)

# print(len(files_train), "train |", len(files_valid), "valid |", len(files_test), "test")


# # input_size for the model
# input_size=640

# # get params from yaml file
# with open('/kaggle/working/args.yaml', errors='ignore') as f:
#         params = yaml.safe_load(f)

# train_data=Dataset(files_train,input_size,params,augment=False)
# valid_data=Dataset(files_valid,input_size,params,augment=False)
# test_data=Dataset(files_test,input_size,params,augment=False)

# train_loader = DataLoader(train_data, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, collate_fn=Dataset.collate_fn)
# valid_loader = DataLoader(valid_data, batch_size=8, shuffle=False, num_workers=2, pin_memory=True, collate_fn=Dataset.collate_fn)
# test_loader  = DataLoader(test_data, batch_size=8, shuffle=False, num_workers=2, pin_memory=True, collate_fn=Dataset.collate_fn)

# print(f"Train_loader : {len(train_loader)} batches")
# print(f"Train_loader : {len(valid_loader)} batches")
# print(f"Train_loader : {len(test_loader)} batches")

In [42]:
from collections import OrderedDict
from ultralytics import YOLO as UModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
u = UModel('yolov8n.pt')          # COCO-pretrained
model = YOLO().to(device)
model.model[22].stride = torch.tensor([8., 16., 32.], device=device)

tgt_sd = model.state_dict()
sd_src = u.model.state_dict()     # plain PyTorch state_dict
sd_shape_match = OrderedDict((k, v) for k, v in sd_src.items()
                             if k in tgt_sd and tgt_sd[k].shape == v.shape)

missing, unexpected = model.load_state_dict(sd_shape_match, strict=False)
print(len(missing), len(unexpected))

36 0


In [43]:
results = model.train(
    data='/kaggle/working/effyolo/data.yaml',
    epochs=200,
    imgsz=640,
    batch=8,
    name='fruit-disease-detector',
    project='/kaggle/working/runs/train',
    device='cuda',
    patience=50,
    cos_lr=True
)

IndexError: list index out of range

In [17]:
torch.manual_seed(1337)

# model, loss and optimizer
criterion=ComputeLoss(model.model, params)
optimizer=torch.optim.AdamW(model.model.parameters(), lr=0.0005, weight_decay=5e-4)

num_epochs=150
scaler = torch.amp.GradScaler('cuda', enabled=(device.type == 'cuda'))
autocast = torch.autocast('cuda', dtype=torch.float16, enabled=(device.type == 'cuda'))

model.train()
global_step = 0
for epoch in range(num_epochs):
    for imgs, targets in train_loader:          # <-- iterate the loader, not a single batch
        imgs = imgs.to(device, non_blocking=True).float() / 255.0
        # targets is a dict of tensors from your collate_fn; ComputeLoss moves them to device

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
            outputs = model(imgs)               # training mode → Head returns 3 feature maps
            box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
            loss = box_loss + cls_loss + dfl_loss

        scaler.scale(loss).backward()
        # clip_gradients(model, max_norm=10.0)  # optional
        scaler.step(optimizer)
        scaler.update()

        # if using your scheduler that steps per-iteration:
        # scheduler.step(global_step, optimizer)
        global_step += 1

    print(f"Epoch {epoch+1}/{num_epochs} "
          f"| loss: {loss.item():.4f} "
          f"| box: {box_loss.item():.4f} "
          f"| cls: {cls_loss.item():.4f} "
          f"| dfl: {dfl_loss.item():.4f}")

  with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch 1/150 | loss: 1093.9884 | box: 2.2886 | cls: 1088.9510 | dfl: 2.7488


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "/kaggle/working/yolo_weights.pt")