In [16]:
# Import các thư viện cần thiết

# Standard library imports
import os
import csv
import copy
import math
import random
import shutil
import time
from os import environ
from platform import system

# Third-party imports
import cv2
import numpy as np
import yaml
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset as TorchDataset, DataLoader, Subset
from torchvision.ops import box_iou, nms

# Optional imports
try:
    import onnx
except ImportError:
    onnx = None

try:
    import albumentations
except ImportError:
    albumentations = None

try:
    from roboflow import Roboflow
except ImportError:
    os.system("pip install roboflow")
    from roboflow import Roboflow

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [17]:
try:
    DRIVE_SAVE_PATH = "/kaggle/working/"
    os.makedirs(DRIVE_SAVE_PATH, exist_ok=True)

    SAVE_PATH = os.path.join(DRIVE_SAVE_PATH, "custom_yolo_model.pth")
    DATASET_PATH = "/kaggle/input/wild-animals-detection-yolov8"  # Fixed path

    CHECKPOINT_DIR = os.path.join(DRIVE_SAVE_PATH, "checkpoints")
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

except:
    SAVE_PATH = "./custom_yolo_model.pth"
    DATASET_PATH = "./roboflow_dataset"
    CHECKPOINT_DIR = "./checkpoints"
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

RESUME_TRAINING = True
SAVE_CHECKPOINT_EVERY = 10

# KIẾN TRÚC

In [18]:
class Conv(nn.Module):
    def __init__(self,in_channels, out_channels,kernel_size=3,stride=1,padding=1,groups=1,activation=True):
        super().__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,groups=groups)
        self.bn=nn.BatchNorm2d(out_channels,eps=0.001,momentum=0.03)
        self.act=nn.SiLU(inplace=True) if activation else nn.Identity()

    def forward(self,x):
        return self.act(self.bn(self.conv(x)))

In [19]:
# 2.1 Bottleneck: staack of 2 COnv with shortcut connnection (True/False)
class Bottleneck(nn.Module):
    def __init__(self,in_channels,out_channels,shortcut=True):
        super().__init__()
        self.conv1=Conv(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv2=Conv(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.shortcut=shortcut

    def forward(self,x):
        x_in=x # for residual connection
        x=self.conv1(x)
        x=self.conv2(x)
        if self.shortcut:
            x=x+x_in
        return x
    
# 2.2 C2f: Conv + bottleneck*N+ Conv
class C2f(nn.Module):
    def __init__(self,in_channels,out_channels, num_bottlenecks,shortcut=True):
        super().__init__()
        
        self.mid_channels=out_channels//2
        self.num_bottlenecks=num_bottlenecks

        self.conv1=Conv(in_channels,out_channels,kernel_size=1,stride=1,padding=0)
        
        # sequence of bottleneck layers
        self.m=nn.ModuleList([Bottleneck(self.mid_channels,self.mid_channels) for _ in range(num_bottlenecks)])

        self.conv2=Conv((num_bottlenecks+2)*out_channels//2,out_channels,kernel_size=1,stride=1,padding=0)
    
    def forward(self,x):
        x=self.conv1(x)

        # split x along channel dimension
        x1,x2=x[:,:x.shape[1]//2,:,:], x[:,x.shape[1]//2:,:,:]
        
        # list of outputs
        outputs=[x1,x2] # x1 is fed through the bottlenecks

        for i in range(self.num_bottlenecks):
            x1=self.m[i](x1)    # [bs,0.5c_out,w,h]
            outputs.insert(0,x1)

        outputs=torch.cat(outputs,dim=1) # [bs,0.5c_out(num_bottlenecks+2),w,h]
        out=self.conv2(outputs)

        return out

In [20]:
class SPPF(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=5):
        #kernel_size= size of maxpool
        super().__init__()
        hidden_channels=in_channels//2
        self.conv1=Conv(in_channels,hidden_channels,kernel_size=1,stride=1,padding=0)
        # concatenate outputs of maxpool and feed to conv2
        self.conv2=Conv(4*hidden_channels,out_channels,kernel_size=1,stride=1,padding=0)

        # maxpool is applied at 3 different sacles
        self.m=nn.MaxPool2d(kernel_size=kernel_size,stride=1,padding=kernel_size//2,dilation=1,ceil_mode=False)
    
    def forward(self,x):
        x=self.conv1(x)

        # apply maxpooling at diffent scales
        y1=self.m(x)
        y2=self.m(y1)
        y3=self.m(y2)

        # concantenate 
        y=torch.cat([x,y1,y2,y3],dim=1)

        # final conv
        y=self.conv2(y)

        return y


In [21]:
# backbone = DarkNet53

# return d,w,r based on version
def yolo_params(version):
    if version=='n':
        return 1/3,1/4,2.0
    elif version=='s':
        return 1/3,1/2,2.0
    elif version=='m':
        return 2/3,3/4,1.5
    elif version=='l':
        return 1.0,1.0,1.0
    elif version=='x':
        return 1.0,1.25,1.0
    
class Backbone(nn.Module):
    def __init__(self,version,in_channels=3,shortcut=True):
        super().__init__()
        d,w,r=yolo_params(version)

        # conv layers
        self.conv_0=Conv(in_channels,int(64*w),kernel_size=3,stride=2,padding=1)
        self.conv_1=Conv(int(64*w),int(128*w),kernel_size=3,stride=2,padding=1)
        self.conv_3=Conv(int(128*w),int(256*w),kernel_size=3,stride=2,padding=1)
        self.conv_5=Conv(int(256*w),int(512*w),kernel_size=3,stride=2,padding=1)
        self.conv_7=Conv(int(512*w),int(512*w*r),kernel_size=3,stride=2,padding=1)

        # c2f layers
        self.c2f_2=C2f(int(128*w),int(128*w),num_bottlenecks=int(3*d),shortcut=True)
        self.c2f_4=C2f(int(256*w),int(256*w),num_bottlenecks=int(6*d),shortcut=True)
        self.c2f_6=C2f(int(512*w),int(512*w),num_bottlenecks=int(6*d),shortcut=True)
        self.c2f_8=C2f(int(512*w*r),int(512*w*r),num_bottlenecks=int(3*d),shortcut=True)

        # sppf
        self.sppf=SPPF(int(512*w*r),int(512*w*r))
    
    def forward(self,x):
        x=self.conv_0(x)
        x=self.conv_1(x)

        x=self.c2f_2(x)

        x=self.conv_3(x)

        out1=self.c2f_4(x) # keep for output

        x=self.conv_5(out1)

        out2=self.c2f_6(x) # keep for output

        x=self.conv_7(out2)
        x=self.c2f_8(x)
        out3=self.sppf(x)

        return out1,out2,out3

print("----Nano model -----")
backbone_n=Backbone(version='n')
print(f"{sum(p.numel() for p in backbone_n.parameters())/1e6} million parameters")

print("----Small model -----")
backbone_s=Backbone(version='s')
print(f"{sum(p.numel() for p in backbone_s.parameters())/1e6} million parameters")


----Nano model -----
1.272656 million parameters
----Small model -----
5.079712 million parameters


In [22]:
# upsample = nearest-neighbor interpolation with scale_factor=2
#            doesn't have trainable paramaters
class Upsample(nn.Module):
    def __init__(self,scale_factor=2,mode='nearest'):
        super().__init__()
        self.scale_factor=scale_factor
        self.mode=mode

    def forward(self,x):
        return nn.functional.interpolate(x,scale_factor=self.scale_factor,mode=self.mode)

In [23]:
class Neck(nn.Module):
    def __init__(self,version):
        super().__init__()
        d,w,r=yolo_params(version)

        self.up=Upsample() # no trainable parameters
        self.c2f_1=C2f(in_channels=int(512*w*(1+r)), out_channels=int(512*w),num_bottlenecks=int(3*d),shortcut=False)
        self.c2f_2=C2f(in_channels=int(768*w), out_channels=int(256*w),num_bottlenecks=int(3*d),shortcut=False)
        self.c2f_3=C2f(in_channels=int(768*w), out_channels=int(512*w),num_bottlenecks=int(3*d),shortcut=False)
        self.c2f_4=C2f(in_channels=int(512*w*(1+r)), out_channels=int(512*w*r),num_bottlenecks=int(3*d),shortcut=False)

        self.cv_1=Conv(in_channels=int(256*w),out_channels=int(256*w),kernel_size=3,stride=2, padding=1)
        self.cv_2=Conv(in_channels=int(512*w),out_channels=int(512*w),kernel_size=3,stride=2, padding=1)


    def forward(self,x_res_1,x_res_2,x):    
        # x_res_1,x_res_2,x = output of backbone
        res_1=x              # for residual connection
        
        x=self.up(x)
        x=torch.cat([x,x_res_2],dim=1)

        res_2=self.c2f_1(x)  # for residual connection
        
        x=self.up(res_2)
        x=torch.cat([x,x_res_1],dim=1)

        out_1=self.c2f_2(x)

        x=self.cv_1(out_1)

        x=torch.cat([x,res_2],dim=1)
        out_2=self.c2f_3(x)

        x=self.cv_2(out_2)

        x=torch.cat([x,res_1],dim=1)
        out_3=self.c2f_4(x)

        return out_1,out_2,out_3

In [24]:
# DFL
class DFL(nn.Module):
    def __init__(self,ch=16):
        super().__init__()
        
        self.ch=ch
        
        self.conv=nn.Conv2d(in_channels=ch,out_channels=1,kernel_size=1,bias=False).requires_grad_(False)
        
        # initialize conv with [0,...,ch-1]
        x=torch.arange(ch,dtype=torch.float).reshape(1,ch,1,1)
        self.conv.weight.data[:]=torch.nn.Parameter(x) # DFL only has ch parameters

    def forward(self,x):
        # x must have num_channels = 4*ch: x=[bs,4*ch,c]
        b,c,a=x.shape                           # c=4*ch
        x=x.reshape(b,4,self.ch,a).transpose(1,2)  # [bs,ch,4,a]

        # take softmax on channel dimension to get distribution probabilities
        x=x.softmax(1)                          # [b,ch,4,a]
        x=self.conv(x)                          # [b,1,4,a]
        return x.reshape(b,4,a)                    # [b,4,a]

In [25]:
class Head(nn.Module):
    def __init__(self,version,ch=16,num_classes=5):

        super().__init__()
        self.ch=ch                          # dfl channels
        self.coordinates=self.ch*4          # number of bounding box coordinates 
        self.nc=num_classes                 # 5 for custom dataset
        self.no=self.coordinates+self.nc    # number of outputs per anchor box

        self.stride = torch.tensor([8., 16., 32.])
        
        d,w,r=yolo_params(version=version)
        
        # for bounding boxes
        self.box=nn.ModuleList([
            nn.Sequential(Conv(int(256*w),self.coordinates,kernel_size=3,stride=1,padding=1),
                          Conv(self.coordinates,self.coordinates,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.coordinates,self.coordinates,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w),self.coordinates,kernel_size=3,stride=1,padding=1),
                          Conv(self.coordinates,self.coordinates,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.coordinates,self.coordinates,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w*r),self.coordinates,kernel_size=3,stride=1,padding=1),
                          Conv(self.coordinates,self.coordinates,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.coordinates,self.coordinates,kernel_size=1,stride=1))
        ])

        # for classification
        self.cls=nn.ModuleList([
            nn.Sequential(Conv(int(256*w),self.nc,kernel_size=3,stride=1,padding=1),
                          Conv(self.nc,self.nc,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.nc,self.nc,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w),self.nc,kernel_size=3,stride=1,padding=1),
                          Conv(self.nc,self.nc,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.nc,self.nc,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w*r),self.nc,kernel_size=3,stride=1,padding=1),
                          Conv(self.nc,self.nc,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.nc,self.nc,kernel_size=1,stride=1))
        ])

        # dfl
        self.dfl=DFL()

    def forward(self,x):
        # x = output of Neck = list of 3 tensors with different resolution and different channel dim
        #     x[0]=[bs, ch0, w0, h0], x[1]=[bs, ch1, w1, h1], x[2]=[bs,ch2, w2, h2] 

        for i in range(len(self.box)):       # detection head i
            box=self.box[i](x[i])            # [bs,num_coordinates,w,h]
            cls=self.cls[i](x[i])            # [bs,num_classes,w,h]
            x[i]=torch.cat((box,cls),dim=1)  # [bs,num_coordinates+num_classes,w,h]

        # in training, no dfl output
        if self.training:
            return x                         # [3,bs,num_coordinates+num_classes,w,h]
        
        # in inference time, dfl produces refined bounding box coordinates
        anchors, strides = (i.transpose(0, 1) for i in self.make_anchors(x, self.stride))

        # concatenate predictions from all detection layers
        x = torch.cat([i.reshape(x[0].shape[0], self.no, -1) for i in x], dim=2) #[bs, 4*self.ch + self.nc, sum_i(h[i]w[i])]
        
        # split out predictions for box and cls
        #           box=[bs,4×self.ch,sum_i(h[i]w[i])]
        #           cls=[bs,self.nc,sum_i(h[i]w[i])]
        box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)


        a, b = self.dfl(box).chunk(2, 1)  # a=b=[bs,2×self.ch,sum_i(h[i]w[i])]
        a = anchors.unsqueeze(0) - a
        b = anchors.unsqueeze(0) + b
        box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
        
        return torch.cat(tensors=(box * strides, cls.sigmoid()), dim=1)


    def make_anchors(self, x, strides, offset=0.5):
        # x= list of feature maps: x=[x[0],...,x[N-1]], in our case N= num_detection_heads=3
        #                          each having shape [bs,ch,w,h]
        #    each feature map x[i] gives output[i] = w*h anchor coordinates + w*h stride values
        
        # strides = list of stride values indicating how much 
        #           the spatial resolution of the feature map is reduced compared to the original image

        assert x is not None
        anchor_tensor, stride_tensor = [], []
        dtype, device = x[0].dtype, x[0].device
        for i, stride in enumerate(strides):
            _, _, h, w = x[i].shape
            sx = torch.arange(end=w, device=device, dtype=dtype) + offset  # x coordinates of anchor centers
            sy = torch.arange(end=h, device=device, dtype=dtype) + offset  # y coordinates of anchor centers
            sy, sx = torch.meshgrid(sy, sx)                                # all anchor centers 
            anchor_tensor.append(torch.stack((sx, sy), -1).reshape(-1, 2))
            stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
        return torch.cat(anchor_tensor), torch.cat(stride_tensor)
        

