# =============================================================================
# GFL Prototype for P&ID Symbol Detection
# Implementation of Kim et al.'s two-network strategy with GFL
# =============================================================================

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import json
import os
from tqdm.auto import tqdm
from pathlib import Path
import warnings
import ipywidgets as widgets
widgets.IntSlider()


# Set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: mps


# --- Data Loader Initialization ---
# Insert this cell here to set up your dataset and dataloader.

In [16]:
# Path: notebooks/modeling/gfl_prototype.ipynb
# Change: Insert after "## Data Loader Integration" markdown

import sys
sys.path.append(r'/Users/mokshdutt/developer/P&ID/src/preprocessing')  # Adjust if your path differs

from coco_dataloader import CocoDetectionDataset
from torch.utils.data import DataLoader

In [17]:
# Add this cell right before DataLoader initialization
def collate_padd(batch):
    images, targets = zip(*batch)
    max_h = max(img.shape[-2] for img in images)
    max_w = max(img.shape[-1] for img in images)
    
    padded_images = []
    for img in images:
        padding = (0, max_w - img.shape[-1], 0, max_h - img.shape[-2])
        padded_images.append(torch.nn.functional.pad(img, padding))
    
    return torch.stack(padded_images), list(targets)


In [18]:
# Path: notebooks/modeling/gfl_prototype.ipynb
# Change: Insert after Data Loader import cell
# Create transform pipeline
transform = transforms.Compose([
    transforms.ToTensor(),  # Critical for converting PIL to tensor
    # Add other transforms here if needed (resize, normalize, etc.)
])

dataset = CocoDetectionDataset(
    root=r'/Users/mokshdutt/developer/P&ID/data/raw/paliwal_dataset/Images',
    annFile=r'/Users/mokshdutt/developer/P&ID/data/processed/annotations/paliwal_coco.json',
    transform= transform
)

dataloader = DataLoader(
    dataset, batch_size=1, shuffle=True, collate_fn=collate_padd, pin_memory=False, num_workers=0, persistent_workers=False)
    

print(f"Total samples: {len(dataset)}")

Total samples: 500


# =============================================================================
# 1. GFL Loss Functions Implementation
# =============================================================================

In [19]:
class QualityFocalLoss(nn.Module):
    """Quality Focal Loss implementation for GFL detector[4][9]."""
    
    def __init__(self, beta=2.0, reduction='mean'):
        super(QualityFocalLoss, self).__init__()
        self.beta = beta
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        """
        Args:
            inputs: Predicted logits [N, num_classes]
            targets: Target IoU-aware classification scores [N, num_classes]
        """
        probs = torch.sigmoid(inputs)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        modulating_factor = torch.pow(torch.abs(targets - probs), self.beta)
        loss = modulating_factor * ce_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

class DistributionFocalLoss(nn.Module):
    """Distribution Focal Loss for bounding box regression[4][9]."""
    
    def __init__(self, reduction='mean'):
        super(DistributionFocalLoss, self).__init__()
        self.reduction = reduction
    
    def forward(self, pred, target):
        """
        Args:
            pred: Predicted distribution [N, n+1] before softmax
            target: Target distance values [N]
        """
        dis_left = target.long()
        dis_right = dis_left + 1
        weight_left = dis_right.float() - target
        weight_right = target - dis_left.float()
        
        loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left + \
               F.cross_entropy(pred, dis_right, reduction='none') * weight_right
        
        if self.reduction == 'mean':
            return loss.mean()
        return loss

# =============================================================================
# 2. ResNet Backbone Implementation
# =============================================================================

In [20]:
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)  # In-place activation
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        identity = x
    
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)  # In-place operation
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        identity = self.shortcut(identity)  # Apply shortcut if exists
        out += identity
        return self.relu(out)  # In-place final activation

