# =============================================================================
# GFL Prototype for P&ID Symbol Detection
# Implementation of Kim et al.'s two-network strategy with GFL
# =============================================================================

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import json
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


# =============================================================================
# 1. GFL Loss Functions Implementation
# =============================================================================

In [13]:
class QualityFocalLoss(nn.Module):
    """Quality Focal Loss implementation for GFL detector[4][9]."""
    
    def __init__(self, beta=2.0, reduction='mean'):
        super(QualityFocalLoss, self).__init__()
        self.beta = beta
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        """
        Args:
            inputs: Predicted logits [N, num_classes]
            targets: Target IoU-aware classification scores [N, num_classes]
        """
        probs = torch.sigmoid(inputs)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        modulating_factor = torch.pow(torch.abs(targets - probs), self.beta)
        loss = modulating_factor * ce_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

class DistributionFocalLoss(nn.Module):
    """Distribution Focal Loss for bounding box regression[4][9]."""
    
    def __init__(self, reduction='mean'):
        super(DistributionFocalLoss, self).__init__()
        self.reduction = reduction
    
    def forward(self, pred, target):
        """
        Args:
            pred: Predicted distribution [N, n+1] before softmax
            target: Target distance values [N]
        """
        dis_left = target.long()
        dis_right = dis_left + 1
        weight_left = dis_right.float() - target
        weight_right = target - dis_left.float()
        
        loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left + \
               F.cross_entropy(pred, dis_right, reduction='none') * weight_right
        
        if self.reduction == 'mean':
            return loss.mean()
        return loss

# =============================================================================
# 2. ResNet Backbone Implementation
# =============================================================================

In [14]:
class BasicBlock(nn.Module):
    """Basic ResNet block for feature extraction[17][18]."""
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, 
                              padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 
                              padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, 
                         stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNetBackbone(nn.Module):
    """ResNet backbone for GFL detector[17][20]."""
    
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNetBackbone, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Feature pyramid levels
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Extract multi-scale features
        c1 = F.relu(self.bn1(self.conv1(x)))
        c1 = self.maxpool(c1)
        
        c2 = self.layer1(c1)  # 1/4 resolution
        c3 = self.layer2(c2)  # 1/8 resolution
        c4 = self.layer3(c3)  # 1/16 resolution
        c5 = self.layer4(c4)  # 1/32 resolution
        
        return [c2, c3, c4, c5]

# =============================================================================
# 3. GFL Detection Head Implementation
# =============================================================================

In [15]:
class GFLHead(nn.Module):
    """GFL detection head with quality-aware classification and distribution regression[4][9]."""
    
    def __init__(self, in_channels=256, num_classes=8, num_convs=4, reg_max=16):
        super(GFLHead, self).__init__()
        self.num_classes = num_classes
        self.reg_max = reg_max
        
        # Shared convolutions
        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        
        for i in range(num_convs):
            chn = in_channels if i == 0 else in_channels
            self.cls_convs.append(
                nn.Conv2d(chn, in_channels, 3, stride=1, padding=1))
            self.reg_convs.append(
                nn.Conv2d(chn, in_channels, 3, stride=1, padding=1))
        
        # Output layers
        self.gfl_cls = nn.Conv2d(in_channels, num_classes, 3, padding=1)
        self.gfl_reg = nn.Conv2d(in_channels, 4 * (reg_max + 1), 3, padding=1)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize layer weights[4]."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, feats):
        """Forward pass through GFL head."""
        cls_outputs = []
        reg_outputs = []
        
        for feat in feats:
            cls_feat = feat
            reg_feat = feat
            
            # Classification branch
            for cls_conv in self.cls_convs:
                cls_feat = F.relu(cls_conv(cls_feat))
            cls_score = self.gfl_cls(cls_feat)
            
            # Regression branch
            for reg_conv in self.reg_convs:
                reg_feat = F.relu(reg_conv(reg_feat))
            bbox_pred = self.gfl_reg(reg_feat)
            
            cls_outputs.append(cls_score)
            reg_outputs.append(bbox_pred)
        
        return cls_outputs, reg_outputs

# =============================================================================
# 4. Complete GFL Detector
# =============================================================================