In [26]:
import torch

# fake feature maps (bs=1, ch=3)
# ví dụ: 3 head detection tương ứng stride 8, 16, 32
x = [
    torch.zeros(1, 3, 80, 80),   # P3
    torch.zeros(1, 3, 40, 40),   # P4
    torch.zeros(1, 3, 20, 20)    # P5
]
strides = [8, 16, 32]

def make_anchors(x, strides, offset=0.5):
    anchor_tensor, stride_tensor = [], []
    dtype, device = x[0].dtype, x[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = x[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + offset
        sy = torch.arange(end=h, device=device, dtype=dtype) + offset
        sy, sx = torch.meshgrid(sy, sx, indexing='ij')  # chú ý indexing
        anchor_tensor.append(torch.stack((sx, sy), -1).reshape(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_tensor), torch.cat(stride_tensor)

# Test
anchors, strides_out = make_anchors(x, strides)

print("Anchor tensor shape:", anchors.shape)       # (8400, 2)
print("Stride tensor shape:", strides_out.shape)   # (8400, 1)

# In thử 5 anchor đầu tiên
print("First 5 anchors:\n", anchors[:5])
print("First 5 strides:\n", strides_out[:5].reshape(-1))

# In thử cuối cùng (P5)
print("Last 5 anchors:\n", anchors[-5:])
print("Last 5 strides:\n", strides_out[-5:].reshape(-1))

Anchor tensor shape: torch.Size([8400, 2])
Stride tensor shape: torch.Size([8400, 1])
First 5 anchors:
 tensor([[0.5000, 0.5000],
        [1.5000, 0.5000],
        [2.5000, 0.5000],
        [3.5000, 0.5000],
        [4.5000, 0.5000]])
First 5 strides:
 tensor([8., 8., 8., 8., 8.])
Last 5 anchors:
 tensor([[15.5000, 19.5000],
        [16.5000, 19.5000],
        [17.5000, 19.5000],
        [18.5000, 19.5000],
        [19.5000, 19.5000]])
Last 5 strides:
 tensor([32., 32., 32., 32., 32.])


In [27]:
class MyYolo(nn.Module):
    def __init__(self, version, num_classes=5):
        super().__init__()
        self.backbone = Backbone(version=version)
        self.neck = Neck(version=version)
        self.head = Head(version=version, num_classes=num_classes)
        self.nc = num_classes

    def forward(self, x):
        x = self.backbone(x)              # return out1, out2, out3
        x = self.neck(x[0], x[1], x[2])   # return out_1, out_2, out_3
        return self.head(list(x))


# khởi tạo model với 5 class
model = MyYolo(version='n', num_classes=5)
print(f"{sum(p.numel() for p in model.parameters())/1e6:.2f} million parameters")

2.66 million parameters


# UTIL

In [28]:
# === DEBUG UTILS ===
import torch
from contextlib import contextmanager

DEBUG_ON = True          # Bật/tắt toàn cục
DEBUG_MAX_ELEMS = 5      # In tối đa vài phần tử để đỡ rác

def tstats(name, t, mask=None):
    if not DEBUG_ON: 
        return
    try:
        if mask is not None:
            t = t[mask]
        if t.numel() == 0:
            print(f"[{name}] empty tensor")
            return
        t_det = t.detach()
        print(f"[{name}] shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
              f"min={t_det.min().item():.6f} max={t_det.max().item():.6f} "
              f"mean={t_det.float().mean().item():.6f} sum={t_det.float().sum().item():.6f} "
              f"nnz={(t_det!=0).sum().item()}/{t_det.numel()}")
        # In vài phần tử đầu
        flat = t_det.reshape(-1)
        print(f"  sample: {flat[:min(flat.numel(), DEBUG_MAX_ELEMS)].tolist()}")
        if torch.isnan(t_det).any() or torch.isinf(t_det).any():
            print(f"  WARN: {name} contains NaN/Inf")
    except Exception as e:
        print(f"[{name}] DEBUG ERROR: {e}")

def tuniq(name, t):
    if not DEBUG_ON: 
        return
    try:
        u = torch.unique(t)
        print(f"[{name}] unique({u.numel()}): {u[:min(u.numel(), DEBUG_MAX_ELEMS)].tolist()}"
              + (" ..." if u.numel() > DEBUG_MAX_ELEMS else ""))
    except Exception as e:
        print(f"[{name}] unique() error: {e}")

@contextmanager
def debug_block(title):
    if DEBUG_ON:
        print(f"\n========== DEBUG: {title} ==========")
    yield
    if DEBUG_ON:
        print(f"========== /DEBUG: {title} ==========\n")

In [29]:
import copy
import random
from time import time

import math
import numpy
import torch
import torchvision
from torch.nn.functional import cross_entropy

def setup_seed():
    """
    Setup random seed.
    """
    random.seed(0)
    numpy.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def setup_multi_processes():
    """
    Setup multi-processing environment variables.
    """
    import cv2
    from os import environ
    from platform import system

    # set multiprocess start method as `fork` to speed up the training
    if system() != 'Windows':
        torch.multiprocessing.set_start_method('fork', force=True)

    # disable opencv multithreading to avoid system being overloaded
    cv2.setNumThreads(0)

    # setup OMP threads
    if 'OMP_NUM_THREADS' not in environ:
        environ['OMP_NUM_THREADS'] = '1'

    # setup MKL threads
    if 'MKL_NUM_THREADS' not in environ:
        environ['MKL_NUM_THREADS'] = '1'


def export_onnx(args):
    import onnx  # noqa

    inputs = ['images']
    outputs = ['outputs']
    dynamic = {'outputs': {0: 'batch', 1: 'anchors'}}

    m = torch.load('./weights/best.pt')['model'].float()
    x = torch.zeros((1, 3, args.input_size, args.input_size))

    torch.onnx.export(m.cpu(), x.cpu(),
                      f='./weights/best.onnx',
                      verbose=False,
                      opset_version=12,
                      # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
                      do_constant_folding=True,
                      input_names=inputs,
                      output_names=outputs,
                      dynamic_axes=dynamic or None)

    # Checks
    model_onnx = onnx.load('./weights/best.onnx')  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model

    onnx.save(model_onnx, './weights/best.onnx')
    # Inference example
    # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/autobackend.py


# def wh2xy(x):
#     y = x.clone() if isinstance(x, torch.Tensor) else numpy.copy(x)
#     y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
#     y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
#     y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
#     y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
#     return y

def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
    # Convert nx4 boxes
    # from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = numpy.copy(x)
    y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w  # top left x
    y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h  # top left y
    y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w  # bottom right x
    y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h  # bottom right y
    return y

def make_anchors(x, strides, offset=0.5):
    assert x is not None
    anchor_tensor, stride_tensor = [], []
    dtype, device = x[0].dtype, x[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = x[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + offset  # shift y
        sy, sx = torch.meshgrid(sy, sx)
        anchor_tensor.append(torch.stack((sx, sy), -1).reshape(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_tensor), torch.cat(stride_tensor)


def compute_metric(output, target, iou_v):
    # intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    (a1, a2) = target[:, 1:].unsqueeze(1).chunk(2, 2)
    (b1, b2) = output[:, :4].unsqueeze(0).chunk(2, 2)
    intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
    # IoU = intersection / (area1 + area2 - intersection)
    iou = intersection / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - intersection + 1e-7)

    correct = numpy.zeros((output.shape[0], iou_v.shape[0]))
    correct = correct.astype(bool)
    for i in range(len(iou_v)):
        # IoU > threshold and classes match
        x = torch.where((iou >= iou_v[i]) & (target[:, 0:1] == output[:, 5]))
        if x[0].shape[0]:
            matches = torch.cat((torch.stack(x, 1),
                                 iou[x[0], x[1]][:, None]), 1).cpu().numpy()  # [label, detect, iou]
            if x[0].shape[0] > 1:
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]]
                matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]]
            correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=output.device)


def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.7):
    max_wh = 7680
    max_det = 300
    max_nms = 30000

    bs = outputs.shape[0]  # batch size
    nc = outputs.shape[1] - 4  # number of classes
    xc = outputs[:, 4:4 + nc].amax(1) > confidence_threshold  # candidates

    # Settings
    start = time()
    limit = 0.5 + 0.05 * bs  # seconds to quit after
    output = [torch.zeros((0, 6), device=outputs.device)] * bs
    for index, x in enumerate(outputs):  # image index, image inference
        x = x.transpose(0, -1)[xc[index]]  # confidence

        # If none remain process next image
        if not x.shape[0]:
            continue

        # matrix nx6 (box, confidence, cls)
        box, cls = x.split((4, nc), 1)
        box = wh2xy(box)  # (cx, cy, w, h) to (x1, y1, x2, y2)
        if nc > 1:
            i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]  #Không

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * max_wh  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes, scores
        indices = torchvision.ops.nms(boxes, scores, iou_threshold)  # NMS
        indices = indices[:max_det]  # limit detections

        output[index] = x[indices]
        if (time() - start) > limit:
            break  # time limit exceeded

    return output


def smooth(y, f=0.05):
    # Box filter of fraction f
    nf = round(len(y) * f * 2) // 2 + 1  # number of filter elements (must be odd)
    p = numpy.ones(nf // 2)  # ones padding
    yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0)  # y padded
    return numpy.convolve(yp, numpy.ones(nf) / nf, mode='valid')  # y-smoothed