class ResNetBackbone(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNetBackbone, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # ADD THIS LINE
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Feature pyramid levels
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    
    # ADD THIS METHOD IF MISSING
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Extract multi-scale features
        c1 = self.relu(self.bn1(self.conv1(x)))  # Now uses self.relu
        c1 = self.maxpool(c1)
        
        c2 = self.layer1(c1)  # 1/4 resolution
        c3 = self.layer2(c2)  # 1/8 resolution
        c4 = self.layer3(c3)  # 1/16 resolution
        c5 = self.layer4(c4)  # 1/32 resolution
        
        return [c2, c3, c4, c5]

# =============================================================================
# 3. GFL Detection Head Implementation
# =============================================================================

In [21]:
class GFLHead(nn.Module):
    """GFL detection head with quality-aware classification and distribution regression[4][9]."""
    
    def __init__(self, in_channels=256, num_classes=8, num_convs=4, reg_max=16):
        super(GFLHead, self).__init__()
        self.num_classes = num_classes
        self.reg_max = reg_max
        
        # Shared convolutions
        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        
        for i in range(num_convs):
            chn = in_channels if i == 0 else in_channels
            self.cls_convs.append(
                nn.Conv2d(chn, in_channels, 3, stride=1, padding=1))
            self.reg_convs.append(
                nn.Conv2d(chn, in_channels, 3, stride=1, padding=1))
        
        # Output layers
        self.gfl_cls = nn.Conv2d(in_channels, num_classes, 3, padding=1)
        self.gfl_reg = nn.Conv2d(in_channels, 4 * (reg_max + 1), 3, padding=1)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize layer weights[4]."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, feats):
        """Forward pass through GFL head."""
        cls_outputs = []
        reg_outputs = []
        
        for feat in feats:
            cls_feat = feat
            reg_feat = feat
            
            # Classification branch
            for cls_conv in self.cls_convs:
                cls_feat = F.relu(cls_conv(cls_feat))
            cls_score = self.gfl_cls(cls_feat)
            
            # Regression branch
            for reg_conv in self.reg_convs:
                reg_feat = F.relu(reg_conv(reg_feat))
            bbox_pred = self.gfl_reg(reg_feat)
            
            cls_outputs.append(cls_score)
            reg_outputs.append(bbox_pred)
        
        return cls_outputs, reg_outputs

# =============================================================================
# 4. Complete GFL Detector
# =============================================================================

In [22]:
class GFLDetector(nn.Module):
    """Complete GFL detector combining ResNet backbone and GFL head[4][11]."""
    
    def __init__(self, num_classes=8, reg_max=16):
        super(GFLDetector, self).__init__()
        
        # ResNet-50 backbone
        self.backbone = ResNetBackbone(BasicBlock, [3, 4, 6, 3], num_classes)
        
        # Feature Pyramid Network (simplified)
        self.fpn = nn.ModuleList([
            nn.Conv2d(64, 256, 1),  # P2
            nn.Conv2d(128, 256, 1),  # P3
            nn.Conv2d(256, 256, 1),  # P4
            nn.Conv2d(512, 256, 1),  # P5
        ])
        
        # GFL detection head
        self.head = GFLHead(256, num_classes, reg_max=reg_max)
        
        # Loss functions
        self.qfl_loss = QualityFocalLoss()
        self.dfl_loss = DistributionFocalLoss()
    
    def forward(self, x, targets=None):
        # Backbone feature extraction
        backbone_feats = self.backbone(x)
        
        # FPN processing
        fpn_feats = []
        for i, feat in enumerate(backbone_feats):
            fpn_feats.append(self.fpn[i](feat))
        
        # Detection head
        cls_outputs, reg_outputs = self.head(fpn_feats)
        
        if self.training and targets is not None:
            # Calculate losses during training
            return self.calculate_losses(cls_outputs, reg_outputs, targets)
        else:
            # Return predictions during inference
            return cls_outputs, reg_outputs
    
    def calculate_losses(self, cls_outputs, reg_outputs, targets):
        """Calculate GFL losses[4]."""
        # This is a simplified loss calculation
        # In practice, you'd need proper anchor generation and matching
        total_loss = 0
        
        for cls_out, reg_out in zip(cls_outputs, reg_outputs):
            # Dummy loss calculation - replace with proper implementation
            cls_loss = self.qfl_loss(cls_out.flatten(2).permute(0, 2, 1).flatten(0, 1), 
                                   torch.zeros_like(cls_out.flatten(2).permute(0, 2, 1).flatten(0, 1)))
            total_loss += cls_loss
        
        return {'total_loss': total_loss, 'cls_loss': cls_loss}


# =============================================================================
# 5. Data Preprocessing for Two-Network Strategy
# =============================================================================

In [23]:
class PIDSymbolDataset(Dataset):
    def __init__(self, data_dir, mode='small', transform=None, symbol_size_threshold=700):
        self.data_dir = Path(data_dir)
        self.mode = mode
        self.threshold = symbol_size_threshold
        # Compose transforms with ToTensor first
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            *([transform] if transform else [])
        ])
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        paliwal_dir = self.data_dir / 'raw' / 'paliwal_dataset'
        for folder in paliwal_dir.iterdir():
            if folder.is_dir():
                try:
                    symbols_file = folder / 'symbols.npy'
                    if symbols_file.exists():
                        symbols_data = np.load(symbols_file, allow_pickle=True).item()
                        filtered_symbols = self._filter_by_size(symbols_data)
                        if filtered_symbols:
                            samples.append({
                                'folder': folder,
                                'symbols': symbols_data,
                                'filtered_indices': filtered_symbols
                            })
                except Exception as e:
                    print(f"Error loading {folder}: {e}")
        print(f"Loaded {len(samples)} samples for {self.mode} network")
        return samples

    def _filter_by_size(self, symbols_data):
        filtered = []
        if 'bounding_box' in symbols_data:
            for i, bbox in enumerate(symbols_data['bounding_box']):
                if len(bbox) >= 4:
                    x1, y1, x2, y2 = bbox[:4]
                    diagonal = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
                    if self.mode == 'small' and diagonal <= self.threshold:
                        filtered.append(i)
                    elif self.mode == 'large' and diagonal > self.threshold:
                        filtered.append(i)
        return filtered

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = sample['folder'] / 'image.jpg'  # Adjust if your image filenames differ
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)  # Now a tensor
        symbols = sample['symbols']
        filtered = sample['filtered_indices']
        boxes = []
        labels = []
        for i in filtered:
            if i < len(symbols['bounding_box']) and i < len(symbols['class_ids']):
                x1, y1, x2, y2 = symbols['bounding_box'][i][:4]
                class_id = symbols['class_ids'][i]
                boxes.append([x1, y1, x2, y2])
                labels.append(class_id)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.long)
        target = {'boxes': boxes, 'labels': labels}
        return image, target