In [16]:
class GFLDetector(nn.Module):
    """Complete GFL detector combining ResNet backbone and GFL head[4][11]."""
    
    def __init__(self, num_classes=8, reg_max=16):
        super(GFLDetector, self).__init__()
        
        # ResNet-50 backbone
        self.backbone = ResNetBackbone(BasicBlock, [3, 4, 6, 3], num_classes)
        
        # Feature Pyramid Network (simplified)
        self.fpn = nn.ModuleList([
            nn.Conv2d(64, 256, 1),  # P2
            nn.Conv2d(128, 256, 1),  # P3
            nn.Conv2d(256, 256, 1),  # P4
            nn.Conv2d(512, 256, 1),  # P5
        ])
        
        # GFL detection head
        self.head = GFLHead(256, num_classes, reg_max=reg_max)
        
        # Loss functions
        self.qfl_loss = QualityFocalLoss()
        self.dfl_loss = DistributionFocalLoss()
    
    def forward(self, x, targets=None):
        # Backbone feature extraction
        backbone_feats = self.backbone(x)
        
        # FPN processing
        fpn_feats = []
        for i, feat in enumerate(backbone_feats):
            fpn_feats.append(self.fpn[i](feat))
        
        # Detection head
        cls_outputs, reg_outputs = self.head(fpn_feats)
        
        if self.training and targets is not None:
            # Calculate losses during training
            return self.calculate_losses(cls_outputs, reg_outputs, targets)
        else:
            # Return predictions during inference
            return cls_outputs, reg_outputs
    
    def calculate_losses(self, cls_outputs, reg_outputs, targets):
        """Calculate GFL losses[4]."""
        # This is a simplified loss calculation
        # In practice, you'd need proper anchor generation and matching
        total_loss = 0
        
        for cls_out, reg_out in zip(cls_outputs, reg_outputs):
            # Dummy loss calculation - replace with proper implementation
            cls_loss = self.qfl_loss(cls_out.flatten(2).permute(0, 2, 1).flatten(0, 1), 
                                   torch.zeros_like(cls_out.flatten(2).permute(0, 2, 1).flatten(0, 1)))
            total_loss += cls_loss
        
        return {'total_loss': total_loss, 'cls_loss': cls_loss}


# =============================================================================
# 5. Data Preprocessing for Two-Network Strategy
# =============================================================================