def compute_ap(tp, conf, pred_cls, target_cls, eps=1e-16):
    """
    Compute the average precision, given the recall and precision curves.
    Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
    # Arguments
        tp:  True positives (nparray, nx1 or nx10).
        conf:  Object-ness value from 0-1 (nparray).
        pred_cls:  Predicted object classes (nparray).
        target_cls:  True object classes (nparray).
    # Returns
        The average precision
    """
    # Sort by object-ness
    i = numpy.argsort(-conf)
    tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]

    # Find unique classes
    unique_classes, nt = numpy.unique(target_cls, return_counts=True)
    nc = unique_classes.shape[0]  # number of classes, number of detections

    # Create Precision-Recall curve and compute AP for each class
    p = numpy.zeros((nc, 1000))
    r = numpy.zeros((nc, 1000))
    ap = numpy.zeros((nc, tp.shape[1]))
    px, py = numpy.linspace(0, 1, 1000), []  # for plotting
    for ci, c in enumerate(unique_classes):
        i = pred_cls == c
        nl = nt[ci]  # number of labels
        no = i.sum()  # number of outputs
        if no == 0 or nl == 0:
            continue

        # Accumulate FPs and TPs
        fpc = (1 - tp[i]).cumsum(0)
        tpc = tp[i].cumsum(0)

        # Recall
        recall = tpc / (nl + eps)  # recall curve
        # negative x, xp because xp decreases
        r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0)

        # Precision
        precision = tpc / (tpc + fpc)  # precision curve
        p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1)  # p at pr_score

        # AP from recall-precision curve
        for j in range(tp.shape[1]):
            m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0]))
            m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0]))

            # Compute the precision envelope
            m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre)))

            # Integrate area under curve
            x = numpy.linspace(0, 1, 101)  # 101-point interp (COCO)
            ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x)  # integrate

    # Compute F1 (harmonic mean of precision and recall)
    f1 = 2 * p * r / (p + r + eps)

    i = smooth(f1.mean(0), 0.1).argmax()  # max F1 index
    p, r, f1 = p[:, i], r[:, i], f1[:, i]
    tp = (r * nt).round()  # true positives
    fp = (tp / (p + eps) - tp).round()  # false positives
    ap50, ap = ap[:, 0], ap.mean(1)  # AP@0.5, AP@0.5:0.95
    m_pre, m_rec = p.mean(), r.mean()
    map50, mean_ap = ap50.mean(), ap.mean()
    return tp, fp, m_pre, m_rec, map50, mean_ap


def compute_iou(box1, box2, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
    b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
    ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
    c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
    rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
    # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
    v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
    with torch.no_grad():
        alpha = v / (v - iou + (1 + eps))

    return iou - (rho2 / c2 + v * alpha)  # CIoU


def strip_optimizer(filename):
    x = torch.load(filename, map_location="cpu")
    x['model'].half()  # to FP16
    for p in x['model'].parameters():
        p.requires_grad = False
    torch.save(x, f=filename)


def clip_gradients(model, max_norm=10.0):
    parameters = model.parameters()
    torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)


def load_weight(model, ckpt):
    dst = model.state_dict()
    src = torch.load(ckpt)['model'].float().cpu()

    ckpt = {}
    for k, v in src.state_dict().items():
        if k in dst and v.shape == dst[k].shape:
            ckpt[k] = v

    model.load_state_dict(state_dict=ckpt, strict=False)
    return model


def set_params(model, decay):
    p1 = []
    p2 = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if param.ndim <= 1 or name.endswith(".bias"):
            p1.append(param)
        else:
            p2.append(param)
    return [{'params': p1, 'weight_decay': 0.00},
            {'params': p2, 'weight_decay': decay}]


def plot_lr(args, optimizer, scheduler, num_steps):
    from matplotlib import pyplot

    optimizer = copy.copy(optimizer)
    scheduler = copy.copy(scheduler)

    y = []
    for epoch in range(args.epochs):
        for i in range(num_steps):
            step = i + num_steps * epoch
            scheduler.step(step, optimizer)
            y.append(optimizer.param_groups[0]['lr'])
    print(y[0])
    print(y[-1])
    pyplot.plot(y, '.-', label='LR')
    pyplot.xlabel('step')
    pyplot.ylabel('LR')
    pyplot.grid()
    pyplot.xlim(0, args.epochs * num_steps)
    pyplot.ylim(0)
    pyplot.savefig('./weights/lr.png', dpi=200)
    pyplot.close()


class CosineLR:
    def __init__(self, args, params, num_steps):
        max_lr = params['max_lr']
        min_lr = params['min_lr']

        warmup_steps = int(max(params['warmup_epochs'] * num_steps, 100))
        decay_steps = int(args.epochs * num_steps - warmup_steps)

        warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps))

        decay_lr = []
        for step in range(1, decay_steps + 1):
            alpha = math.cos(math.pi * step / decay_steps)
            decay_lr.append(min_lr + 0.5 * (max_lr - min_lr) * (1 + alpha))

        self.total_lr = numpy.concatenate((warmup_lr, decay_lr))

    def step(self, step, optimizer):
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.total_lr[step]


class LinearLR:
    def __init__(self, args, params, num_steps):
        max_lr = params['max_lr']
        min_lr = params['min_lr']

        warmup_steps = int(max(params['warmup_epochs'] * num_steps, 100))
        decay_steps = int(args.epochs * num_steps - warmup_steps)

        warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
        decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)

        self.total_lr = numpy.concatenate((warmup_lr, decay_lr))

    def step(self, step, optimizer):
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.total_lr[step]