# =============================================================================
# 6. Adaptive NMS Implementation
# =============================================================================

In [24]:
def adaptive_nms(boxes, scores, labels, iou_thresholds, score_threshold=0.05):
    """
    Adaptive Non-Maximum Suppression with class-specific IoU thresholds[24][25].
    
    Args:
        boxes: Tensor [N, 4] in (x1, y1, x2, y2) format
        scores: Tensor [N] detection scores
        labels: Tensor [N] class labels
        iou_thresholds: Dict mapping class_id to IoU threshold
        score_threshold: Minimum score threshold
    """
    # Filter by score threshold
    valid_mask = scores > score_threshold
    boxes = boxes[valid_mask]
    scores = scores[valid_mask]
    labels = labels[valid_mask]
    
    if len(boxes) == 0:
        return torch.empty(0, dtype=torch.long)
    
    # Sort by scores in descending order
    sorted_indices = torch.argsort(scores, descending=True)
    boxes = boxes[sorted_indices]
    scores = scores[sorted_indices]
    labels = labels[sorted_indices]
    
    keep = []
    suppressed = torch.zeros(len(boxes), dtype=torch.bool)
    
    for i in range(len(boxes)):
        if suppressed[i]:
            continue
            
        keep.append(sorted_indices[i])
        current_box = boxes[i:i+1]
        current_label = labels[i]
        
        # Get class-specific IoU threshold
        iou_threshold = iou_thresholds.get(current_label.item(), 0.5)
        
        # Calculate IoU with remaining boxes of the same class
        remaining_boxes = boxes[i+1:]
        remaining_labels = labels[i+1:]
        same_class_mask = remaining_labels == current_label
        
        if same_class_mask.any():
            same_class_boxes = remaining_boxes[same_class_mask]
            ious = calculate_iou(current_box, same_class_boxes)
            
            # Suppress boxes with high IoU
            suppress_mask = ious > iou_threshold
            
            # Map back to global indices
            global_suppress_indices = torch.where(same_class_mask)[0] + i + 1
            suppressed[global_suppress_indices[suppress_mask]] = True
    
    return torch.tensor(keep, dtype=torch.long)

def calculate_iou(box1, box2):
    """Calculate IoU between box1 and box2[29]."""
    # box1: [1, 4], box2: [N, 4]
    
    # Calculate intersection
    x1_max = torch.max(box1[:, 0], box2[:, 0])
    y1_max = torch.max(box1[:, 1], box2[:, 1])
    x2_min = torch.min(box1[:, 2], box2[:, 2])
    y2_min = torch.min(box1[:, 3], box2[:, 3])
    
    intersection = torch.clamp(x2_min - x1_max, min=0) * torch.clamp(y2_min - y1_max, min=0)
    
    # Calculate areas
    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
    
    # Calculate union
    union = area1 + area2 - intersection
    
    return intersection / (union + 1e-6)


# =============================================================================
# 7. Training Setup and Utilities
# =============================================================================