In [17]:
class PIDSymbolDataset(Dataset):
    """Dataset class for P&ID symbols implementing Kim et al.'s strategy[11]."""
    
    def __init__(self, data_dir, mode='small', transform=None, symbol_size_threshold=700):
        self.data_dir = Path(data_dir)
        self.mode = mode  # 'small' or 'large'
        self.transform = transform
        self.threshold = symbol_size_threshold
        self.samples = self._load_samples()
    
    def _load_samples(self):
        """Load and filter samples based on symbol size[11]."""
        samples = []
        
        # Load Paliwal dataset
        paliwal_dir = self.data_dir / 'raw' / 'paliwal_dataset'
        
        for folder in paliwal_dir.iterdir():
            if folder.is_dir():
                try:
                    # Load symbol annotations
                    symbols_file = folder / 'symbols.npy'
                    if symbols_file.exists():
                        symbols_data = np.load(symbols_file, allow_pickle=True).item()
                        
                        # Filter by symbol size
                        filtered_symbols = self._filter_by_size(symbols_data)
                        if filtered_symbols:
                            samples.append({
                                'folder': folder,
                                'symbols': filtered_symbols
                            })
                except Exception as e:
                    print(f"Error loading {folder}: {e}")
        
        print(f"Loaded {len(samples)} samples for {self.mode} network")
        return samples
    
    def _filter_by_size(self, symbols_data):
        """Filter symbols based on diagonal length[11]."""
        filtered = []
        
        if 'bounding_box' in symbols_data:
            for i, bbox in enumerate(symbols_data['bounding_box']):
                if len(bbox) >= 4:
                    x1, y1, x2, y2 = bbox[:4]
                    diagonal = np.sqrt((x2-x1)**2 + (y2-y1)**2)
                    
                    if self.mode == 'small' and diagonal <= self.threshold:
                        filtered.append(i)
                    elif self.mode == 'large' and diagonal > self.threshold:
                        filtered.append(i)
        
        return filtered
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image (this is simplified - you'd need actual image loading)
        # For now, return dummy data
        if self.mode == 'small':
            # Return patch-based data for small symbols
            image = torch.randn(3, 512, 512)  # Dummy patch
        else:
            # Return scaled-down full image for large symbols
            image = torch.randn(3, 1024, 1024)  # Dummy full image
        
        # Create dummy target
        target = {
            'boxes': torch.tensor([[100, 100, 200, 200]], dtype=torch.float32),
            'labels': torch.tensor([1], dtype=torch.long),
            'scores': torch.tensor([1.0], dtype=torch.float32)
        }
        
        if self.transform:
            image = self.transform(image)
        
        return image, target

# =============================================================================
# 6. Adaptive NMS Implementation
# =============================================================================

In [18]:
def adaptive_nms(boxes, scores, labels, iou_thresholds, score_threshold=0.05):
    """
    Adaptive Non-Maximum Suppression with class-specific IoU thresholds[24][25].
    
    Args:
        boxes: Tensor [N, 4] in (x1, y1, x2, y2) format
        scores: Tensor [N] detection scores
        labels: Tensor [N] class labels
        iou_thresholds: Dict mapping class_id to IoU threshold
        score_threshold: Minimum score threshold
    """
    # Filter by score threshold
    valid_mask = scores > score_threshold
    boxes = boxes[valid_mask]
    scores = scores[valid_mask]
    labels = labels[valid_mask]
    
    if len(boxes) == 0:
        return torch.empty(0, dtype=torch.long)
    
    # Sort by scores in descending order
    sorted_indices = torch.argsort(scores, descending=True)
    boxes = boxes[sorted_indices]
    scores = scores[sorted_indices]
    labels = labels[sorted_indices]
    
    keep = []
    suppressed = torch.zeros(len(boxes), dtype=torch.bool)
    
    for i in range(len(boxes)):
        if suppressed[i]:
            continue
            
        keep.append(sorted_indices[i])
        current_box = boxes[i:i+1]
        current_label = labels[i]
        
        # Get class-specific IoU threshold
        iou_threshold = iou_thresholds.get(current_label.item(), 0.5)
        
        # Calculate IoU with remaining boxes of the same class
        remaining_boxes = boxes[i+1:]
        remaining_labels = labels[i+1:]
        same_class_mask = remaining_labels == current_label
        
        if same_class_mask.any():
            same_class_boxes = remaining_boxes[same_class_mask]
            ious = calculate_iou(current_box, same_class_boxes)
            
            # Suppress boxes with high IoU
            suppress_mask = ious > iou_threshold
            
            # Map back to global indices
            global_suppress_indices = torch.where(same_class_mask)[0] + i + 1
            suppressed[global_suppress_indices[suppress_mask]] = True
    
    return torch.tensor(keep, dtype=torch.long)

def calculate_iou(box1, box2):
    """Calculate IoU between box1 and box2[29]."""
    # box1: [1, 4], box2: [N, 4]
    
    # Calculate intersection
    x1_max = torch.max(box1[:, 0], box2[:, 0])
    y1_max = torch.max(box1[:, 1], box2[:, 1])
    x2_min = torch.min(box1[:, 2], box2[:, 2])
    y2_min = torch.min(box1[:, 3], box2[:, 3])
    
    intersection = torch.clamp(x2_min - x1_max, min=0) * torch.clamp(y2_min - y1_max, min=0)
    
    # Calculate areas
    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
    
    # Calculate union
    union = area1 + area2 - intersection
    
    return intersection / (union + 1e-6)


# =============================================================================
# 7. Training Setup and Utilities
# =============================================================================


In [19]:
def create_gfl_models(num_classes=8):
    """Create both small and large symbol detection networks[11]."""
    
    # Small symbol network (for patches)
    small_net = GFLDetector(num_classes=num_classes, reg_max=16)
    
    # Large symbol network (for full images)
    large_net = GFLDetector(num_classes=num_classes, reg_max=16)
    
    return small_net, large_net

def setup_training(model, learning_rate=1e-4):
    """Setup optimizer and scheduler for training[4]."""
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 11], gamma=0.1)
    
    return optimizer, scheduler

def train_epoch(model, dataloader, optimizer, device):
    """Train for one epoch[4]."""
    model.train()
    total_loss = 0
    
    for batch_idx, (images, targets) in enumerate(dataloader):
        images = images.to(device)
        # Note: targets processing would be more complex in practice
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, targets)
        loss = outputs['total_loss']
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}')
    
    return total_loss / len(dataloader)

# =============================================================================
# 8. Prototype Testing and Visualization
# =============================================================================

