In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class YOLOv8Loss(nn.Module):
    """YOLOv8-style loss: Smooth L1 + Binary CE + Cross Entropy"""
    
    def __init__(self, num_classes=37, S=7, lambda_coord=5.0, lambda_obj=1.0, lambda_class=1.0):
        super(YOLOv8Loss, self).__init__()
        self.num_classes = num_classes
        self.S = S
        self.lambda_coord = lambda_coord
        self.lambda_obj = lambda_obj
        self.lambda_class = lambda_class
    
    def forward(self, predictions, targets):
        """
        predictions: (batch_size, S*S*(5 + C)) - [x,y,w,h,obj,classes...]
        targets: (batch_size, S, S, 5 + C) - [x,y,w,h,conf,classes...]
        """
        batch_size = predictions.size(0)
        
        # Reshape predictions to match targets
        predictions = predictions.view(batch_size, self.S, self.S, 5 + self.num_classes)
        
        # Split into components
        pred_bbox = predictions[..., :4]    # bounding boxes
        pred_obj = predictions[..., 4:5]    # objectness
        pred_class = predictions[..., 5:]   # classes
        
        target_bbox = targets[..., :4]
        target_obj = targets[..., 4:5]
        target_class = targets[..., 5:]
        
        # Create masks
        obj_mask = (target_obj > 0)         # cells with objects
        noobj_mask = (target_obj == 0)      # cells without objects
        
        # 1. Bounding Box Loss (Smooth L1) - only for positive cells
        bbox_loss = 0
        if obj_mask.sum() > 0:
            bbox_loss = F.smooth_l1_loss(
                pred_bbox[obj_mask.expand_as(pred_bbox)],
                target_bbox[obj_mask.expand_as(target_bbox)],
                reduction='sum'
            )
        
        # 2. Objectness Loss (Binary CE)
        obj_loss = 0
        if obj_mask.sum() > 0:
            obj_loss = F.binary_cross_entropy_with_logits(
                pred_obj[obj_mask], target_obj[obj_mask], reduction='sum'
            )
        
        noobj_loss = 0
        if noobj_mask.sum() > 0:
            noobj_loss = F.binary_cross_entropy_with_logits(
                pred_obj[noobj_mask], torch.zeros_like(pred_obj[noobj_mask]), reduction='sum'
            )
        
        total_obj_loss = obj_loss + noobj_loss
        
        # 3. Classification Loss (Cross Entropy) - only for positive cells
        class_loss = 0
        if obj_mask.sum() > 0:
            target_class_indices = torch.argmax(target_class, dim=-1)
            obj_mask_flat = obj_mask.squeeze(-1)
            
            pred_class_obj = pred_class[obj_mask_flat]
            target_class_obj = target_class_indices[obj_mask_flat]
            
            if pred_class_obj.numel() > 0:
                class_loss = F.cross_entropy(pred_class_obj, target_class_obj, reduction='sum')
        
        # Combine losses
        total_loss = (
            self.lambda_coord * bbox_loss +
            self.lambda_obj * total_obj_loss +
            self.lambda_class * class_loss
        )
        
        # Normalize by batch size
        total_loss = total_loss / batch_size
        bbox_loss = bbox_loss / batch_size
        total_obj_loss = total_obj_loss / batch_size
        class_loss = class_loss / batch_size
        
        return {
            'total_loss': total_loss,
            'bbox_loss': bbox_loss,
            'obj_loss': total_obj_loss,
            'class_loss': class_loss,
            'num_pos': obj_mask.sum().item(),
            'num_neg': noobj_mask.sum().item()
        }

In [None]:
class TargetPreparer:
    """Convert dataloader format to YOLO targets"""
    
    def __init__(self, S=7, num_classes=37, img_width=640, img_height=160):
        self.S = S
        self.num_classes = num_classes
        self.img_width = img_width
        self.img_height = img_height
    
    def __call__(self, batch):
        """Convert batch to YOLO target format (need to verify based on json)"""
        batch_size = len(batch['captcha_string'])
        targets = torch.zeros(batch_size, self.S, self.S, 5 + self.num_classes)
        
        for b in range(batch_size):
            if len(batch['bboxes'][b]) == 0:
                continue
                
            bboxes = batch['bboxes'][b]
            category_ids = batch['category_ids'][b]
            
            for i in range(len(bboxes)):
                x1, y1, x2, y2 = bboxes[i]
                class_id = category_ids[i].item()
                
                # Convert to 
                center_x = (x1 + x2) / 2.0 / self.img_width
                center_y = (y1 + y2) / 2.0 / self.img_height
                width = (x2 - x1) / self.img_width
                height = (y2 - y1) / self.img_height
                
                # Find grid cell
                grid_x = min(max(int(center_x * self.S), 0), self.S - 1)
                grid_y = min(max(int(center_y * self.S), 0), self.S - 1)
                
                # Relative position within grid cell
                rel_x = center_x * self.S - grid_x
                rel_y = center_y * self.S - grid_y
                
                # Set target values
                targets[b, grid_y, grid_x, 0] = rel_x
                targets[b, grid_y, grid_x, 1] = rel_y
                targets[b, grid_y, grid_x, 2] = width
                targets[b, grid_y, grid_x, 3] = height
                targets[b, grid_y, grid_x, 4] = 1.0  # confidence
                
                # One-hot class
                if 0 <= class_id < self.num_classes:
                    targets[b, grid_y, grid_x, 5 + class_id] = 1.0
        
        return targets