In [25]:
def create_gfl_models(num_classes=8, reg_max=16, device='mps'):
    """
    Create both small and large symbol detection networks.
    Returns models moved to the specified device.
    """
    # Small symbol network (for patches)
    small_net = GFLDetector(num_classes=num_classes, reg_max=reg_max).to(device)

    # Large symbol network (for full images)
    large_net = GFLDetector(num_classes=num_classes, reg_max=reg_max).to(device)

    return small_net, large_net
    
small_net, large_net = create_gfl_models(num_classes=8, device=device)
print(f"Models initialized on {device}")

RuntimeError: MPS backend out of memory (MPS allocated: 8.15 GB, other allocations: 672.00 KB, max allowed: 1.07 GB). Tried to allocate 36.75 KB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
def setup_training(model, learning_rate=1e-4):
    """
    Setup AdamW optimizer and MultiStepLR scheduler for training.
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 11], gamma=0.1)
    return optimizer, scheduler


In [None]:
def train_epoch(model, dataloader, optimizer, device, grad_clip=None, grad_accum=1, scaler=None, scheduler=None):
    model.train()
    total_loss = 0
    accumulated_loss = 0
    optimizer.zero_grad()
    
    from tqdm.auto import tqdm
    progress_bar = tqdm(dataloader, desc=f"Training ({device.type.upper()})", leave=False)

    for batch_idx, (images, targets) in enumerate(progress_bar):
        # === MEMORY FIX 1: Clear cache BEFORE processing ===
        if device.type == 'mps':
            torch.mps.empty_cache()
        elif device.type == 'cuda':
            torch.cuda.empty_cache()
        
        images = images.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        with torch.autocast(device_type=device.type, enabled=scaler is not None):
            outputs = model(images, targets)
            loss = outputs['total_loss'] / grad_accum

        if scaler:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        
        # === MEMORY FIX 2: Synchronize AFTER backward ===
        if device.type == 'mps':
            torch.mps.synchronize()

        accumulated_loss += loss.item() * grad_accum

        if (batch_idx + 1) % grad_accum == 0:
            if grad_clip:
                if scaler:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            if scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
                
            optimizer.zero_grad()
            progress_bar.set_postfix(loss=accumulated_loss)
            total_loss += accumulated_loss
            accumulated_loss = 0
            
            # === MEMORY FIX 3: Clear cache AFTER optimizer step ===
            if device.type == 'mps':
                torch.mps.empty_cache()

    if accumulated_loss > 0:
        if scaler:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
        optimizer.zero_grad()
        total_loss += accumulated_loss

    avg_loss = total_loss / len(dataloader)
    return avg_loss


In [None]:
# === Train GFL Model on Real Data ===
num_epochs = 1  # Adjust as needed

# Initialize models on the correct device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
small_net, large_net = create_gfl_models(num_classes=8, device=device)

# Use small_net or large_net
model = small_net

# Setup training
optimizer, scheduler = setup_training(model, learning_rate=1e-4)

# Training loop
for epoch in range(num_epochs):
    avg_loss = train_epoch(model, dataloader, optimizer, device)
    scheduler.step()
    print(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.4f}")


In [None]:
def train_epoch(model, dataloader, optimizer, device, grad_clip=None, grad_accum=1, scaler=None, scheduler=None):
    model.train()
    total_loss = 0
    accumulated_loss = 0
    optimizer.zero_grad()
    
    from tqdm.auto import tqdm
    progress_bar = tqdm(dataloader, desc=f"Training ({device.type.upper()})", leave=False)

    for batch_idx, (images, targets) in enumerate(progress_bar):
        # Images are already padded and stacked by collate_fn
        images = images.to(device)  # Direct tensor transfer
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Mixed precision context
        with torch.autocast(device_type=device.type, enabled=scaler is not None):
            outputs = model(images, targets)
            loss = outputs['total_loss'] / grad_accum

        # Backward pass and optimization
        if scaler:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        accumulated_loss += loss.item() * grad_accum

        # Gradient accumulation
        if (batch_idx + 1) % grad_accum == 0:
            if grad_clip:
                if scaler:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            if scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
                
            optimizer.zero_grad()
            progress_bar.set_postfix(loss=accumulated_loss)
            total_loss += accumulated_loss
            accumulated_loss = 0

        # Scheduler step
        if scheduler and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step()

    # Handle remaining gradients
    if accumulated_loss > 0:
        if scaler:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
        optimizer.zero_grad()
        total_loss += accumulated_loss

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
def validate_epoch(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, targets in dataloader:
            # Verify tensor type
            if not isinstance(images[0], torch.Tensor):
                raise TypeError("Images must be tensors. Check dataset transforms.")
                
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            outputs = model(images, targets)
            loss = outputs['total_loss']
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Validation loss: {avg_loss:.4f}")
    return avg_loss


In [None]:
validate_epoch(model, dataloader, device)  # Use a separate val_dataloader if you split your data


# =============================================================================
# 8. Prototype Testing and Visualization
# =============================================================================

In [None]:
def test_gfl_prototype():
    """Test GFL implementation with MPS/device-aware dummy data."""
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    # Create models and move to device
    small_net, large_net = create_gfl_models(num_classes=8, device=device)
    
    print(f"Small network parameters: {sum(p.numel() for p in small_net.parameters()):,}")
    print(f"Large network parameters: {sum(p.numel() for p in large_net.parameters()):,}")
    
    # Create dummy input ON DEVICE
    dummy_input = torch.randn(2, 3, 512, 512).to(device)
    
    with torch.no_grad():
        small_net.eval()
        cls_outputs, reg_outputs = small_net(dummy_input)
        print(f"Number of feature levels: {len(cls_outputs)}")
        for i, (cls_out, reg_out) in enumerate(zip(cls_outputs, reg_outputs)):
            print(f"Level {i}: cls_shape={cls_out.shape}, reg_shape={reg_out.shape}")

    # Test loss functions with device-aware tensors
    qfl = QualityFocalLoss().to(device)
    dfl = DistributionFocalLoss().to(device)
    
    dummy_cls_pred = torch.randn(100, 8).to(device)
    dummy_cls_target = torch.rand(100, 8).to(device)
    dummy_reg_pred = torch.randn(100, 17).to(device)
    dummy_reg_target = torch.rand(100).to(device) * 16
    
    qfl_loss = qfl(dummy_cls_pred, dummy_cls_target)
    dfl_loss = dfl(dummy_reg_pred, dummy_reg_target)
    
    print(f"QFL Loss: {qfl_loss.item():.4f}")
    print(f"DFL Loss: {dfl_loss.item():.4f}")


In [None]:
def visualize_architecture():
    """Visualize the GFL architecture."""
    print("\n" + "="*60)
    print("GFL DETECTOR ARCHITECTURE")
    print("="*60)
    print("""
    Input Image
         ↓
    ResNet Backbone (Feature Extraction)
         ↓
    Feature Pyramid Network
         ↓
    GFL Detection Head
    ├── Classification Branch (QFL)
    └── Regression Branch (DFL)
         ↓
    Post-processing (Adaptive NMS)
         ↓
    Final Detections
    """)
    print("\nTwo-Network Strategy:")
    print("├── Small Symbol Network: Patch-based detection (symbols < 700px)")
    print("└── Large Symbol Network: Full-image detection (symbols > 700px)")

# Run sanity checks and visualization
test_gfl_prototype()
visualize_architecture()

In [None]:
def plot_image_with_boxes(img, boxes, labels=None, class_names=None):
    """Visualize image with bounding boxes, handling both CPU/MPS tensors and numpy arrays."""
    plt.figure(figsize=(10, 10))
    
    # Convert torch tensors to numpy arrays
    if isinstance(img, torch.Tensor):
        # Handle device transfer
        if img.device.type == 'mps':
            img = img.cpu()
        # Convert tensor to numpy and fix channel order
        img = img.detach().permute(1, 2, 0).numpy()  # CHW -> HWC
    
    # Convert boxes if they're tensors
    if isinstance(boxes, torch.Tensor):
        if boxes.device.type == 'mps':
            boxes = boxes.cpu()
        boxes = boxes.detach().numpy()
    
    # Convert labels if they're tensors
    if labels is not None and isinstance(labels, torch.Tensor):
        if labels.device.type == 'mps':
            labels = labels.cpu()
        labels = labels.detach().numpy()

    plt.imshow(img)
    ax = plt.gca()
    
    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = box  # Assume [x1, y1, x2, y2] format
        w = x2 - x1
        h = y2 - y1
        
        rect = plt.Rectangle((x1, y1), w, h, 
                            fill=False, color='red', linewidth=2)
        ax.add_patch(rect)
        
        if labels is not None and class_names is not None:
            label_idx = int(labels[i]) - 1  # Convert to 0-based index
            label = class_names[label_idx] if 0 <= label_idx < len(class_names) else str(labels[i])
            ax.text(x1, y1, label, color='white', fontsize=10, 
                    bbox=dict(facecolor='red', alpha=0.7, edgecolor='none'))

    plt.axis('off')
    plt.show()