In [20]:
def test_gfl_prototype():
    """Test the GFL implementation with dummy data[4]."""
    
    print("Testing GFL Prototype Implementation...")
    
    # Create models
    small_net, large_net = create_gfl_models(num_classes=8)
    
    print(f"Small network parameters: {sum(p.numel() for p in small_net.parameters()):,}")
    print(f"Large network parameters: {sum(p.numel() for p in large_net.parameters()):,}")
    
    # Test forward pass
    dummy_input = torch.randn(2, 3, 512, 512)
    
    with torch.no_grad():
        small_net.eval()
        cls_outputs, reg_outputs = small_net(dummy_input)
        
        print(f"Number of feature levels: {len(cls_outputs)}")
        for i, (cls_out, reg_out) in enumerate(zip(cls_outputs, reg_outputs)):
            print(f"Level {i}: cls_shape={cls_out.shape}, reg_shape={reg_out.shape}")
    
    # Test loss functions
    qfl = QualityFocalLoss()
    dfl = DistributionFocalLoss()
    
    dummy_cls_pred = torch.randn(100, 8)
    dummy_cls_target = torch.rand(100, 8)
    dummy_reg_pred = torch.randn(100, 17)  # reg_max + 1
    dummy_reg_target = torch.rand(100) * 16
    
    qfl_loss = qfl(dummy_cls_pred, dummy_cls_target)
    dfl_loss = dfl(dummy_reg_pred, dummy_reg_target)
    
    print(f"QFL Loss: {qfl_loss.item():.4f}")
    print(f"DFL Loss: {dfl_loss.item():.4f}")
    
    # Test adaptive NMS
    dummy_boxes = torch.tensor([
        [10, 10, 50, 50],
        [15, 15, 55, 55],
        [100, 100, 150, 150],
        [105, 105, 155, 155]
    ], dtype=torch.float32)
    
    dummy_scores = torch.tensor([0.9, 0.8, 0.85, 0.75])
    dummy_labels = torch.tensor([0, 0, 1, 1])
    
    iou_thresholds = {0: 0.5, 1: 0.6}  # Class-specific thresholds
    
    keep_indices = adaptive_nms(dummy_boxes, dummy_scores, dummy_labels, iou_thresholds)
    print(f"NMS kept {len(keep_indices)} detections: {keep_indices}")
    
    print("✅ GFL Prototype test completed successfully!")

def visualize_architecture():
    """Visualize the GFL architecture[4]."""
    
    print("\n" + "="*60)
    print("GFL DETECTOR ARCHITECTURE")
    print("="*60)
    
    print("""
    Input Image
         ↓
    ResNet Backbone (Feature Extraction)
         ↓
    Feature Pyramid Network
         ↓
    GFL Detection Head
    ├── Classification Branch (QFL)
    └── Regression Branch (DFL)
         ↓
    Post-processing (Adaptive NMS)
         ↓
    Final Detections
    """)
    
    print("\nTwo-Network Strategy:")
    print("├── Small Symbol Network: Patch-based detection (symbols < 700px)")
    print("└── Large Symbol Network: Full-image detection (symbols > 700px)")

if __name__ == "__main__":
    # Run prototype tests
    test_gfl_prototype()
    visualize_architecture()
    
    print("\n" + "="*60)
    print("NEXT STEPS")
    print("="*60)
    print("1. Implement proper data loading from Paliwal dataset")
    print("2. Add anchor generation and matching logic")
    print("3. Implement proper loss calculation with ground truth")
    print("4. Add evaluation metrics (mAP, precision, recall)")
    print("5. Create training loop with validation")
    print("6. Implement inference pipeline with visualization")

Testing GFL Prototype Implementation...
Small network parameters: 26,427,276
Large network parameters: 26,427,276
Number of feature levels: 4
Level 0: cls_shape=torch.Size([2, 8, 128, 128]), reg_shape=torch.Size([2, 68, 128, 128])
Level 1: cls_shape=torch.Size([2, 8, 64, 64]), reg_shape=torch.Size([2, 68, 64, 64])
Level 2: cls_shape=torch.Size([2, 8, 32, 32]), reg_shape=torch.Size([2, 68, 32, 32])
Level 3: cls_shape=torch.Size([2, 8, 16, 16]), reg_shape=torch.Size([2, 68, 16, 16])
QFL Loss: 0.1454
DFL Loss: 3.3212
NMS kept 2 detections: tensor([0, 2])
✅ GFL Prototype test completed successfully!

GFL DETECTOR ARCHITECTURE

    Input Image
         ↓
    ResNet Backbone (Feature Extraction)
         ↓
    Feature Pyramid Network
         ↓
    GFL Detection Head
    ├── Classification Branch (QFL)
    └── Regression Branch (DFL)
         ↓
    Post-processing (Adaptive NMS)
         ↓
    Final Detections
    

Two-Network Strategy:
├── Small Symbol Network: Patch-based detection (symbo