# Models Module

> Implementation of object detection models

In [None]:
#| default_exp models

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models.detection import FasterRCNN as TorchFasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from typing import Dict, List, Tuple, Union, Optional, Callable

from objdetect.core import box_xyxy_to_cxcywh, box_cxcywh_to_xyxy

In [None]:
#| hide
from nbdev.showdoc import *

## Backbone Networks

In [None]:
#| export
def create_backbone(name="resnet50", pretrained=True, trainable_layers=3):
    """Create a backbone network with FPN.
    
    Args:
        name: Backbone name (resnet18, resnet34, resnet50, resnet101)
        pretrained: Whether to use pretrained weights
        trainable_layers: Number of trainable layers (0 to 5)
        
    Returns:
        Backbone network with FPN
    """
    backbone = resnet_fpn_backbone(
        backbone_name=name,
        weights="DEFAULT" if pretrained else None,
        trainable_layers=trainable_layers
    )
    return backbone

## Object Detection Models

In [None]:
#| export
class FasterRCNN(nn.Module):
    """Faster R-CNN model for object detection.
    
    A wrapper around torchvision's Faster R-CNN with more convenient initialization.
    """
    def __init__(self, num_classes, backbone_name="resnet50", pretrained_backbone=True, 
                 trainable_backbone_layers=3, min_size=800, max_size=1333,
                 **kwargs):
        """
        Args:
            num_classes: Number of classes (including background)
            backbone_name: Backbone name (resnet18, resnet34, resnet50, resnet101)
            pretrained_backbone: Whether to use pretrained backbone
            trainable_backbone_layers: Number of trainable backbone layers
            min_size: Minimum size of the image to be rescaled
            max_size: Maximum size of the image to be rescaled
            **kwargs: Additional arguments for Faster R-CNN
        """
        super().__init__()
        
        # Create backbone
        backbone = create_backbone(
            name=backbone_name,
            pretrained=pretrained_backbone,
            trainable_layers=trainable_backbone_layers
        )
        
        # Define anchor generator
        anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
        
        # Create Faster R-CNN model
        self.model = TorchFasterRCNN(
            backbone=backbone,
            num_classes=num_classes,
            rpn_anchor_generator=anchor_generator,
            min_size=min_size,
            max_size=max_size,
            **kwargs
        )
        
    def forward(self, images, targets=None):
        """
        Args:
            images: List[Tensor] or Tensor, input images
            targets: Optional[List[Dict]], ground truth boxes and labels
            
        Returns:
            In training, returns losses dict
            In inference, returns List[Dict] with predictions
        """
        # Convert single tensor to list for batch processing
        if isinstance(images, torch.Tensor):
            images = [img for img in images]
            
        return self.model(images, targets)
    
    def freeze_backbone(self):
        """Freeze backbone parameters."""
        for param in self.model.backbone.parameters():
            param.requires_grad = False
            
    def unfreeze_backbone(self):
        """Unfreeze backbone parameters."""
        for param in self.model.backbone.parameters():
            param.requires_grad = True
            
    def predict(self, images, threshold=0.5):
        """Make predictions on images.
        
        Args:
            images: List[PIL.Image] or PIL.Image or tensor
            threshold: Confidence threshold
            
        Returns:
            List of prediction dictionaries
        """
        self.eval()
        with torch.no_grad():
            if not isinstance(images, list):
                images = [images]
                
            # Convert PIL images to tensors if needed
            processed_images = []
            for img in images:
                if not isinstance(img, torch.Tensor):
                    img = torchvision.transforms.ToTensor()(img)
                processed_images.append(img)
                
            predictions = self.model(processed_images)
            
            # Filter predictions by threshold
            filtered_predictions = []
            for pred in predictions:
                scores = pred['scores']
                keep = scores >= threshold
                
                filtered_pred = {
                    'boxes': pred['boxes'][keep],
                    'labels': pred['labels'][keep],
                    'scores': scores[keep]
                }
                filtered_predictions.append(filtered_pred)
                
            return filtered_predictions

## YOLO Model (Stub for future implementation)

In [None]:
#| export
class YOLO(nn.Module):
    """YOLO model for object detection.
    
    This is a placeholder for future implementation.
    """
    def __init__(self, num_classes, backbone_name="darknet", **kwargs):
        super().__init__()
        self.num_classes = num_classes
        # Will be implemented in future versions
        
    def forward(self, x, targets=None):
        # Placeholder
        return {"loss": torch.tensor(0.0, requires_grad=True)}

## Factory Function

In [None]:
#| export
def create_model(model_name="faster_rcnn", num_classes=91, **kwargs):
    """Factory function to create object detection models.
    
    Args:
        model_name: Model type (faster_rcnn, yolo)
        num_classes: Number of classes including background
        **kwargs: Additional model-specific parameters
        
    Returns:
        Object detection model
    """
    model_name = model_name.lower()
    
    if model_name == "faster_rcnn":
        return FasterRCNN(num_classes=num_classes, **kwargs)
    elif model_name == "yolo":
        return YOLO(num_classes=num_classes, **kwargs)
    else:
        raise ValueError(f"Unsupported model: {model_name}")