In [2]:
# import
import math, os, random, cv2, numpy, torch
import torch.nn as nn

### (a) Conv 

In [4]:
class Conv(nn.Module):
    def __init__(self,in_channels, out_channels,kernel_size=3,stride=1,padding=1,groups=1,activation=True):
        super().__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,groups=groups)
        self.bn=nn.BatchNorm2d(out_channels,eps=0.001,momentum=0.03)
        self.act=nn.SiLU(inplace=True) if activation else nn.Identity()

    def forward(self,x):
        return self.act(self.bn(self.conv(x)))

### (b) C2f

In [5]:
# 2.1 Bottleneck: staack of 2 COnv with shortcut connnection (True/False)
class Bottleneck(nn.Module):
    def __init__(self,in_channels,out_channels,shortcut=True):
        super().__init__()
        self.conv1=Conv(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.conv2=Conv(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        self.shortcut=shortcut

    def forward(self,x):
        x_in=x # for residual connection
        x=self.conv1(x)
        x=self.conv2(x)
        if self.shortcut:
            x=x+x_in
        return x
    
# 2.2 C2f: Conv + bottleneck*N+ Conv
class C2f(nn.Module):
    def __init__(self,in_channels,out_channels, num_bottlenecks,shortcut=True):
        super().__init__()
        
        self.mid_channels=out_channels//2
        self.num_bottlenecks=num_bottlenecks

        self.conv1=Conv(in_channels,out_channels,kernel_size=1,stride=1,padding=0)
        
        # sequence of bottleneck layers
        self.m=nn.ModuleList([Bottleneck(self.mid_channels,self.mid_channels) for _ in range(num_bottlenecks)])

        self.conv2=Conv((num_bottlenecks+2)*out_channels//2,out_channels,kernel_size=1,stride=1,padding=0)
    
    def forward(self,x):
        x=self.conv1(x)

        # split x along channel dimension
        x1,x2=x[:,:x.shape[1]//2,:,:], x[:,x.shape[1]//2:,:,:]
        
        # list of outputs
        outputs=[x1,x2] # x1 is fed through the bottlenecks

        for i in range(self.num_bottlenecks):
            x1=self.m[i](x1)    # [bs,0.5c_out,w,h]
            outputs.insert(0,x1)

        outputs=torch.cat(outputs,dim=1) # [bs,0.5c_out(num_bottlenecks+2),w,h]
        out=self.conv2(outputs)

        return out
         
# sanity check
c2f=C2f(in_channels=64,out_channels=128,num_bottlenecks=2)
print(f"{sum(p.numel() for p in c2f.parameters())/1e6} million parameters")

dummy_input=torch.rand((1,64,244,244))
dummy_input=c2f(dummy_input)
print("Output shape: ", dummy_input.shape)


0.18944 million parameters
Output shape:  torch.Size([1, 128, 244, 244])


### (c) SPPF

In [6]:
class SPPF(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=5):
        #kernel_size= size of maxpool
        super().__init__()
        hidden_channels=in_channels//2
        self.conv1=Conv(in_channels,hidden_channels,kernel_size=1,stride=1,padding=0)
        # concatenate outputs of maxpool and feed to conv2
        self.conv2=Conv(4*hidden_channels,out_channels,kernel_size=1,stride=1,padding=0)

        # maxpool is applied at 3 different sacles
        self.m=nn.MaxPool2d(kernel_size=kernel_size,stride=1,padding=kernel_size//2,dilation=1,ceil_mode=False)
    
    def forward(self,x):
        x=self.conv1(x)

        # apply maxpooling at diffent scales
        y1=self.m(x)
        y2=self.m(y1)
        y3=self.m(y2)

        # concantenate 
        y=torch.cat([x,y1,y2,y3],dim=1)

        # final conv
        y=self.conv2(y)

        return y

# sanity check
sppf=SPPF(in_channels=128,out_channels=512)
print(f"{sum(p.numel() for p in sppf.parameters())/1e6} million parameters")

dummy_input=sppf(dummy_input)
print("Output shape: ", dummy_input.shape)


0.140416 million parameters
Output shape:  torch.Size([1, 512, 244, 244])


### Putting things together

In [7]:
# backbone = DarkNet53

# return d,w,r based on version
def yolo_params(version):
    if version=='n':
        return 1/3,1/4,2.0
    elif version=='s':
        return 1/3,1/2,2.0
    elif version=='m':
        return 2/3,3/4,1.5
    elif version=='l':
        return 1.0,1.0,1.0
    elif version=='x':
        return 1.0,1.25,1.0
    
class Backbone(nn.Module):
    def __init__(self,version,in_channels=3,shortcut=True):
        super().__init__()
        d,w,r=yolo_params(version)

        # conv layers
        self.conv_0=Conv(in_channels,int(64*w),kernel_size=3,stride=2,padding=1)
        self.conv_1=Conv(int(64*w),int(128*w),kernel_size=3,stride=2,padding=1)
        self.conv_3=Conv(int(128*w),int(256*w),kernel_size=3,stride=2,padding=1)
        self.conv_5=Conv(int(256*w),int(512*w),kernel_size=3,stride=2,padding=1)
        self.conv_7=Conv(int(512*w),int(512*w*r),kernel_size=3,stride=2,padding=1)

        # c2f layers
        self.c2f_2=C2f(int(128*w),int(128*w),num_bottlenecks=int(3*d),shortcut=True)
        self.c2f_4=C2f(int(256*w),int(256*w),num_bottlenecks=int(6*d),shortcut=True)
        self.c2f_6=C2f(int(512*w),int(512*w),num_bottlenecks=int(6*d),shortcut=True)
        self.c2f_8=C2f(int(512*w*r),int(512*w*r),num_bottlenecks=int(3*d),shortcut=True)

        # sppf
        self.sppf=SPPF(int(512*w*r),int(512*w*r))
    
    def forward(self,x):
        x=self.conv_0(x)
        x=self.conv_1(x)

        x=self.c2f_2(x)

        x=self.conv_3(x)

        out1=self.c2f_4(x) # keep for output

        x=self.conv_5(out1)

        out2=self.c2f_6(x) # keep for output

        x=self.conv_7(out2)
        x=self.c2f_8(x)
        out3=self.sppf(x)

        return out1,out2,out3

print("----Nano model -----")
backbone_n=Backbone(version='n')
print(f"{sum(p.numel() for p in backbone_n.parameters())/1e6} million parameters")

print("----Small model -----")
backbone_s=Backbone(version='s')
print(f"{sum(p.numel() for p in backbone_s.parameters())/1e6} million parameters")
        

        


----Nano model -----
1.272656 million parameters
----Small model -----
5.079712 million parameters


In [8]:
# sanity check
x=torch.rand((1,3,640,640))
out1,out2,out3=backbone_n(x)
print(out1.shape)
print(out2.shape)
print(out3.shape)

torch.Size([1, 64, 80, 80])
torch.Size([1, 128, 40, 40])
torch.Size([1, 256, 20, 20])


## 2. Neck
The neck comprises of Upsample + C2f with 

**Upsample** = nearest-neighbor interpolation with scale_factor=2. It doesn't have trainable paramaters.

In [9]:
# upsample = nearest-neighbor interpolation with scale_factor=2
#            doesn't have trainable paramaters
class Upsample(nn.Module):
    def __init__(self,scale_factor=2,mode='nearest'):
        super().__init__()
        self.scale_factor=scale_factor
        self.mode=mode

    def forward(self,x):
        return nn.functional.interpolate(x,scale_factor=self.scale_factor,mode=self.mode)
    
    


In [10]:
class Neck(nn.Module):
    def __init__(self,version):
        super().__init__()
        d,w,r=yolo_params(version)

        self.up=Upsample() # no trainable parameters
        self.c2f_1=C2f(in_channels=int(512*w*(1+r)), out_channels=int(512*w),num_bottlenecks=int(3*d),shortcut=False)
        self.c2f_2=C2f(in_channels=int(768*w), out_channels=int(256*w),num_bottlenecks=int(3*d),shortcut=False)
        self.c2f_3=C2f(in_channels=int(768*w), out_channels=int(512*w),num_bottlenecks=int(3*d),shortcut=False)
        self.c2f_4=C2f(in_channels=int(512*w*(1+r)), out_channels=int(512*w*r),num_bottlenecks=int(3*d),shortcut=False)

        self.cv_1=Conv(in_channels=int(256*w),out_channels=int(256*w),kernel_size=3,stride=2, padding=1)
        self.cv_2=Conv(in_channels=int(512*w),out_channels=int(512*w),kernel_size=3,stride=2, padding=1)


    def forward(self,x_res_1,x_res_2,x):    
        # x_res_1,x_res_2,x = output of backbone
        res_1=x              # for residual connection
        
        x=self.up(x)
        x=torch.cat([x,x_res_2],dim=1)

        res_2=self.c2f_1(x)  # for residual connection
        
        x=self.up(res_2)
        x=torch.cat([x,x_res_1],dim=1)

        out_1=self.c2f_2(x)

        x=self.cv_1(out_1)

        x=torch.cat([x,res_2],dim=1)
        out_2=self.c2f_3(x)

        x=self.cv_2(out_2)

        x=torch.cat([x,res_1],dim=1)
        out_3=self.c2f_4(x)

        return out_1,out_2,out_3
    
# sanity check
neck=Neck(version='n')
print(f"{sum(p.numel() for p in neck.parameters())/1e6} million parameters")

x=torch.rand((1,3,640,640))
out1,out2,out3=Backbone(version='n')(x)
out_1,out_2,out_3=neck(out1,out2,out3)
print(out_1.shape)
print(out_2.shape)
print(out_3.shape)



0.98688 million parameters
torch.Size([1, 64, 80, 80])
torch.Size([1, 128, 40, 40])
torch.Size([1, 256, 20, 20])


### (a) DFL

In [11]:
# DFL
class DFL(nn.Module):
    def __init__(self,ch=16):
        super().__init__()
        
        self.ch=ch
        
        self.conv=nn.Conv2d(in_channels=ch,out_channels=1,kernel_size=1,bias=False).requires_grad_(False)
        
        # initialize conv with [0,...,ch-1]
        x=torch.arange(ch,dtype=torch.float).view(1,ch,1,1)
        self.conv.weight.data[:]=torch.nn.Parameter(x) # DFL only has ch parameters

    def forward(self,x):
        # x must have num_channels = 4*ch: x=[bs,4*ch,c]
        b,c,a=x.shape                           # c=4*ch
        x=x.view(b,4,self.ch,a).transpose(1,2)  # [bs,ch,4,a]

        # take softmax on channel dimension to get distribution probabilities
        x=x.softmax(1)                          # [b,ch,4,a]
        x=self.conv(x)                          # [b,1,4,a]
        return x.view(b,4,a)                    # [b,4,a]

# sanity check
dummy_input=torch.rand((1,64,128))
dfl=DFL()
print(f"{sum(p.numel() for p in dfl.parameters())} parameters")

dummy_output=dfl(dummy_input)
print(dummy_output.shape)

print(dfl)





16 parameters
torch.Size([1, 4, 128])
DFL(
  (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
)


### (b) Head

In [12]:
class Head(nn.Module):
    def __init__(self,version,ch=16,num_classes=9):

        super().__init__()
        self.ch=ch                          # dfl channels
        self.coordinates=self.ch*4          # number of bounding box coordinates 
        self.nc=num_classes                
        self.no=self.coordinates+self.nc    # number of outputs per anchor box

        self.stride=torch.zeros(3)          # strides computed during build
        
        d,w,r=yolo_params(version=version)
        
        # for bounding boxes
        self.box=nn.ModuleList([
            nn.Sequential(Conv(int(256*w),self.coordinates,kernel_size=3,stride=1,padding=1),
                          Conv(self.coordinates,self.coordinates,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.coordinates,self.coordinates,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w),self.coordinates,kernel_size=3,stride=1,padding=1),
                          Conv(self.coordinates,self.coordinates,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.coordinates,self.coordinates,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w*r),self.coordinates,kernel_size=3,stride=1,padding=1),
                          Conv(self.coordinates,self.coordinates,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.coordinates,self.coordinates,kernel_size=1,stride=1))
        ])

        # for classification
        self.cls=nn.ModuleList([
            nn.Sequential(Conv(int(256*w),self.nc,kernel_size=3,stride=1,padding=1),
                          Conv(self.nc,self.nc,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.nc,self.nc,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w),self.nc,kernel_size=3,stride=1,padding=1),
                          Conv(self.nc,self.nc,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.nc,self.nc,kernel_size=1,stride=1)),

            nn.Sequential(Conv(int(512*w*r),self.nc,kernel_size=3,stride=1,padding=1),
                          Conv(self.nc,self.nc,kernel_size=3,stride=1,padding=1),
                          nn.Conv2d(self.nc,self.nc,kernel_size=1,stride=1))
        ])

        # dfl
        self.dfl=DFL()

    def forward(self,x):
        # x = output of Neck = list of 3 tensors with different resolution and different channel dim
        #     x[0]=[bs, ch0, w0, h0], x[1]=[bs, ch1, w1, h1], x[2]=[bs,ch2, w2, h2] 

        for i in range(len(self.box)):       # detection head i
            box=self.box[i](x[i])            # [bs,num_coordinates,w,h]
            cls=self.cls[i](x[i])            # [bs,num_classes,w,h]
            x[i]=torch.cat((box,cls),dim=1)  # [bs,num_coordinates+num_classes,w,h]

        # in training, no dfl output
        if self.training:
            return x                         # [3,bs,num_coordinates+num_classes,w,h]
        
        # in inference time, dfl produces refined bounding box coordinates
        anchors, strides = (i.transpose(0, 1) for i in self.make_anchors(x, self.stride))

        # concatenate predictions from all detection layers
        x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2) #[bs, 4*self.ch + self.nc, sum_i(h[i]w[i])]
        
        # split out predictions for box and cls
        #           box=[bs,4×self.ch,sum_i(h[i]w[i])]
        #           cls=[bs,self.nc,sum_i(h[i]w[i])]
        box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)


        a, b = self.dfl(box).chunk(2, 1)  # a=b=[bs,2×self.ch,sum_i(h[i]w[i])]
        a = anchors.unsqueeze(0) - a
        b = anchors.unsqueeze(0) + b
        box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
        
        return torch.cat(tensors=(box * strides, cls.sigmoid()), dim=1)


    def make_anchors(self, x, strides, offset=0.5):
        # x= list of feature maps: x=[x[0],...,x[N-1]], in our case N= num_detection_heads=3
        #                          each having shape [bs,ch,w,h]
        #    each feature map x[i] gives output[i] = w*h anchor coordinates + w*h stride values
        
        # strides = list of stride values indicating how much 
        #           the spatial resolution of the feature map is reduced compared to the original image

        assert x is not None
        anchor_tensor, stride_tensor = [], []
        dtype, device = x[0].dtype, x[0].device
        for i, stride in enumerate(strides):
            _, _, h, w = x[i].shape
            sx = torch.arange(end=w, device=device, dtype=dtype) + offset  # x coordinates of anchor centers
            sy = torch.arange(end=h, device=device, dtype=dtype) + offset  # y coordinates of anchor centers
            sy, sx = torch.meshgrid(sy, sx)                                # all anchor centers 
            anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
            stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
        return torch.cat(anchor_tensor), torch.cat(stride_tensor)
        

In [13]:

detect=Head(version='n')
print(f"{sum(p.numel() for p in detect.parameters())/1e6} million parameters")

# out_1,out_2,out_3 are output of the neck
output=detect([out_1,out_2,out_3])
print(output[0].shape)
print(output[1].shape)
print(output[2].shape)

print(detect)


0.390118 million parameters
torch.Size([1, 66, 80, 80])
torch.Size([1, 66, 40, 40])
torch.Size([1, 66, 20, 20])
Head(
  (box): ModuleList(
    (0): Sequential(
      (0): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): Sequential(
      (0): Conv(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (

## 4. Putting everything together

In [14]:
class MyYolo(nn.Module):
    def __init__(self,version):
        super().__init__()
        self.backbone=Backbone(version=version)
        self.neck=Neck(version=version)
        self.head=Head(version=version)

    def forward(self,x):
        x=self.backbone(x)              # return out1,out2,out3
        x=self.neck(x[0],x[1],x[2])     # return out_1, out_2,out_3
        return self.head(list(x))
    
model=MyYolo(version='n')
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")
print(model)

2.649654 million parameters
MyYolo(
  (backbone): Backbone(
    (conv_0): Conv(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (conv_1): Conv(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (conv_3): Conv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (conv_5): Conv(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
     

In [15]:
# import shutil
# import os

# output_dir = "/kaggle/working"

# # Delete everything inside the directory
# for filename in os.listdir(output_dir):
#     file_path = os.path.join(output_dir, filename)
#     try:
#         if os.path.isfile(file_path) or os.path.islink(file_path):
#             os.unlink(file_path)  # remove file
#         elif os.path.isdir(file_path):
#             shutil.rmtree(file_path)  # remove foldear
#     except Exception as e:
#         print(f"Failed to delete {file_path}. Reason: {e}")

# print("/kaggle/working cleared")

In [16]:
import os
import shutil
from pathlib import Path

# Input dataset root (contains multiple class folders)
source_root = Path("/kaggle/input/urban-issues-dataset")

# Output YOLO dataset root
output_root = Path("/kaggle/working/dataset")
splits = ["train", "valid", "test"]

# Create YOLO folders
for split in splits:
    (output_root / "images" / split).mkdir(parents=True, exist_ok=True)
    (output_root / "labels" / split).mkdir(parents=True, exist_ok=True)

# Loop through each class folder
for class_folder in source_root.iterdir():
    if class_folder.is_dir():
        for split in splits:
            img_dir = class_folder / class_folder.name / split / "images"
            label_dir = class_folder / class_folder.name / split / "labels"

            if not img_dir.exists() or not label_dir.exists():
                continue  # skip if split doesn't exist for this class

            for img_file in img_dir.glob("*.*"):
                # Copy image
                dest_img_path = output_root / "images" / split / f"{class_folder.name}_{img_file.name}"
                shutil.copy(img_file, dest_img_path)

                # Copy label file (no ID remapping)
                src_label_path = label_dir / f"{img_file.stem}.txt"
                dest_label_path = output_root / "labels" / split / f"{class_folder.name}_{img_file.stem}.txt"
                if src_label_path.exists():
                    shutil.copy(src_label_path, dest_label_path)

print("YOLO dataset ready at:", output_root)

In [19]:
import sys
sys.path.append("/kaggle/input/yolo-model-training-files")

import util
from dataset import Dataset

In [25]:
# train_fixed.py  (FINAL - with defaults)
import os
import sys
import yaml
import math
import shutil
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import numpy as np
import cv2
from PIL import Image

# ----------------------
# User-editable settings
# ----------------------
DATA_ROOT = os.environ.get("DATA_ROOT", "/kaggle/working/dataset")  # change if needed
TRAIN_IMG_SUB = os.path.join("images", "train")
VAL_IMG_SUB = os.path.join("images", "val")
LABELS_SUB = "labels"
INPUT_SIZE = 640
NUM_WORKERS = 0   # set to >0 after debug if desired
BATCH_SIZE = 4    # lower if OOM
EPOCHS = 4
CHECKPOINT_DIR = "weights"

# ----------------------
# Fixed Dataset class (no augmentation)
# ----------------------
class FixedDataset(torch.utils.data.Dataset):
    FORMATS = ('bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp')

    def __init__(self, filenames, input_size=640):
        self.filenames = list(filenames)
        self.input_size = input_size
        # load labels (with cache handling)
        self.labels = self._load_label(self.filenames)
        self.indices = range(len(self.filenames))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        img, (h0, w0) = self._load_image(img_path)
        img_resized, ratio, pad = resize(img, self.input_size, augment=False)
        # find label file path
        label_path = self._label_path_from_image(img_path)
        label = self.labels.get(img_path, np.zeros((0,5), dtype=np.float32)).copy()
        # if label.size:
        #     # after resize+pad, convert normalized xywh -> absolute xyxy, then later to normalized xywh relative to resized image
        #     label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w0, ratio[1] * h0, pad[0], pad[1])
        # else:
        #     label = np.zeros((0,5), dtype=np.float32)

        nl = len(label)
        h_res, w_res = img_resized.shape[:2]
        cls = label[:, 0:1] if nl else np.zeros((0,1), dtype=np.float32)
        box = label[:, 1:5] if nl else np.zeros((0,4), dtype=np.float32)
        box = xy2wh(box, w_res, h_res)  # normalized xywh w.r.t resized image

        # make consistent: cls (n,1), box (n,4)
        if nl:
            target_cls = torch.from_numpy(cls).float().view(-1,1)
            target_box = torch.from_numpy(box).float().view(-1,4)
            idx_tensor = torch.zeros(nl, dtype=torch.long)
        else:
            target_cls = torch.zeros((0,1), dtype=torch.float32)
            target_box = torch.zeros((0,4), dtype=torch.float32)
            idx_tensor = torch.zeros((0,), dtype=torch.long)

        # convert image to CHW RGB normalized tensor
        img_rgb = img_resized[:, :, ::-1].transpose((2,0,1)).astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(np.ascontiguousarray(img_rgb))

        return img_tensor, target_cls, target_box, idx_tensor

    @staticmethod
    def collate_fn(batch):
        samples, cls_list, box_list, idx_list = zip(*batch)
        imgs = torch.stack(samples, dim=0)

        # normalize cls to (n,1)
        cls_fixed = []
        for c in cls_list:
            if isinstance(c, np.ndarray):
                c = torch.from_numpy(c)
            if not isinstance(c, torch.Tensor):
                c = torch.tensor(c)
            if c.numel() == 0:
                c = c.reshape(0,1)
            elif c.dim() == 1:
                c = c.view(-1,1)
            else:
                c = c.view(-1, c.shape[-1])
            cls_fixed.append(c.float())

        # normalize box to (n,4)
        box_fixed = []
        for b in box_list:
            if isinstance(b, np.ndarray):
                b = torch.from_numpy(b)
            if not isinstance(b, torch.Tensor):
                b = torch.tensor(b)
            if b.numel() == 0:
                b = b.reshape(0,4)
            elif b.dim() == 1:
                if b.numel() == 4:
                    b = b.view(1,4)
                else:
                    b = b.view(-1,4)
            else:
                b = b.view(-1,4)
            box_fixed.append(b.float())

        cls = torch.cat(cls_fixed, dim=0) if sum(c.numel() for c in cls_fixed) > 0 else torch.zeros((0,1), dtype=torch.float32)
        box = torch.cat(box_fixed, dim=0) if sum(b.numel() for b in box_fixed) > 0 else torch.zeros((0,4), dtype=torch.float32)

        # build idx offsets
        new_idx = []
        for i, it in enumerate(idx_list):
            if isinstance(it, np.ndarray):
                it = torch.from_numpy(it)
            if not isinstance(it, torch.Tensor):
                it = torch.tensor(it)
            if it.numel() == 0:
                new_idx.append(it.reshape(0))
            else:
                new_idx.append((it.long() + i))
        idx = torch.cat(new_idx, dim=0) if any(n.numel() for n in new_idx) else torch.zeros((0,), dtype=torch.long)

        targets = {'cls': cls, 'box': box, 'idx': idx}
        return imgs, targets

    # -------------- helper I/O --------------
    @staticmethod
    def _label_path_from_image(img_path):
        a = os.sep + "images" + os.sep
        b = os.sep + "labels" + os.sep
        if a in img_path:
            return b.join(img_path.rsplit(a,1)).rsplit('.',1)[0] + '.txt'
        else:
            # fallback: replace images with labels in path
            return img_path.replace(os.sep + "images" + os.sep, os.sep + "labels" + os.sep).rsplit('.',1)[0] + '.txt'

    @staticmethod
    def _load_image(path):
        img = cv2.imread(path)
        if img is None:
            raise FileNotFoundError(path)
        h,w = img.shape[:2]
        return img, (h,w)

    def _load_label(self, filenames):
        cache_path = f"{os.path.dirname(filenames[0])}.cache"
        # try load safe, otherwise rebuild
        if os.path.exists(cache_path):
            try:
                data = torch.load(cache_path)
                if isinstance(data, dict):
                    return data
                else:
                    print(f"[FixedDataset] cache {cache_path} invalid type; rebuilding.")
                    os.remove(cache_path)
            except Exception as e:
                print(f"[FixedDataset] failed to load cache {cache_path}: {e}. Rebuilding.")
                try:
                    os.remove(cache_path)
                except:
                    pass
    
        labels = {}
        for img_path in filenames:
            try:
                with open(img_path, 'rb') as f:
                    im = Image.open(f)
                    im.verify()
                shape = im.size
                if not ((shape[0] > 9) and (shape[1] > 9)):
                    # Skip small images (do not add to labels)
                    continue
                if im.format is None or im.format.lower() not in FixedDataset.FORMATS:
                    # Skip unsupported formats
                    continue
            except Exception:
                # Skip images that can't be opened or verified
                continue
    
            # Construct label path by replacing /images/ with /labels/ and changing extension to .txt
            img_dir, img_file = os.path.split(img_path)
            label_dir = img_dir.replace('/images/', '/labels/')
            label_file = os.path.splitext(img_file)[0] + '.txt'
            label_path = os.path.join(label_dir, label_file)
    
            if os.path.isfile(label_path):
                good = []
                with open(label_path, 'r') as lf:
                    for line in lf.read().strip().splitlines():
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            try:
                                vals = [float(x) for x in parts[:5]]
                                good.append(vals)
                            except:
                                continue
                if len(good):
                    arr = np.array(good, dtype=np.float32)
                    if arr.shape[1] != 5:
                        arr = arr[:, :5]
                    arr[:,1:] = np.clip(arr[:,1:], 0.0, 1.0)
                    labels[img_path] = arr
                # else no labels -> skip this image (do not add to labels)
            # else no label file -> skip this image (do not add to labels)
    
        try:
            torch.save(labels, cache_path)
        except Exception as e:
            print(f"[FixedDataset] warning: failed to write cache: {e}")
    
        return labels
     


# ----------------------
# small helper functions (resize, wh2xy, xy2wh)
# ----------------------
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
    y = np.copy(x)
    y[:,0] = w * (x[:,0] - x[:,2]/2) + pad_w
    y[:,1] = h * (x[:,1] - x[:,3]/2) + pad_h
    y[:,2] = w * (x[:,0] + x[:,2]/2) + pad_w
    y[:,3] = h * (x[:,1] + x[:,3]/2) + pad_h
    return y

def xy2wh(x, w, h):
    if x.size == 0:
        return x.reshape((0,4))
    x[:,[0,2]] = x[:,[0,2]].clip(0, w - 1e-3)
    x[:,[1,3]] = x[:,[1,3]].clip(0, h - 1e-3)
    y = np.copy(x)
    y[:,0] = ((x[:,0] + x[:,2]) / 2) / w
    y[:,1] = ((x[:,1] + x[:,3]) / 2) / h
    y[:,2] = (x[:,2] - x[:,0]) / w
    y[:,3] = (x[:,3] - x[:,1]) / h
    return y

def resample():
    choices = (cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LINEAR,
               cv2.INTER_NEAREST, cv2.INTER_LANCZOS4)
    return np.random.choice(choices)

def resize(image, input_size, augment=False):
    shape = image.shape[:2]  # h,w
    r = min(input_size / shape[0], input_size / shape[1])
    if not augment:
        r = min(r, 1.0)
    new_w = int(round(shape[1] * r))
    new_h = int(round(shape[0] * r))
    if (new_w, new_h) != (shape[1], shape[0]):
        image = cv2.resize(image, (new_w, new_h), interpolation=resample())
    pad_w = (input_size - new_w) / 2
    pad_h = (input_size - new_h) / 2
    top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
    left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
    image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0,0,0))
    return image, (r, r), (left, top)

# ----------------------
# Training entrypoint
# ----------------------
def main():
    # locate train images
    train_img_dir = os.path.join(DATA_ROOT, TRAIN_IMG_SUB)
    if not os.path.isdir(train_img_dir):
        raise FileNotFoundError(f"Train image folder not found: {train_img_dir}")

    img_exts = ('.jpg', '.jpeg', '.png', '.bmp')
    filenames_train = [os.path.join(train_img_dir, f) for f in sorted(os.listdir(train_img_dir)) if f.lower().endswith(img_exts)]
    if len(filenames_train) == 0:
        raise RuntimeError(f"No training images found in {train_img_dir}")

    print(f"Found {len(filenames_train)} training images.")

    # load args.yaml (if exists)
    params = {}
    args_path = os.path.join("utils", "args.yaml")
    if os.path.isfile(args_path):
        with open(args_path) as f:
            params = yaml.safe_load(f)
    else:
        print("[train_fixed] Warning: utils/args.yaml not found, using defaults.")

    # --- SET DEFAULTS FOR KEYS REQUIRED BY ComputeLoss ---
    defaults = {
        'box': 7.5,
        'cls': 0.5,
        'dfl': 1.5,
        'min_lr': 1e-4,
        'max_lr': 1e-2,
        'momentum': 0.937,
        'weight_decay': 5e-4,
        'warmup_epochs': 3.0,
        'hsv_h': 0.015,
        'hsv_s': 0.7,
        'hsv_v': 0.4,
        'degrees': 0.0,
        'translate': 0.1,
        'scale': 0.5,
        'shear': 0.0,
        'flip_ud': 0.0,
        'flip_lr': 0.5,
        'mosaic': 0.0,
        'mix_up': 0.0,
        'img_size': INPUT_SIZE,
        'S': 20,
        'B': 2,
        'C': 2
    }
    for k, v in defaults.items():
        params.setdefault(k, v)

    # small summary print
    print("Using params (sample):")
    for k in ['box', 'cls', 'dfl', 'lr', 'min_lr', 'max_lr', 'weight_decay', 'img_size']:
        print(f"  {k}: {params.get(k)}")

    batch_size = params.get("batch_size", BATCH_SIZE)
    epochs = params.get("epochs", EPOCHS)
    lr = params.get("lr", 5e-4)

    # dataset & loader
    ds = FixedDataset(filenames_train, input_size=INPUT_SIZE)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, collate_fn=FixedDataset.collate_fn, pin_memory=True)
    print("Train loader batches:", len(loader))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MyYolo(version='n')  # adapt if your constructor differs
    model = model.to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f} M")

    # pass params (now guaranteed to contain 'box','cls','dfl', etc.)
    criterion = util.ComputeLoss(model, params)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=params.get("weight_decay", 5e-4))

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    # training loop
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for batch_idx, (imgs, targets) in enumerate(loader):
            imgs = imgs.float().to(device)
            # move targets to device
            targets = {k: v.to(device) for k,v in targets.items()}

            outputs = model(imgs)  # use your model forward signature
            losses = criterion(outputs, targets)  # returns tuple/list
            loss = sum(losses)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if (batch_idx + 1) % 20 == 0:
                avg = running_loss / (batch_idx + 1)
                print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(loader)} | avg_loss: {avg:.4f}")

        epoch_avg = running_loss / max(1, len(loader))
        print(f"Epoch {epoch+1} finished. Avg Loss: {epoch_avg:.4f}")

        # checkpoint
        ckpt = {"model": model.state_dict(), "epoch": epoch+1, "optimizer": optimizer.state_dict()}
        torch.save(ckpt, os.path.join(CHECKPOINT_DIR, f"epoch{epoch+1}.pt"))
        torch.save(ckpt, os.path.join(CHECKPOINT_DIR, "last.pt"))

    print("Training complete.")

if __name__ == "__main__":
    main()

Found 5394 training images.
Using params (sample):
  box: 7.5
  cls: 0.5
  dfl: 1.5
  lr: None
  min_lr: 0.0001
  max_lr: 0.01
  weight_decay: 0.0005
  img_size: 640
[FixedDataset] failed to load cache /kaggle/working/dataset/images/train.cache: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray._reconstruct was not an allowed global by default. Please use `torch.serialization.add_safe_global