In [5]:
# Test with dummy data
def test_loss_function():
    batch_size = 4
    S = 7
    num_classes = 37
    
    # Create dummy predictions and targets
    predictions = torch.randn(batch_size, S * S * (5 + num_classes))
    targets = torch.zeros(batch_size, S, S, 5 + num_classes)
    
    # Add dummy objects
    targets[0, 2, 3, 4] = 1.0  # object at grid (2,3)
    targets[0, 2, 3, 5] = 1.0  # class 0
    targets[1, 1, 5, 4] = 1.0  # object at grid (1,5)
    targets[1, 1, 5, 10] = 1.0  # class 5
    
    # Test loss
    loss_fn = YOLOv8Loss(num_classes=num_classes, S=S)
    loss_dict = loss_fn(predictions, targets)
    
    print("YOLOv8 Loss Test Results:")
    print(f"Total Loss: {loss_dict['total_loss']:.4f}")
    print(f"BBox Loss: {loss_dict['bbox_loss']:.4f}")
    print(f"Obj Loss: {loss_dict['obj_loss']:.4f}")
    print(f"Class Loss: {loss_dict['class_loss']:.4f}")
    print(f"Positive samples: {loss_dict['num_pos']}")
    print(f"Negative samples: {loss_dict['num_neg']}")
    
    return loss_dict

# Run test
test_results = test_loss_function()

YOLOv8 Loss Test Results:
Total Loss: 44.0681
BBox Loss: 0.6995
Obj Loss: 39.0235
Class Loss: 1.5471
Positive samples: 2
Negative samples: 194


In [6]:
# Test target preparation (if you have dataloader available)
def test_target_preparer():
    # Example batch format (replace with your actual batch)
    dummy_batch = {
        'captcha_string': ['ABC', 'XYZ'],
        'bboxes': [
            torch.tensor([[100, 50, 150, 100], [200, 60, 250, 110]]),  # 2 objects
            torch.tensor([[300, 70, 350, 120]])  # 1 object
        ],
        'category_ids': [
            torch.tensor([0, 1]),  # classes for first image
            torch.tensor([5])      # class for second image
        ]
    }
    
    target_preparer = TargetPreparer(S=7, num_classes=37, img_width=640, img_height=160)
    targets = target_preparer(dummy_batch)
    
    print(f"Target shape: {targets.shape}")
    print(f"Objects found: {(targets[..., 4] > 0).sum().item()}")
    
    # Check first image targets
    obj_cells = torch.where(targets[0, ..., 4] > 0)
    print(f"Image 0 - Object cells: {list(zip(obj_cells[0].tolist(), obj_cells[1].tolist()))}")
    
    return targets

# Run target test
# target_results = test_target_preparer()

In [7]:
class YOLOTrainer:
    """Simple training wrapper"""
    
    def __init__(self, backbone, yolo_head, loss_fn, target_preparer, device='cuda'):
        self.backbone = backbone
        self.yolo_head = yolo_head
        self.loss_fn = loss_fn
        self.target_preparer = target_preparer
        self.device = device
        
        self.backbone.to(device)
        self.yolo_head.to(device)
        self.loss_fn.to(device)
    
    def train_step(self, batch, optimizer):
        """Single training step"""
        self.backbone.train()
        self.yolo_head.train()
        
        images = batch['image'].to(self.device)
        targets = self.target_preparer(batch).to(self.device)
        
        # Forward pass
        features = self.backbone(images)
        predictions = self.yolo_head(features)
        loss_dict = self.loss_fn(predictions, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss_dict['total_loss'].backward()
        optimizer.step()
        
        return {k: v.item() if torch.is_tensor(v) else v for k, v in loss_dict.items()}
    
    def validate_step(self, batch):
        """Single validation step"""
        self.backbone.eval()
        self.yolo_head.eval()
        
        with torch.no_grad():
            images = batch['image'].to(self.device)
            targets = self.target_preparer(batch).to(self.device)
            
            features = self.backbone(images)
            predictions = self.yolo_head(features)
            loss_dict = self.loss_fn(predictions, targets)
        
        return {k: v.item() if torch.is_tensor(v) else v for k, v in loss_dict.items()}

In [8]:
# Example usage with your models
def setup_training():
    """Setup training components"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize loss and target preparer
    loss_fn = YOLOv8Loss(num_classes=37, S=7, lambda_coord=5.0, lambda_obj=1.0, lambda_class=1.0)
    target_preparer = TargetPreparer(S=7, num_classes=37, img_width=640, img_height=160)
    
    # Initialize trainer (uncomment when you have your models)
    # trainer = YOLOTrainer(backbone, yolo_head, loss_fn, target_preparer, device)
    # optimizer = torch.optim.Adam(
    #     list(backbone.parameters()) + list(yolo_head.parameters()), 
    #     lr=0.001
    # )
    
    print(f"Device: {device}")
    print("✅ Training setup complete!")
    print("Ready to integrate with your ResNet-18 backbone and YOLO head")
    
    return loss_fn, target_preparer

loss_fn, target_preparer = setup_training()

Device: cuda
✅ Training setup complete!
Ready to integrate with your ResNet-18 backbone and YOLO head