class EMA:
    """
    Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        # Create EMA
        self.ema = copy.deepcopy(model).eval()  # FP32 EMA
        self.updates = updates  # number of EMA updates
        # decay exponential ramp (to help early epochs)
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        if hasattr(model, 'module'):
            model = model.module
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)

            msd = model.state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()


class AverageMeter:
    def __init__(self):
        self.num = 0
        self.sum = 0
        self.avg = 0

    def update(self, v, n):
        if not math.isnan(float(v)):
            self.num = self.num + n
            self.sum = self.sum + v * n
            self.avg = self.sum / self.num


class Assigner(torch.nn.Module):
    def __init__(self, nc=5, top_k=13, alpha=1.0, beta=6.0, eps=1E-9):
        super().__init__()
        self.top_k = top_k
        self.nc = nc
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    @torch.no_grad()
    def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
        with debug_block("Assigner.forward / inputs"):
            print(f"pd_scores: {tuple(pd_scores.shape)}  (B, A, C)")
            print(f"pd_bboxes: {tuple(pd_bboxes.shape)}  (B, A, 4)")
            print(f"anc_points: {tuple(anc_points.shape)} (A, 2)")
            print(f"gt_labels: {tuple(gt_labels.shape)}  (B, M, 1)")
            print(f"gt_bboxes: {tuple(gt_bboxes.shape)}  (B, M, 4)")
            print(f"mask_gt sum: {mask_gt.sum().item()}")
            tuniq("gt_labels uniq", gt_labels.reshape(-1))
        
        batch_size = pd_scores.size(0)
        num_max_boxes = gt_bboxes.size(1)

        if num_max_boxes == 0:
            device = gt_bboxes.device
            print("Assigner: num_max_boxes==0 -> return zeros")
            return (torch.zeros_like(pd_bboxes).to(device),
                    torch.zeros_like(pd_scores).to(device),
                    torch.zeros_like(pd_scores[..., 0]).to(device))

        num_anchors = anc_points.shape[0]
        shape = gt_bboxes.shape
        lt, rb = gt_bboxes.reshape(-1, 1, 4).chunk(2, 2)
        mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2)
        mask_in_gts = mask_in_gts.reshape(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)

        
        with debug_block("in GT mask"):
            print("mask_in_gts sum:", mask_in_gts.sum().item())
        
        na = pd_bboxes.shape[-2]
        gt_mask = (mask_in_gts * mask_gt).bool()  # b, max_num_obj, h*w
        overlaps = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
        bbox_scores = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)

        ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
        ind[0] = torch.arange(end=batch_size).reshape(-1, 1).expand(-1, num_max_boxes)  # b, max_num_obj
        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
        bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask]  # b, max_num_obj, h*w

        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask]
        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask]
        overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0)

        with debug_block("overlaps & scores"):
            tstats("bbox_scores (selected)", bbox_scores[gt_mask])
            tstats("overlaps (selected)", overlaps[gt_mask])
        

        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
        
        with debug_block("align_metric"):
            tstats("align_metric(all)", align_metric)
            tstats("align_metric(selected)", align_metric[gt_mask])

        top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool()
        top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True)
        if top_k_mask is None:
            top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices)
        top_k_indices.masked_fill_(~top_k_mask, 0)

        with debug_block("top-k"):
            tstats("top_k_metrics", top_k_metrics)
            tstats("top_k_indices", top_k_indices)

        mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device)
        ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device)
        for k in range(self.top_k):
            mask_top_k.scatter_add_(-1, top_k_indices[:, :, k:k + 1], ones)
        mask_top_k.masked_fill_(mask_top_k > 1, 0)
        mask_top_k = mask_top_k.to(align_metric.dtype)
        mask_pos = mask_top_k * mask_in_gts * mask_gt

        with debug_block("positive mask"):
            print("mask_pos sum:", mask_pos.sum().item())
            print("pos per-gt:", mask_pos.sum(-1)[mask_gt.squeeze(-1).bool()].reshape(-1).tolist()[:20])
        
        fg_mask = mask_pos.sum(-2)
        if fg_mask.max() > 1:
            mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1)
            max_overlaps_idx = overlaps.argmax(1)

            is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
            is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)

            mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
            fg_mask = mask_pos.sum(-2)
        target_gt_idx = mask_pos.argmax(-2)

        # Assigned target
        index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None]
        target_index = target_gt_idx + index * num_max_boxes
        target_labels = gt_labels.long().flatten()[target_index]

        target_bboxes = gt_bboxes.reshape(-1, gt_bboxes.shape[-1])[target_index]

        # SỬA
        # labels hợp lệ?
        assert (target_labels >= 0).all() and (target_labels < self.nc).all(), "Assigned labels out of range"

        
        # Assigned target scores
        target_labels.clamp_(min=0, max=self.nc - 1)

        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
                                    #dtype=torch.int64,  #SỬA
                                    dtype=torch.float32,
                                    device=target_labels.device)
        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)

        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc)
        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)

        # Normalize
        align_metric *= mask_pos
        pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
        pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
        target_scores = target_scores * norm_align_metric

        
        with debug_block("assigner outputs"):
            tstats("fg_mask", fg_mask)
            tstats("target_labels", target_labels)
            tstats("norm_align_metric", norm_align_metric)
            print("target_scores > 0:", (target_scores > 0).sum().item())
        
        return target_bboxes, target_scores, fg_mask.bool()


class QFL(torch.nn.Module):
    def __init__(self, beta=2.0):
        super().__init__()
        self.beta = beta
        self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, outputs, targets):
        bce_loss = self.bce_loss(outputs, targets)
        return torch.pow(torch.abs(targets - outputs.sigmoid()), self.beta) * bce_loss


class VFL(torch.nn.Module):
    def __init__(self, alpha=0.75, gamma=2.00, iou_weighted=True):
        super().__init__()
        assert alpha >= 0.0
        self.alpha = alpha
        self.gamma = gamma
        self.iou_weighted = iou_weighted
        self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, outputs, targets):
        assert outputs.size() == targets.size()
        targets = targets.type_as(outputs)

        if self.iou_weighted:
            focal_weight = targets * (targets > 0.0).float() + \
                           self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * \
                           (targets <= 0.0).float()

        else:
            focal_weight = (targets > 0.0).float() + \
                           self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * \
                           (targets <= 0.0).float()

        return self.bce_loss(outputs, targets) * focal_weight


class BoxLoss(torch.nn.Module):
    def __init__(self, dfl_ch):
        super().__init__()
        self.dfl_ch = dfl_ch

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        with debug_block("BoxLoss.forward / inputs"):
            print("fg_mask sum:", fg_mask.sum().item())
            tstats("target_scores sum per pos", torch.masked_select(target_scores.sum(-1), fg_mask))
            tstats("pred_bboxes(pos)", pred_bboxes[fg_mask])
            tstats("target_bboxes(pos)", target_bboxes[fg_mask])
            
        # IoU loss
        weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
        iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
        loss_box = ((1.0 - iou) * weight).sum() / target_scores_sum

        with debug_block("IoU part"):
            tstats("IoU(pos)", iou)
            print("loss_box(partial):", float(loss_box.detach().cpu()))

        # DFL loss
        a, b = target_bboxes.chunk(2, -1)


        # SỬA: broadcast anchor_points theo batch để khớp fg_mask
        batch_size = fg_mask.shape[0]                          
        anchors_batched = anchor_points.unsqueeze(0).expand(batch_size, -1, -1)  # [B, num_anchors, 2]

        # SỬA: dùng anchors_batched thay cho anchor_points
        target = torch.cat((anchors_batched - a, b - anchors_batched), -1)
        target = target.clamp(0, self.dfl_ch - 0.01)

        with debug_block("DFL target build"):
            tstats("anchor_points(pos)", anchors_batched[fg_mask])  # SỬA: dùng anchors_batched
            tstats("target distances(pos)", target[fg_mask])
            print("dfl_ch:", self.dfl_ch)
        
        loss_dfl = self.df_loss(pred_dist[fg_mask].reshape(-1, self.dfl_ch + 1), target[fg_mask])
        loss_dfl = (loss_dfl * weight).sum() / target_scores_sum


        with debug_block("DFL final"):
            tstats("loss_dfl(per-pos)", loss_dfl.unsqueeze(0))
            print("loss_dfl:", float(loss_dfl.detach().cpu()))
        
        return loss_box, loss_dfl

    @staticmethod
    def df_loss(pred_dist, target):
        # Distribution Focal Loss (DFL)
        # https://ieeexplore.ieee.org/document/9792391
        tl = target.long()  # target left
        tr = tl + 1  # target right
        wl = tr - target  # weight left
        wr = 1 - wl  # weight right
        left_loss = cross_entropy(pred_dist, tl.reshape(-1), reduction='none').reshape(tl.shape)
        right_loss = cross_entropy(pred_dist, tr.reshape(-1), reduction='none').reshape(tl.shape)
        return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True)


class ComputeLoss:
    def __init__(self, model, params):
        if hasattr(model, 'module'):
            model = model.module

        device = next(model.parameters()).device

        m = model.head  # Head() module

        self.params = params
        self.stride = m.stride
        self.nc = m.nc
        self.no = m.no
        self.reg_max = m.ch
        self.device = device

        self.box_loss = BoxLoss(m.ch - 1).to(device)
        self.cls_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        self.assigner = Assigner(nc=self.nc, top_k=10, alpha=0.5, beta=6.0)

        self.project = torch.arange(m.ch, dtype=torch.float, device=device)

    def box_decode(self, anchor_points, pred_dist):
        b, a, c = pred_dist.shape
        pred_dist = pred_dist.reshape(b, a, 4, c // 4)
        pred_dist = pred_dist.softmax(3)
        pred_dist = pred_dist.matmul(self.project.type(pred_dist.dtype))
        lt, rb = pred_dist.chunk(2, -1)
        x1y1 = anchor_points - lt
        x2y2 = anchor_points + rb
        return torch.cat(tensors=(x1y1, x2y2), dim=-1)

    def __call__(self, outputs, targets):
        #*****************************
        with debug_block("ComputeLoss.__call__ / inputs"):
            # Kiểm tra outputs
            print("num.feature levels:", len(outputs))
            for li, o in enumerate(outputs):
                print(f"  L{li} shape={tuple(o.shape)}")  # (B, no, H, W)
    
            # Kiểm tra targets dict
            print("targets keys:", list(targets.keys()))
            tstats("targets['idx']", targets['idx'])
            tstats("targets['cls']", targets['cls'])
            tstats("targets['box']", targets['box'])
            if 'cls' in targets:
                tuniq("targets['cls'] uniq", targets['cls'])
        #*****************************************




        
        x = torch.cat([i.reshape(outputs[0].shape[0], self.no, -1) for i in outputs], dim=2)
        pred_distri, pred_scores = x.split(split_size=(self.reg_max * 4, self.nc), dim=1)

        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
        pred_distri = pred_distri.permute(0, 2, 1).contiguous()



        
        #***********************************     
        with debug_block("pred tensors"):
            tstats("pred_scores(logits)", pred_scores)
            tstats("pred_scores(sigmoid)", pred_scores.sigmoid())
            tstats("pred_distri", pred_distri)
        #*********************************



        
        data_type = pred_scores.dtype
        batch_size = pred_scores.shape[0]
        input_size = torch.tensor(outputs[0].shape[2:], device=self.device, dtype=data_type) * self.stride[0]
        anchor_points, stride_tensor = make_anchors(outputs, self.stride, offset=0.5)





        #***********************************
        with debug_block("anchors"):
            tstats("anchor_points", anchor_points)
            tstats("stride_tensor", stride_tensor)
        #*****************************




        idx = targets['idx'].reshape(-1, 1)
        cls = targets['cls'].reshape(-1, 1)
        box = targets['box']



        #***************************
        # Sanity: class hợp lệ?
        assert (cls >= 0).all(), "Found negative class id"
        assert (cls < self.nc).all(), f"Found class id >= nc ({self.nc})"
        #***************************






        targets = torch.cat((idx, cls, box), dim=1).to(self.device)
        if targets.shape[0] == 0:
            gt = torch.zeros(batch_size, 0, 5, device=self.device)
        else:
            i = targets[:, 0]
            _, counts = i.unique(return_counts=True)
            counts = counts.to(dtype=torch.int32)
            gt = torch.zeros(batch_size, counts.max(), 5, device=self.device)
            for j in range(batch_size):
                matches = i == j
                n = matches.sum()
                if n:
                    gt[j, :n] = targets[matches, 1:]
            x = gt[..., 1:5].mul_(input_size[[1, 0, 1, 0]])
            y = torch.empty_like(x)
            dw = x[..., 2] / 2  # half-width
            dh = x[..., 3] / 2  # half-height
            y[..., 0] = x[..., 0] - dw  # top left x
            y[..., 1] = x[..., 1] - dh  # top left y
            y[..., 2] = x[..., 0] + dw  # bottom right x
            y[..., 3] = x[..., 1] + dh  # bottom right y
            gt[..., 1:5] = y
        gt_labels, gt_bboxes = gt.split((1, 4), 2)
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)

        #***************************
        with debug_block("GT build"):
            tstats("gt_labels", gt_labels)
            tuniq("gt_labels uniq", gt_labels.reshape(-1))
            tstats("gt_bboxes", gt_bboxes)
            print("mask_gt sum:", mask_gt.sum().item())
        #***************************





        pred_bboxes = self.box_decode(anchor_points, pred_distri)


        #***************************
        with debug_block("decoded boxes"):
            tstats("pred_bboxes(decoded)", pred_bboxes)
        #***************************


        assigned_targets = self.assigner(pred_scores.detach().sigmoid(),
                                         (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
                                         anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
        target_bboxes, target_scores, fg_mask = assigned_targets


        #***************************
        with debug_block("after assigner"):
            tstats("target_bboxes", target_bboxes)
            tstats("target_scores", target_scores)
            print("target_scores > 0:", (target_scores > 0).sum().item())
            print("fg_mask sum:", fg_mask.sum().item())
        #***************************


        # # SỬA: Quan trọng: dùng clamp để đảm bảo tensor cùng device, tránh Python int.
        target_scores_sum = target_scores.sum().clamp(min=1.0)

        loss_cls = self.cls_loss(pred_scores, target_scores.to(data_type)).sum() / target_scores_sum  # BCE

        # #SỬA: scale lại cls loss theo số lượng foreground anchors nếu muốn
        # fg_mask_cls = target_scores.sum(-1) > 0
        # pred_scores_pos = pred_scores[fg_mask_cls]
        # target_scores_pos = target_scores[fg_mask_cls]
        
        # if fg_mask_cls.sum() > 0:
        #     loss_cls = self.cls_loss(pred_scores_pos, target_scores_pos).sum() / fg_mask_cls.sum()  # sum()/num_foreground
        # else:
        #     loss_cls = torch.zeros(1, device=pred_scores.device)

        # SỐ foreground anchors (tính theo mask) #SỬA
        #num_fg_anchors = fg_mask.sum().clamp(min=1.0)  
        
        # Classification loss (BCE) #SỬA
        # Dùng mean trên lớp, sau đó sum trên foreground anchors
        #loss_cls = (self.cls_loss(pred_scores, target_scores.to(pred_scores.dtype)).sum(dim=-1)  # sum over nc
                    #[fg_mask]).sum() / num_fg_anchors

        with debug_block("cls loss"):
            print("target_scores_sum:", float(target_scores_sum.detach().cpu()))
            tstats("BCE elem", self.cls_loss(pred_scores, target_scores.to(data_type)))
            print("loss_cls:", float(loss_cls.detach().cpu()))

        # Box loss
        loss_box = torch.zeros(1, device=self.device)
        loss_dfl = torch.zeros(1, device=self.device)
        if fg_mask.sum():
            target_bboxes /= stride_tensor
            loss_box, loss_dfl = self.box_loss(pred_distri,
                                               pred_bboxes,
                                               anchor_points,
                                               target_bboxes,
                                               target_scores,
                                               target_scores_sum, fg_mask)
        else:
            print("NOTE: fg_mask.sum()==0 => loss_box=0, loss_dfl=0")


        # loss_box đã tính xong từ BoxLoss.forward
        print("loss_box (before gain):", float(loss_box.detach().cpu()))
        print("box_gain:", self.params['box'])
        
        loss_box *= self.params['box']  # box gain
        loss_cls *= self.params['cls']  # cls gain
        loss_dfl *= self.params['dfl']  # dfl gain

        with debug_block("final losses (after gain)"):
            print(f"loss_box={float(loss_box.detach().cpu()):.6f} "
                  f"loss_cls={float(loss_cls.detach().cpu()):.6f} "
                  f"loss_dfl={float(loss_dfl.detach().cpu()):.6f}")
        
        return loss_box, loss_cls, loss_dfl

# DATASET

In [30]:
import math
import os
import random

import cv2
import numpy
import torch
from PIL import Image
from torch.utils import data

FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp'

class Dataset(data.Dataset):
    def __init__(self, filenames, input_size, params, augment):
        self.params = params
        self.mosaic = augment
        self.augment = augment
        self.input_size = input_size

        # Read labels
        labels = self.load_label(filenames)
        self.labels = list(labels.values())
        self.filenames = list(labels.keys())  # update
        self.n = len(self.filenames)  # number of samples
        self.indices = range(self.n)
        # Albumentations (optional, only used if package is installed)
        self.albumentations = Albumentations()

    def __getitem__(self, index):
        index = self.indices[index]

        params = self.params
        mosaic = self.mosaic and random.random() < params['mosaic']

        if mosaic:
            # Load MOSAIC
            image, label = self.load_mosaic(index, params)
            # MixUp augmentation
            if random.random() < params['mix_up']:
                index = random.choice(self.indices)
                mix_image1, mix_label1 = image, label
                mix_image2, mix_label2 = self.load_mosaic(index, params)

                image, label = mix_up(mix_image1, mix_label1, mix_image2, mix_label2)
        else:
            # Load image
            image, shape = self.load_image(index)
            h, w = image.shape[:2]

            # Resize
            image, ratio, pad = resize(image, self.input_size, self.augment)

            label = self.labels[index].copy()
            if label.size:
                label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
            if self.augment:
                image, label = random_perspective(image, label, params)

        nl = len(label)  # number of labels
        h, w = image.shape[:2]
        cls = label[:, 0:1]
        box = label[:, 1:5]
        box = xy2wh(box, w, h)

        if self.augment:
            # Albumentations
            image, box, cls = self.albumentations(image, box, cls)
            nl = len(box)  # update after albumentations
            # HSV color-space
            augment_hsv(image, params)
            # Flip up-down
            if random.random() < params['flip_ud']:
                image = numpy.flipud(image)
                if nl:
                    box[:, 1] = 1 - box[:, 1]
            # Flip left-right
            if random.random() < params['flip_lr']:
                image = numpy.fliplr(image)
                if nl:
                    box[:, 0] = 1 - box[:, 0]

        target_cls = torch.zeros((nl, 1))
        target_box = torch.zeros((nl, 4))
        if nl:
            target_cls = torch.from_numpy(cls)
            target_box = torch.from_numpy(box)

        # Convert HWC to CHW, BGR to RGB
        sample = image.transpose((2, 0, 1))[::-1]
        sample = numpy.ascontiguousarray(sample)

        return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)

    def __len__(self):
        return len(self.filenames)

    def load_image(self, i):
        image = cv2.imread(self.filenames[i])
        h, w = image.shape[:2]
        r = self.input_size / max(h, w)
        if r != 1:
            image = cv2.resize(image,
                               dsize=(int(w * r), int(h * r)),
                               interpolation=resample() if self.augment else cv2.INTER_LINEAR)
        return image, (h, w)

    def load_mosaic(self, index, params):
        label4 = []
        border = [-self.input_size // 2, -self.input_size // 2]
        image4 = numpy.full((self.input_size * 2, self.input_size * 2, 3), 0, dtype=numpy.uint8)
        y1a, y2a, x1a, x2a, y1b, y2b, x1b, x2b = (None, None, None, None, None, None, None, None)

        xc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
        yc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))

        indices = [index] + random.choices(self.indices, k=3)
        random.shuffle(indices)

        for i, index in enumerate(indices):
            # Load image
            image, _ = self.load_image(index)
            shape = image.shape
            if i == 0:  # top left
                x1a = max(xc - shape[1], 0)
                y1a = max(yc - shape[0], 0)
                x2a = xc
                y2a = yc
                x1b = shape[1] - (x2a - x1a)
                y1b = shape[0] - (y2a - y1a)
                x2b = shape[1]
                y2b = shape[0]
            if i == 1:  # top right
                x1a = xc
                y1a = max(yc - shape[0], 0)
                x2a = min(xc + shape[1], self.input_size * 2)
                y2a = yc
                x1b = 0
                y1b = shape[0] - (y2a - y1a)
                x2b = min(shape[1], x2a - x1a)
                y2b = shape[0]
            if i == 2:  # bottom left
                x1a = max(xc - shape[1], 0)
                y1a = yc
                x2a = xc
                y2a = min(self.input_size * 2, yc + shape[0])
                x1b = shape[1] - (x2a - x1a)
                y1b = 0
                x2b = shape[1]
                y2b = min(y2a - y1a, shape[0])
            if i == 3:  # bottom right
                x1a = xc
                y1a = yc
                x2a = min(xc + shape[1], self.input_size * 2)
                y2a = min(self.input_size * 2, yc + shape[0])
                x1b = 0
                y1b = 0
                x2b = min(shape[1], x2a - x1a)
                y2b = min(y2a - y1a, shape[0])

            pad_w = x1a - x1b
            pad_h = y1a - y1b
            image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]

            # Labels
            label = self.labels[index].copy()
            if len(label):
                label[:, 1:] = wh2xy(label[:, 1:], shape[1], shape[0], pad_w, pad_h)
            label4.append(label)

        # Concat/clip labels
        label4 = numpy.concatenate(label4, 0)
        for x in label4[:, 1:]:
            numpy.clip(x, 0, 2 * self.input_size, out=x)

        # Augment
        image4, label4 = random_perspective(image4, label4, params, border)

        return image4, label4

    @staticmethod
    def collate_fn(batch):
        samples, cls, box, indices = zip(*batch)

        cls = torch.cat(cls, dim=0)
        box = torch.cat(box, dim=0)

        new_indices = list(indices)
        for i in range(len(indices)):
            new_indices[i] += i
        indices = torch.cat(new_indices, dim=0)

        targets = {'cls': cls,
                   'box': box,
                   'idx': indices}
        return torch.stack(samples, dim=0), targets

    @staticmethod
    def load_label(filenames):
        x = {}
        for filename in filenames:
            try:
                # verify images
                with open(filename, 'rb') as f:
                    image = Image.open(f)
                    image.verify()  # PIL verify
                shape = image.size  # image size
                assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
                assert image.format.lower() in FORMATS, f'invalid image format {image.format}'

                # verify labels
                a = f'{os.sep}images{os.sep}'
                b = f'{os.sep}labels{os.sep}'
                label_path = b.join(filename.rsplit(a, 1)).rsplit('.', 1)[0] + '.txt'
                if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit('.', 1)[0] + '.txt'):
                    with open(label_path) as f:
                        label = [x.split() for x in f.read().strip().splitlines() if len(x)]
                        label = numpy.array(label, dtype=numpy.float32)
                    nl = len(label)
                    if nl:
                        assert (label >= 0).all()
                        assert label.shape[1] == 5
                        assert (label[:, 1:] <= 1).all()
                        _, i = numpy.unique(label, axis=0, return_index=True)
                        if len(i) < nl:  # duplicate row check
                            label = label[i]  # remove duplicates
                    else:
                        label = numpy.zeros((0, 5), dtype=numpy.float32)
                else:
                    label = numpy.zeros((0, 5), dtype=numpy.float32)
            except FileNotFoundError:
                label = numpy.zeros((0, 5), dtype=numpy.float32)
            except AssertionError:
                continue
            x[filename] = label
        return x


def wh2xy(x, w, h, pad_w=0, pad_h=0):
    # Convert nx4 boxes
    # from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = numpy.copy(x)
    y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w  # top left x
    y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h  # top left y
    y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w  # bottom right x
    y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h  # bottom right y
    return y


def xy2wh(x, w, h):
    # warning: inplace clip
    x[:, [0, 2]] = x[:, [0, 2]].clip(0, w - 1E-3)  # x1, x2
    x[:, [1, 3]] = x[:, [1, 3]].clip(0, h - 1E-3)  # y1, y2

    # Convert nx4 boxes
    # from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
    y = numpy.copy(x)
    y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w  # x center
    y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h  # y center
    y[:, 2] = (x[:, 2] - x[:, 0]) / w  # width
    y[:, 3] = (x[:, 3] - x[:, 1]) / h  # height
    return y


def resample():
    choices = (cv2.INTER_AREA,
               cv2.INTER_CUBIC,
               cv2.INTER_LINEAR,
               cv2.INTER_NEAREST,
               cv2.INTER_LANCZOS4)
    return random.choice(seq=choices)


def augment_hsv(image, params):
    # HSV color-space augmentation
    h = params['hsv_h']
    s = params['hsv_s']
    v = params['hsv_v']

    r = numpy.random.uniform(-1, 1, 3) * [h, s, v] + 1
    h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV))

    x = numpy.arange(0, 256, dtype=r.dtype)
    lut_h = ((x * r[0]) % 180).astype('uint8')
    lut_s = numpy.clip(x * r[1], 0, 255).astype('uint8')
    lut_v = numpy.clip(x * r[2], 0, 255).astype('uint8')

    hsv = cv2.merge((cv2.LUT(h, lut_h), cv2.LUT(s, lut_s), cv2.LUT(v, lut_v)))
    cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR, dst=image)  # no return needed


def resize(image, input_size, augment):
    # Resize and pad image while meeting stride-multiple constraints
    shape = image.shape[:2]  # current shape [height, width]

    # Scale ratio (new / old)
    r = min(input_size / shape[0], input_size / shape[1])
    if not augment:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)

    # Compute padding
    pad = int(round(shape[1] * r)), int(round(shape[0] * r))
    w = (input_size - pad[0]) / 2
    h = (input_size - pad[1]) / 2

    if shape[::-1] != pad:  # resize
        image = cv2.resize(image,
                           dsize=pad,
                           interpolation=resample() if augment else cv2.INTER_LINEAR)
    top, bottom = int(round(h - 0.1)), int(round(h + 0.1))
    left, right = int(round(w - 0.1)), int(round(w + 0.1))
    image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT)  # add border
    return image, (r, r), (w, h)


def candidates(box1, box2):
    # box1(4,n), box2(4,n)
    w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
    w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
    aspect_ratio = numpy.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16))  # aspect ratio
    return (w2 > 2) & (h2 > 2) & (w2 * h2 / (w1 * h1 + 1e-16) > 0.1) & (aspect_ratio < 100)


def random_perspective(image, label, params, border=(0, 0)):
    h = image.shape[0] + border[0] * 2
    w = image.shape[1] + border[1] * 2

    # Center
    center = numpy.eye(3)
    center[0, 2] = -image.shape[1] / 2  # x translation (pixels)
    center[1, 2] = -image.shape[0] / 2  # y translation (pixels)

    # Perspective
    perspective = numpy.eye(3)

    # Rotation and Scale
    rotate = numpy.eye(3)
    a = random.uniform(-params['degrees'], params['degrees'])
    s = random.uniform(1 - params['scale'], 1 + params['scale'])
    rotate[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)

    # Shear
    shear = numpy.eye(3)
    shear[0, 1] = math.tan(random.uniform(-params['shear'], params['shear']) * math.pi / 180)
    shear[1, 0] = math.tan(random.uniform(-params['shear'], params['shear']) * math.pi / 180)

    # Translation
    translate = numpy.eye(3)
    translate[0, 2] = random.uniform(0.5 - params['translate'], 0.5 + params['translate']) * w
    translate[1, 2] = random.uniform(0.5 - params['translate'], 0.5 + params['translate']) * h

    # Combined rotation matrix, order of operations (right to left) is IMPORTANT
    matrix = translate @ shear @ rotate @ perspective @ center
    if (border[0] != 0) or (border[1] != 0) or (matrix != numpy.eye(3)).any():  # image changed
        image = cv2.warpAffine(image, matrix[:2], dsize=(w, h), borderValue=(0, 0, 0))

    # Transform label coordinates
    n = len(label)
    if n:
        xy = numpy.ones((n * 4, 3))
        xy[:, :2] = label[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1
        xy = xy @ matrix.T  # transform
        xy = xy[:, :2].reshape(n, 8)  # perspective rescale or affine

        # create new boxes
        x = xy[:, [0, 2, 4, 6]]
        y = xy[:, [1, 3, 5, 7]]
        box = numpy.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T

        # clip
        box[:, [0, 2]] = box[:, [0, 2]].clip(0, w)
        box[:, [1, 3]] = box[:, [1, 3]].clip(0, h)
        # filter candidates
        indices = candidates(box1=label[:, 1:5].T * s, box2=box.T)

        label = label[indices]
        label[:, 1:5] = box[indices]

    return image, label


def mix_up(image1, label1, image2, label2):
    # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
    alpha = numpy.random.beta(a=32.0, b=32.0)  # mix-up ratio, alpha=beta=32.0
    image = (image1 * alpha + image2 * (1 - alpha)).astype(numpy.uint8)
    label = numpy.concatenate((label1, label2), 0)
    return image, label


class Albumentations:
    def __init__(self):
        self.transform = None
        try:
            import albumentations

            transforms = [albumentations.Blur(p=0.01),
                          albumentations.CLAHE(p=0.01),
                          albumentations.ToGray(p=0.01),
                          albumentations.MedianBlur(p=0.01)]
            self.transform = albumentations.Compose(transforms,
                                                    albumentations.BboxParams('yolo', ['class_labels']))

        except ImportError:  # package not installed, skip
            pass

    def __call__(self, image, box, cls):
        if self.transform:
            x = self.transform(image=image,
                               bboxes=box,
                               class_labels=cls)
            image = x['image']
            box = numpy.array(x['bboxes'])
            cls = numpy.array(x['class_labels'])
        return image, box, cls

In [31]:
# PARAMS
params = {
    'min_lr': 0.0001,
    'max_lr': 0.01,
    'momentum': 0.937,
    'weight_decay': 0.0005,
    'warmup_epochs': 3.0,
    'box': 7.5,
    'cls': 0.5,
    'dfl': 1.5,

    # --- Tắt augmentation ---
    'hsv_h': 0.0,
    'hsv_s': 0.0,
    'hsv_v': 0.0,
    'degrees': 0.0,
    'translate': 0.0,
    'scale': 1.0,
    'shear': 0.0,
    'flip_ud': 0.0,
    'flip_lr': 0.0,
    'mosaic': 0.0,
    'mix_up': 0.0,

    # --- Dataset ---
    'nc': 5,
    'names': ['Elephant', 'Giraffe', 'Leopard', 'Lion', 'Zebra']
}

In [32]:
train_dir = os.path.join(DATASET_PATH, "train", "images")

filenames_train = []
for filename in os.listdir(train_dir):
    if filename.endswith(('.jpg', '.png', '.jpeg')):
        filenames_train.append(os.path.join(train_dir, filename))

input_size = 640

# Tạo Dataset cho tập train
train_data = Dataset(
    filenames_train,
    input_size,
    params,   # đã được định nghĩa ở cell trước
    augment=False   # False = không dùng augmentation
)

# DataLoader
train_loader = DataLoader(
    train_data,
    batch_size=16,
    num_workers=0,
    pin_memory=True,
    collate_fn=Dataset.collate_fn
)

print(f"Train_loader : {len(train_loader)} batches")

Train_loader : 562 batches


Got processor for bboxes, but no transform to process it.


In [33]:
batch=next(iter(train_loader))
print("All keys in batch      : ", batch[1].keys())
print(f"Input batch shape      : ", batch[0].shape)
print(f"Classification scores  : {batch[1]['cls'].shape}")
print(f"Box coordinates        : {batch[1]['box'].shape}")
print(f"Index identifier (which score belongs to which image): {batch[1]['idx'].shape}")

All keys in batch      :  dict_keys(['cls', 'box', 'idx'])
Input batch shape      :  torch.Size([16, 3, 640, 640])
Classification scores  : torch.Size([26, 1])
Box coordinates        : torch.Size([26, 4])
Index identifier (which score belongs to which image): torch.Size([26])


In [None]:
from tqdm import tqdm
torch.manual_seed(1337)

# model, loss and optimizer
model = MyYolo(version='n')
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")
print(f"Number of classes (nc): {model.nc}")

criterion = ComputeLoss(model, params)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

num_epochs = 10

model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for i, (imgs, targets) in pbar:
        imgs = imgs.float()
        outputs = model(imgs)
        
        # unpack loss
        box_loss, cls_loss, dfl_loss = criterion(outputs, targets)

        print(f"Batch {i}: cls={cls_loss.item()}, box={box_loss.item()}, dfl={dfl_loss.item()}")
        if cls_loss.item() == 0 or dfl_loss.item() == 0:
            print(">>> Targets:", targets)
            break  # hoặc tiếp tục nếu muốn kiểm tra nhiều batch
        
        loss = cls_loss + box_loss + dfl_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()

        # update progress bar
        pbar.set_postfix({
            "Total": f"{loss.item():.4f}",
            "Cls": f"{cls_loss.item():.4f}",
            "Box": f"{box_loss.item():.4f}",
            "DFL": f"{dfl_loss.item():.4f}"
        })
    
    print(f"Epoch {epoch+1}/{num_epochs} | Avg Loss: {epoch_loss/len(train_loader):.4f}")

2.662425 million parameters
Number of classes (nc): 5


Epoch 1/10:   0%|          | 0/562 [00:00<?, ?it/s]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(26,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.115385 sum=185.000000 nnz=25/26
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(26, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.538462 sum=40.000000 nnz=19/26
  sample: [0.0, 1.0, 0.0, 4.0, 4.0]
[targets['box']] shape=(26, 4) dtype=torch.float32 device=cpu min=0.091406 max=0.999998 mean=0.538792 sum=56.034370 nnz=104/104
  sample: [0.543749988079071, 0.514843761920929, 0.909375011920929, 0.965624988079071, 0.4242187440395355]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-8.163302 max=11.168692 mean=-0.165634 sum=-111305.796875 nnz=672000/672000
  sample: [-0.4034656882286072, -0.2890910804271698, -0.13177458941936493, -0.189820

torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:3637.)



[pred_bboxes(decoded)] shape=(16, 8400, 4) dtype=torch.float32 device=cpu min=-11.639657 max=90.763435 mean=34.838608 sum=18729236.000000 nnz=537600/537600
  sample: [-6.97528600692749, -7.008828163146973, 8.085697174072266, 7.9957709312438965, -6.034775733947754]


pd_scores: (16, 8400, 5)  (B, A, C)
pd_bboxes: (16, 8400, 4)  (B, A, 4)
anc_points: (8400, 2) (A, 2)
gt_labels: (16, 5, 1)  (B, M, 1)
gt_bboxes: (16, 5, 4)  (B, M, 4)
mask_gt sum: 26.0
[gt_labels uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


mask_in_gts sum: 74912.0


[bbox_scores (selected)] shape=(74912,) dtype=torch.float32 device=cpu min=0.001540 max=0.994088 mean=0.456062 sum=34164.500000 nnz=74912/74912
  sample: [0.49897482991218567, 0.5829175710678101, 0.44716718792915344, 0.44952020049095154, 0.4654175639152527]
[overlaps (selected)] shape=(74912,) dtype=torch.float32 device=cpu min=0.000000 max=0.923846 mean=0.066017 sum=4945.435547 nnz=43236/74912
  sample: [0.0, 0.0, 0.0, 0.0, 0.0]


[align_metric(all)] shape=(1

Epoch 1/10:   0%|          | 1/562 [00:07<1:13:58,  7.91s/it, Total=1594.1926, Cls=1587.3511, Box=2.5981, DFL=4.2435]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(19,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.105263 sum=154.000000 nnz=18/19
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(19, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.105263 sum=40.000000 nnz=18/19
  sample: [1.0, 4.0, 1.0, 3.0, 4.0]
[targets['box']] shape=(19, 4) dtype=torch.float32 device=cpu min=0.092188 max=0.999998 mean=0.541745 sum=41.172642 nnz=76/76
  sample: [0.42109376192092896, 0.6328117251396179, 0.8421875238418579, 0.7343734502792358, 0.559374988079071]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-7.428193 max=10.100890 mean=-0.172083 sum=-115639.640625 nnz=672000/672000
  sample: [-0.38518571853637695, -0.35229185223579407, -0.30055686831474304, -0.22

Epoch 1/10:   0%|          | 2/562 [00:13<1:00:40,  6.50s/it, Total=2016.8875, Cls=2010.2894, Box=2.3600, DFL=4.2381]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(17,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.588235 sum=129.000000 nnz=16/17
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(17, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.647059 sum=28.000000 nnz=13/17
  sample: [1.0, 3.0, 2.0, 4.0, 3.0]
[targets['box']] shape=(17, 4) dtype=torch.float32 device=cpu min=0.209375 max=0.999998 mean=0.581583 sum=39.547649 nnz=68/68
  sample: [0.4281249940395355, 0.53125, 0.43437498807907104, 0.596875011920929, 0.49609375]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-8.889069 max=16.134869 mean=-0.181981 sum=-122291.281250 nnz=672000/672000
  sample: [-0.4038163423538208, -0.43105393648147583, -0.3366694450378418, -0.22900769114494324, -0.0

Epoch 1/10:   1%|          | 3/562 [00:18<53:10,  5.71s/it, Total=2241.7019, Cls=2234.9246, Box=2.5981, DFL=4.1792]  


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(26,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.500000 sum=169.000000 nnz=24/26
  sample: [0.0, 0.0, 1.0, 2.0, 3.0]
[targets['cls']] shape=(26, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=3.000000 sum=78.000000 nnz=24/26
  sample: [4.0, 4.0, 2.0, 0.0, 4.0]
[targets['box']] shape=(26, 4) dtype=torch.float32 device=cpu min=0.081250 max=0.998438 mean=0.531460 sum=55.271873 nnz=104/104
  sample: [0.504687488079071, 0.5023437738418579, 0.984375, 0.9750000238418579, 0.5234375]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-8.113626 max=16.579485 mean=-0.188333 sum=-126559.460938 nnz=672000/672000
  sample: [-0.3715493977069855, -0.2806432545185089, -0.1836024671792984, -0.22017450630664825, 0.07

Epoch 1/10:   1%|          | 4/562 [00:22<47:57,  5.16s/it, Total=1578.9387, Cls=1572.0312, Box=2.7221, DFL=4.1853]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(19,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.789474 sum=129.000000 nnz=18/19
  sample: [0.0, 1.0, 1.0, 1.0, 2.0]
[targets['cls']] shape=(19, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.684211 sum=32.000000 nnz=16/19
  sample: [2.0, 1.0, 1.0, 1.0, 1.0]
[targets['box']] shape=(19, 4) dtype=torch.float32 device=cpu min=0.085938 max=0.998438 mean=0.606764 sum=46.114044 nnz=76/76
  sample: [0.47187501192092896, 0.4984374940395355, 0.7562500238418579, 0.996874988079071, 0.42500001192092896]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-11.683716 max=21.389832 mean=-0.192314 sum=-129234.671875 nnz=672000/672000
  sample: [-0.3442990183830261, -0.25384920835494995, -0.22569595277309418, -0.2

Epoch 1/10:   1%|          | 5/562 [00:26<45:17,  4.88s/it, Total=1992.5341, Cls=1985.6359, Box=2.7150, DFL=4.1832]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(24,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.666667 sum=184.000000 nnz=23/24
  sample: [0.0, 1.0, 2.0, 2.0, 2.0]
[targets['cls']] shape=(24, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.875000 sum=69.000000 nnz=21/24
  sample: [1.0, 3.0, 4.0, 4.0, 4.0]
[targets['box']] shape=(24, 4) dtype=torch.float32 device=cpu min=0.100000 max=0.999998 mean=0.530062 sum=50.885929 nnz=96/96
  sample: [0.4085937440395355, 0.5195304751396179, 0.542187511920929, 0.9609359502792358, 0.3179687559604645]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-11.462031 max=24.948408 mean=-0.196233 sum=-131868.750000 nnz=672000/672000
  sample: [-0.342987596988678, -0.2532123923301697, -0.21465465426445007, -0.27299

Epoch 1/10:   1%|          | 6/562 [00:30<42:47,  4.62s/it, Total=1756.1455, Cls=1749.1476, Box=2.5304, DFL=4.4675]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(22,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.909091 sum=174.000000 nnz=21/22
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(22, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.863636 sum=41.000000 nnz=18/22
  sample: [2.0, 2.0, 0.0, 0.0, 0.0]
[targets['box']] shape=(22, 4) dtype=torch.float32 device=cpu min=0.086719 max=0.999998 mean=0.542791 sum=47.765617 nnz=88/88
  sample: [0.546875, 0.3414062559604645, 0.651562511920929, 0.567187488079071, 0.5179687738418579]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-12.570154 max=26.040228 mean=-0.200149 sum=-134500.437500 nnz=672000/672000
  sample: [-0.3505748510360718, -0.26628997921943665, -0.21652087569236755, -0.26222741603851

Epoch 1/10:   1%|          | 7/562 [00:35<42:12,  4.56s/it, Total=1597.4736, Cls=1590.7538, Box=2.5385, DFL=4.1812]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(23,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.869565 sum=158.000000 nnz=22/23
  sample: [0.0, 1.0, 2.0, 2.0, 2.0]
[targets['cls']] shape=(23, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.956522 sum=45.000000 nnz=21/23
  sample: [4.0, 2.0, 4.0, 4.0, 4.0]
[targets['box']] shape=(23, 4) dtype=torch.float32 device=cpu min=0.084375 max=0.999998 mean=0.533526 sum=49.084358 nnz=92/92
  sample: [0.3734374940395355, 0.484375, 0.746874988079071, 0.6343749761581421, 0.520312488079071]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-11.556610 max=33.798119 mean=-0.204480 sum=-137410.781250 nnz=672000/672000
  sample: [-0.4153488278388977, -0.4115291237831116, -0.24972297251224518, -0.165999948978424

Epoch 1/10:   1%|▏         | 8/562 [00:39<41:00,  4.44s/it, Total=1621.3571, Cls=1614.5602, Box=2.6276, DFL=4.1694]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(24,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.708333 sum=161.000000 nnz=22/24
  sample: [0.0, 0.0, 1.0, 1.0, 2.0]
[targets['cls']] shape=(24, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.458333 sum=59.000000 nnz=22/24
  sample: [3.0, 3.0, 0.0, 0.0, 1.0]
[targets['box']] shape=(24, 4) dtype=torch.float32 device=cpu min=0.043750 max=0.996875 mean=0.508227 sum=48.789837 nnz=96/96
  sample: [0.28125, 0.48046875, 0.53125, 0.6171875, 0.75]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-11.524751 max=22.263109 mean=-0.207640 sum=-139533.906250 nnz=672000/672000
  sample: [-0.34825634956359863, -0.268350213766098, -0.22251857817173004, -0.26834312081336975, 0.04859928414225578]
[pred_scores(sig

Epoch 1/10:   2%|▏         | 9/562 [00:43<40:10,  4.36s/it, Total=1577.7734, Cls=1570.9409, Box=2.6656, DFL=4.1668]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(21,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.142857 sum=150.000000 nnz=20/21
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(21, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.000000 sum=42.000000 nnz=15/21
  sample: [2.0, 4.0, 2.0, 1.0, 1.0]
[targets['box']] shape=(21, 4) dtype=torch.float32 device=cpu min=0.135937 max=0.999998 mean=0.562984 sum=47.290615 nnz=84/84
  sample: [0.538281261920929, 0.49609375, 0.918749988079071, 0.979687511920929, 0.4906249940395355]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-18.953283 max=29.993456 mean=-0.213807 sum=-143678.000000 nnz=672000/672000
  sample: [-0.34663593769073486, -0.261007159948349, -0.21952170133590698, -0.27019202709198

Epoch 1/10:   2%|▏         | 10/562 [00:47<39:25,  4.28s/it, Total=1724.2809, Cls=1717.5409, Box=2.6028, DFL=4.1371]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(20,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.600000 sum=152.000000 nnz=19/20
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(20, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.650000 sum=53.000000 nnz=19/20
  sample: [1.0, 2.0, 3.0, 1.0, 1.0]
[targets['box']] shape=(20, 4) dtype=torch.float32 device=cpu min=0.112500 max=0.999998 mean=0.541865 sum=43.349205 nnz=80/80
  sample: [0.43281251192092896, 0.524217963218689, 0.8656250238418579, 0.9515609741210938, 0.40312498807907104]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-13.789505 max=21.985941 mean=-0.215475 sum=-144799.296875 nnz=672000/672000
  sample: [-0.3480251133441925, -0.26703885197639465, -0.22491027414798737, -0.2

Epoch 1/10:   2%|▏         | 11/562 [00:51<37:40,  4.10s/it, Total=1721.5735, Cls=1714.8988, Box=2.5422, DFL=4.1325]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(23,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.130435 sum=187.000000 nnz=22/23
  sample: [0.0, 1.0, 1.0, 2.0, 3.0]
[targets['cls']] shape=(23, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.260870 sum=29.000000 nnz=13/23
  sample: [1.0, 1.0, 1.0, 2.0, 2.0]
[targets['box']] shape=(23, 4) dtype=torch.float32 device=cpu min=0.055469 max=0.999998 mean=0.532549 sum=48.994526 nnz=92/92
  sample: [0.5757812261581421, 0.550000011920929, 0.7593749761581421, 0.7406250238418579, 0.4859375059604645]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-7.213044 max=16.943451 mean=-0.216500 sum=-145487.890625 nnz=672000/672000
  sample: [-0.4128319323062897, -0.40677741169929504, -0.22677989304065704, -0.1768

Epoch 1/10:   2%|▏         | 12/562 [00:55<36:46,  4.01s/it, Total=1741.0863, Cls=1733.9650, Box=2.9724, DFL=4.1489]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(20,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.250000 sum=145.000000 nnz=19/20
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(20, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.800000 sum=36.000000 nnz=16/20
  sample: [1.0, 3.0, 3.0, 4.0, 0.0]
[targets['box']] shape=(20, 4) dtype=torch.float32 device=cpu min=0.104687 max=0.999998 mean=0.576992 sum=46.159363 nnz=80/80
  sample: [0.37109375, 0.5257804989814758, 0.7421875, 0.9484359622001648, 0.5468742251396179]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-20.079485 max=36.927013 mean=-0.223332 sum=-150078.812500 nnz=672000/672000
  sample: [-0.4117221534252167, -0.42663145065307617, -0.24891811609268188, -0.1757635623216629, 0

Epoch 1/10:   2%|▏         | 13/562 [00:59<36:12,  3.96s/it, Total=1735.1504, Cls=1728.4312, Box=2.6024, DFL=4.1169]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(19,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.315789 sum=158.000000 nnz=18/19
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(19, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.526316 sum=48.000000 nnz=16/19
  sample: [2.0, 3.0, 2.0, 2.0, 4.0]
[targets['box']] shape=(19, 4) dtype=torch.float32 device=cpu min=0.093750 max=0.999998 mean=0.594963 sum=45.217175 nnz=76/76
  sample: [0.4671874940395355, 0.4789062440395355, 0.721875011920929, 0.7250000238418579, 0.5171874761581421]
[targets['cls'] uniq] unique(4): [0.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-23.896620 max=45.436630 mean=-0.227555 sum=-152916.875000 nnz=672000/672000
  sample: [-0.35321083664894104, -0.27438488602638245, -0.2318626046180725, -0.27194783

Epoch 1/10:   2%|▏         | 14/562 [01:03<35:47,  3.92s/it, Total=1848.8267, Cls=1842.3849, Box=2.3300, DFL=4.1118]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(25,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.040000 sum=201.000000 nnz=23/25
  sample: [0.0, 0.0, 1.0, 2.0, 3.0]
[targets['cls']] shape=(25, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.360000 sum=59.000000 nnz=24/25
  sample: [1.0, 1.0, 2.0, 1.0, 0.0]
[targets['box']] shape=(25, 4) dtype=torch.float32 device=cpu min=0.085938 max=0.999998 mean=0.530031 sum=53.003113 nnz=100/100
  sample: [0.512499988079071, 0.44218748807907104, 0.21562500298023224, 0.3453125059604645, 0.2718749940395355]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-20.864254 max=38.699741 mean=-0.227000 sum=-152544.156250 nnz=672000/672000
  sample: [-0.4093879163265228, -0.42198312282562256, -0.24012410640716553, -0

Epoch 1/10:   3%|▎         | 15/562 [01:06<35:40,  3.91s/it, Total=1442.6899, Cls=1435.8846, Box=2.6844, DFL=4.1209]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(18,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.000000 sum=126.000000 nnz=17/18
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(18, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.222222 sum=40.000000 nnz=16/18
  sample: [2.0, 1.0, 4.0, 4.0, 4.0]
[targets['box']] shape=(18, 4) dtype=torch.float32 device=cpu min=0.187500 max=0.999998 mean=0.580046 sum=41.763279 nnz=72/72
  sample: [0.4164062440395355, 0.672656238079071, 0.707812488079071, 0.5718749761581421, 0.55859375]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-14.215843 max=27.495424 mean=-0.228990 sum=-153881.531250 nnz=672000/672000
  sample: [-0.35717013478279114, -0.2791576683521271, -0.22867777943611145, -0.268201470375

Epoch 1/10:   3%|▎         | 16/562 [01:10<35:28,  3.90s/it, Total=2010.6237, Cls=2003.8062, Box=2.6822, DFL=4.1354]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(26,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.076923 sum=158.000000 nnz=24/26
  sample: [0.0, 0.0, 1.0, 2.0, 2.0]
[targets['cls']] shape=(26, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.653846 sum=43.000000 nnz=24/26
  sample: [1.0, 1.0, 3.0, 1.0, 1.0]
[targets['box']] shape=(26, 4) dtype=torch.float32 device=cpu min=0.069531 max=0.981248 mean=0.526615 sum=54.767956 nnz=104/104
  sample: [0.41484373807907104, 0.4789062440395355, 0.8296874761581421, 0.875, 0.680468738079071]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-14.536716 max=38.107899 mean=-0.234105 sum=-157318.734375 nnz=672000/672000
  sample: [-0.3997673988342285, -0.3775644302368164, -0.20238862931728363, -0.19680473208427

Epoch 1/10:   3%|▎         | 17/562 [01:14<35:36,  3.92s/it, Total=1564.3118, Cls=1557.4113, Box=2.8002, DFL=4.1002]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(16,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.500000 sum=120.000000 nnz=15/16
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(16, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.062500 sum=33.000000 nnz=15/16
  sample: [1.0, 3.0, 2.0, 2.0, 1.0]
[targets['box']] shape=(16, 4) dtype=torch.float32 device=cpu min=0.185937 max=0.996875 mean=0.595996 sum=38.143749 nnz=64/64
  sample: [0.6312500238418579, 0.5390625, 0.35624995827674866, 0.6796875, 0.4507812559604645]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-17.359341 max=51.986168 mean=-0.236797 sum=-159127.781250 nnz=672000/672000
  sample: [-0.38627591729164124, -0.34390339255332947, -0.2042395919561386, -0.22469881176948547, 

Epoch 1/10:   3%|▎         | 18/562 [01:18<35:20,  3.90s/it, Total=2175.3699, Cls=2168.5247, Box=2.7439, DFL=4.1013]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(20,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.750000 sum=135.000000 nnz=19/20
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(20, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.700000 sum=34.000000 nnz=15/20
  sample: [2.0, 3.0, 1.0, 1.0, 1.0]
[targets['box']] shape=(20, 4) dtype=torch.float32 device=cpu min=0.084375 max=0.999998 mean=0.574736 sum=45.978889 nnz=80/80
  sample: [0.617968738079071, 0.47578126192092896, 0.7593749761581421, 0.836718738079071, 0.38671875]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-17.546007 max=48.447388 mean=-0.241060 sum=-161992.421875 nnz=672000/672000
  sample: [-0.3667796850204468, -0.28310996294021606, -0.2155599445104599, -0.260915219783

Epoch 1/10:   3%|▎         | 19/562 [01:22<34:36,  3.82s/it, Total=1797.3412, Cls=1790.8201, Box=2.4375, DFL=4.0836]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(24,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.791667 sum=211.000000 nnz=23/24
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(24, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.250000 sum=54.000000 nnz=20/24
  sample: [4.0, 1.0, 1.0, 4.0, 4.0]
[targets['box']] shape=(24, 4) dtype=torch.float32 device=cpu min=0.110937 max=0.999998 mean=0.535555 sum=51.413273 nnz=96/96
  sample: [0.524218738079071, 0.5296875238418579, 0.885937511920929, 0.9375, 0.4398437440395355]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-15.789781 max=36.187393 mean=-0.243707 sum=-163770.812500 nnz=672000/672000
  sample: [-0.3708961308002472, -0.3281770646572113, -0.2398737221956253, -0.27850860357284546,

Epoch 1/10:   4%|▎         | 20/562 [01:26<34:38,  3.83s/it, Total=1455.7450, Cls=1449.0859, Box=2.5684, DFL=4.0907]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(22,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.090909 sum=156.000000 nnz=19/22
  sample: [0.0, 0.0, 0.0, 1.0, 2.0]
[targets['cls']] shape=(22, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.909091 sum=42.000000 nnz=18/22
  sample: [4.0, 4.0, 4.0, 0.0, 3.0]
[targets['box']] shape=(22, 4) dtype=torch.float32 device=cpu min=0.078125 max=0.996875 mean=0.557786 sum=49.085152 nnz=88/88
  sample: [0.34375, 0.515625, 0.6859375238418579, 0.5249999761581421, 0.4203124940395355]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-14.173998 max=28.078943 mean=-0.246287 sum=-165504.562500 nnz=672000/672000
  sample: [-0.390545517206192, -0.4720116853713989, -0.29714635014533997, -0.2177928239107132, 0.03943

Epoch 1/10:   4%|▎         | 21/562 [01:30<34:36,  3.84s/it, Total=1735.9592, Cls=1729.0696, Box=2.7843, DFL=4.1054]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(24,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.041667 sum=145.000000 nnz=23/24
  sample: [0.0, 1.0, 2.0, 2.0, 2.0]
[targets['cls']] shape=(24, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=0.958333 sum=23.000000 nnz=10/24
  sample: [3.0, 0.0, 0.0, 0.0, 0.0]
[targets['box']] shape=(24, 4) dtype=torch.float32 device=cpu min=0.064844 max=0.999998 mean=0.543937 sum=52.217960 nnz=96/96
  sample: [0.49609375, 0.53515625, 0.9921875, 0.9140625, 0.49921876192092896]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-16.508114 max=34.319263 mean=-0.249451 sum=-167631.328125 nnz=672000/672000
  sample: [-0.36232438683509827, -0.27496153116226196, -0.21497580409049988, -0.27650031447410583, 0.0488505624234

Epoch 1/10:   4%|▍         | 22/562 [01:34<35:09,  3.91s/it, Total=1532.7039, Cls=1525.7155, Box=2.9044, DFL=4.0840]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(20,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.800000 sum=156.000000 nnz=19/20
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(20, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.550000 sum=31.000000 nnz=14/20
  sample: [4.0, 2.0, 4.0, 1.0, 0.0]
[targets['box']] shape=(20, 4) dtype=torch.float32 device=cpu min=0.146875 max=0.999998 mean=0.543974 sum=43.517956 nnz=80/80
  sample: [0.4999992251396179, 0.4999992251396179, 0.9999984502792358, 0.9999984502792358, 0.5375000238418579]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-14.311651 max=29.816950 mean=-0.252236 sum=-169502.656250 nnz=672000/672000
  sample: [-0.3792431652545929, -0.3808777928352356, -0.2448544055223465, -0.2498

Epoch 1/10:   4%|▍         | 23/562 [01:37<34:58,  3.89s/it, Total=1697.2689, Cls=1690.5640, Box=2.6287, DFL=4.0763]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(19,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.000000 sum=152.000000 nnz=18/19
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(19, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.368421 sum=45.000000 nnz=17/19
  sample: [0.0, 1.0, 4.0, 2.0, 4.0]
[targets['box']] shape=(19, 4) dtype=torch.float32 device=cpu min=0.100781 max=0.999998 mean=0.567835 sum=43.155453 nnz=76/76
  sample: [0.5874999761581421, 0.484375, 0.512499988079071, 0.835156261920929, 0.5382804870605469]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-14.440008 max=28.614769 mean=-0.256499 sum=-172367.375000 nnz=672000/672000
  sample: [-0.3796234428882599, -0.47945162653923035, -0.312490314245224, -0.2387837171554565

Epoch 1/10:   4%|▍         | 24/562 [01:42<35:36,  3.97s/it, Total=1849.8157, Cls=1843.0720, Box=2.6993, DFL=4.0443]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(25,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.080000 sum=177.000000 nnz=24/25
  sample: [0.0, 1.0, 1.0, 1.0, 2.0]
[targets['cls']] shape=(25, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.520000 sum=38.000000 nnz=16/25
  sample: [3.0, 0.0, 0.0, 0.0, 1.0]
[targets['box']] shape=(25, 4) dtype=torch.float32 device=cpu min=0.087500 max=0.996875 mean=0.500680 sum=50.067955 nnz=100/100
  sample: [0.6031249761581421, 0.532031238079071, 0.7875000238418579, 0.8531249761581421, 0.66796875]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-12.446234 max=25.616982 mean=-0.261580 sum=-175781.500000 nnz=672000/672000
  sample: [-0.3759429156780243, -0.4455881118774414, -0.28123998641967773, -0.2517534494

Epoch 1/10:   4%|▍         | 25/562 [01:46<35:39,  3.99s/it, Total=1459.6965, Cls=1452.8334, Box=2.8032, DFL=4.0599]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(23,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.260870 sum=167.000000 nnz=22/23
  sample: [0.0, 1.0, 1.0, 1.0, 2.0]
[targets['cls']] shape=(23, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.782609 sum=41.000000 nnz=16/23
  sample: [0.0, 0.0, 0.0, 0.0, 4.0]
[targets['box']] shape=(23, 4) dtype=torch.float32 device=cpu min=0.073436 max=0.999998 mean=0.561413 sum=51.649986 nnz=92/92
  sample: [0.51953125, 0.55859375, 0.620312511920929, 0.8773437738418579, 0.2562499940395355]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-13.868507 max=26.673128 mean=-0.261511 sum=-175735.531250 nnz=672000/672000
  sample: [-0.37008512020111084, -0.48861658573150635, -0.32245075702667236, -0.268458753824234, 0

Epoch 1/10:   5%|▍         | 26/562 [01:50<35:30,  3.97s/it, Total=1523.2517, Cls=1516.2457, Box=2.9121, DFL=4.0939]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(21,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.523809 sum=137.000000 nnz=20/21
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(21, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.285714 sum=48.000000 nnz=19/21
  sample: [1.0, 2.0, 1.0, 4.0, 4.0]
[targets['box']] shape=(21, 4) dtype=torch.float32 device=cpu min=0.081250 max=0.993748 mean=0.570564 sum=47.927338 nnz=84/84
  sample: [0.503125011920929, 0.46796876192092896, 0.3687499761581421, 0.5921875238418579, 0.5523437261581421]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-11.303843 max=21.596521 mean=-0.267244 sum=-179588.187500 nnz=672000/672000
  sample: [-0.357407808303833, -0.5061233639717102, -0.3528450131416321, -0.29820

Epoch 1/10:   5%|▍         | 27/562 [01:54<35:34,  3.99s/it, Total=1609.2532, Cls=1602.9099, Box=2.3108, DFL=4.0324]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(33,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=5.666667 sum=187.000000 nnz=26/33
  sample: [0.0, 0.0, 0.0, 0.0, 0.0]
[targets['cls']] shape=(33, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.575758 sum=52.000000 nnz=25/33
  sample: [1.0, 1.0, 1.0, 1.0, 1.0]
[targets['box']] shape=(33, 4) dtype=torch.float32 device=cpu min=0.078125 max=0.988281 mean=0.451639 sum=59.616405 nnz=132/132
  sample: [0.39375001192092896, 0.518750011920929, 0.38749998807907104, 0.4546875059604645, 0.08749999850988388]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-8.271017 max=14.497089 mean=-0.267923 sum=-180044.343750 nnz=672000/672000
  sample: [-0.34810346364974976, -0.5441296696662903, -0.3895624876022339, -0.

Epoch 1/10:   5%|▍         | 28/562 [01:57<35:08,  3.95s/it, Total=1123.4913, Cls=1116.7383, Box=2.7296, DFL=4.0235]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(23,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.173913 sum=165.000000 nnz=21/23
  sample: [0.0, 0.0, 1.0, 2.0, 3.0]
[targets['cls']] shape=(23, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.130435 sum=49.000000 nnz=17/23
  sample: [0.0, 0.0, 3.0, 2.0, 0.0]
[targets['box']] shape=(23, 4) dtype=torch.float32 device=cpu min=0.159375 max=0.999998 mean=0.597987 sum=55.014835 nnz=92/92
  sample: [0.48906248807907104, 0.34140628576278687, 0.34843745827674866, 0.5843750238418579, 0.7523437738418579]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-10.252823 max=18.407978 mean=-0.274472 sum=-184445.187500 nnz=672000/672000
  sample: [-0.35479065775871277, -0.2631850242614746, -0.21940749883651733, -0

Epoch 1/10:   5%|▌         | 29/562 [02:01<34:37,  3.90s/it, Total=1510.3518, Cls=1503.8240, Box=2.4482, DFL=4.0797]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(20,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.000000 sum=160.000000 nnz=19/20
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(20, 1) dtype=torch.float32 device=cpu min=1.000000 max=4.000000 mean=3.000000 sum=60.000000 nnz=20/20
  sample: [2.0, 2.0, 4.0, 3.0, 3.0]
[targets['box']] shape=(20, 4) dtype=torch.float32 device=cpu min=0.165625 max=0.999998 mean=0.619902 sum=49.592182 nnz=80/80
  sample: [0.5757812261581421, 0.47734373807907104, 0.8460937738418579, 0.9546874761581421, 0.4124999940395355]
[targets['cls'] uniq] unique(4): [1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-8.657863 max=18.915436 mean=-0.277796 sum=-186679.218750 nnz=672000/672000
  sample: [-0.3571088910102844, -0.268420934677124, -0.22493235766887665, -0.309045821

Epoch 1/10:   5%|▌         | 30/562 [02:05<34:30,  3.89s/it, Total=1635.1572, Cls=1629.0804, Box=2.0692, DFL=4.0076]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(19,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=8.263158 sum=157.000000 nnz=18/19
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(19, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.210526 sum=42.000000 nnz=16/19
  sample: [2.0, 2.0, 2.0, 0.0, 1.0]
[targets['box']] shape=(19, 4) dtype=torch.float32 device=cpu min=0.087500 max=0.998438 mean=0.572882 sum=43.539059 nnz=76/76
  sample: [0.6148437261581421, 0.543749988079071, 0.762499988079071, 0.901562511920929, 0.578125]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-10.393990 max=26.464697 mean=-0.284978 sum=-191504.968750 nnz=672000/672000
  sample: [-0.3521030843257904, -0.2513987123966217, -0.22342470288276672, -0.3210511207580566

Epoch 1/10:   6%|▌         | 31/562 [02:09<33:54,  3.83s/it, Total=1829.6776, Cls=1823.2007, Box=2.4654, DFL=4.0114]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(19,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=6.368421 sum=121.000000 nnz=16/19
  sample: [0.0, 0.0, 0.0, 1.0, 1.0]
[targets['cls']] shape=(19, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=1.894737 sum=36.000000 nnz=14/19
  sample: [4.0, 4.0, 4.0, 0.0, 0.0]
[targets['box']] shape=(19, 4) dtype=torch.float32 device=cpu min=0.109375 max=0.999998 mean=0.579739 sum=44.060143 nnz=76/76
  sample: [0.768750011920929, 0.7132812738418579, 0.17812499403953552, 0.24687500298023224, 0.5921875238418579]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-7.680235 max=23.738459 mean=-0.285702 sum=-191992.000000 nnz=672000/672000
  sample: [-0.34480994939804077, -0.4165251851081848, -0.28669264912605286, -0.33

Epoch 1/10:   6%|▌         | 32/562 [02:13<34:03,  3.86s/it, Total=1726.1365, Cls=1719.9116, Box=2.2312, DFL=3.9937]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(23,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.478261 sum=172.000000 nnz=22/23
  sample: [0.0, 1.0, 2.0, 3.0, 3.0]
[targets['cls']] shape=(23, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.086957 sum=48.000000 nnz=17/23
  sample: [3.0, 3.0, 3.0, 0.0, 0.0]
[targets['box']] shape=(23, 4) dtype=torch.float32 device=cpu min=0.075000 max=0.989062 mean=0.505129 sum=46.471867 nnz=92/92
  sample: [0.5859375, 0.5453125238418579, 0.6703125238418579, 0.879687488079071, 0.4867187440395355]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-8.767775 max=29.540455 mean=-0.293132 sum=-196984.921875 nnz=672000/672000
  sample: [-0.34967610239982605, -0.24326559901237488, -0.22248384356498718, -0.330321341753

Epoch 1/10:   6%|▌         | 33/562 [02:17<33:51,  3.84s/it, Total=1515.0918, Cls=1508.8643, Box=2.2654, DFL=3.9620]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(29,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.448276 sum=216.000000 nnz=28/29
  sample: [0.0, 1.0, 2.0, 2.0, 2.0]
[targets['cls']] shape=(29, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.413793 sum=70.000000 nnz=26/29
  sample: [1.0, 2.0, 1.0, 1.0, 1.0]
[targets['box']] shape=(29, 4) dtype=torch.float32 device=cpu min=0.057813 max=0.999998 mean=0.508621 sum=58.999992 nnz=116/116
  sample: [0.5507804751396179, 0.5445305109024048, 0.8984359502792358, 0.9109359979629517, 0.836718738079071]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-7.280628 max=40.190544 mean=-0.295733 sum=-198732.406250 nnz=672000/672000
  sample: [-0.3528614044189453, -0.2760125398635864, -0.22301486134529114, -0.324

Epoch 1/10:   6%|▌         | 34/562 [02:20<33:50,  3.85s/it, Total=1178.9917, Cls=1172.2355, Box=2.7692, DFL=3.9871]


num.feature levels: 3
  L0 shape=(16, 69, 80, 80)
  L1 shape=(16, 69, 40, 40)
  L2 shape=(16, 69, 20, 20)
targets keys: ['cls', 'box', 'idx']
[targets['idx']] shape=(19,) dtype=torch.float32 device=cpu min=0.000000 max=15.000000 mean=7.526316 sum=143.000000 nnz=18/19
  sample: [0.0, 1.0, 2.0, 3.0, 4.0]
[targets['cls']] shape=(19, 1) dtype=torch.float32 device=cpu min=0.000000 max=4.000000 mean=2.157895 sum=41.000000 nnz=17/19
  sample: [1.0, 4.0, 2.0, 2.0, 2.0]
[targets['box']] shape=(19, 4) dtype=torch.float32 device=cpu min=0.118750 max=0.999998 mean=0.586081 sum=44.542171 nnz=76/76
  sample: [0.5882804989814758, 0.5687500238418579, 0.8234359622001648, 0.859375, 0.5304679870605469]
[targets['cls'] uniq] unique(5): [0.0, 1.0, 2.0, 3.0, 4.0]


[pred_scores(logits)] shape=(16, 8400, 5) dtype=torch.float32 device=cpu min=-12.245803 max=42.208099 mean=-0.299193 sum=-201057.937500 nnz=672000/672000
  sample: [-0.34517309069633484, -0.22251826524734497, -0.23873206973075867, -0.34721371531

Epoch 1/10:   6%|▌         | 34/562 [02:24<33:50,  3.85s/it, Total=1796.6754, Cls=1790.3151, Box=2.4314, DFL=3.9289]