In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class YOLOv8Head(nn.Module):
    """
    YOLOv8-style detection head for CAPTCHA character detection
    
    Input: Feature maps from ResNet18 backbone (batch_size, 512, 5, 20)
    Output: Predictions (batch_size, S*S*(5 + num_classes))
           Format: [x, y, w, h, objectness, class_0, class_1, ...]
    """
    
    def __init__(self, in_channels=512, num_classes=37, S=7):
        super(YOLOv8Head, self).__init__()
        self.num_classes = num_classes
        self.S = S
        
        # Output channels: bbox(4) + objectness(1) + classes(num_classes)
        self.output_channels = 5 + num_classes
        
        # Feature processing layers
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(256)
        
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        
        # Final prediction layer
        self.pred_conv = nn.Conv2d(64, self.output_channels, kernel_size=1)
        
        # Adaptive pooling to ensure S×S output (7×7 grid)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((S, S))
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights using Kaiming initialization"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        """
        Forward pass
        
        Args:
            x: Feature maps from backbone (batch_size, 512, 5, 20)
            
        Returns:
            predictions: (batch_size, S*S*(5 + num_classes))
        """
        # Process features through conv layers
        x = F.relu(self.bn1(self.conv1(x)))  # (batch_size, 256, 5, 20)
        x = F.relu(self.bn2(self.conv2(x)))  # (batch_size, 128, 5, 20)
        x = F.relu(self.bn3(self.conv3(x)))  # (batch_size, 64, 5, 20)
        
        # Get predictions
        x = self.pred_conv(x)  # (batch_size, 5+num_classes, 5, 20)
        
        # Ensure S×S spatial dimensions (7×7 grid)
        x = self.adaptive_pool(x)  # (batch_size, 5+num_classes, 7, 7)
        
        # Reshape for loss function
        batch_size = x.size(0)
        x = x.permute(0, 2, 3, 1)  # (batch_size, 7, 7, 5+num_classes)
        x = x.view(batch_size, -1)  # (batch_size, 7*7*(5+num_classes))
        
        return x



In [None]:
# Simpler head
class SimpleYOLOv8Head(nn.Module):
    """Simpler version with fewer layers"""
    
    def __init__(self, in_channels=512, num_classes=37, S=7):
        super(SimpleYOLOv8Head, self).__init__()
        self.num_classes = num_classes
        self.S = S
        self.output_channels = 5 + num_classes
        
        # Single conv + prediction
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(128)
        self.pred_conv = nn.Conv2d(128, self.output_channels, kernel_size=1)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((S, S))
        
        # Initialize weights
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.pred_conv.weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)
        
    def forward(self, x):
        x = F.relu(self.bn(self.conv(x)))
        x = self.pred_conv(x)
        x = self.adaptive_pool(x)
        
        batch_size = x.size(0)
        x = x.permute(0, 2, 3, 1)
        x = x.view(batch_size, -1)
        
        return x