# Core Module

> Core functionality for object detection

In [None]:
#| default_exp core

In [None]:
#| export
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Tuple, Union, Optional, Callable

In [None]:
#| hide
from nbdev.showdoc import *

## Bounding Box Utilities

In [None]:
#| export
def box_cxcywh_to_xyxy(x):
    """Convert bounding box from (center_x, center_y, width, height) to (x1, y1, x2, y2) format.
    
    Args:
        x: tensor of shape (..., 4) containing bounding boxes in (cx, cy, w, h) format
        
    Returns:
        tensor of same shape containing boxes in (x1, y1, x2, y2) format
    """
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)

In [None]:
#| export
def box_xyxy_to_cxcywh(x):
    """Convert bounding box from (x1, y1, x2, y2) to (center_x, center_y, width, height) format.
    
    Args:
        x: tensor of shape (..., 4) containing bounding boxes in (x1, y1, x2, y2) format
        
    Returns:
        tensor of same shape containing boxes in (cx, cy, w, h) format
    """
    x0, y0, x1, y1 = x.unbind(-1)
    b = [(x0 + x1) / 2, (y0 + y1) / 2,
         (x1 - x0), (y1 - y0)]
    return torch.stack(b, dim=-1)

In [None]:
#| export
def box_iou(boxes1, boxes2):
    """Compute intersection over union between boxes.
    
    Args:
        boxes1: tensor of shape (N, 4) containing N boxes in (x1, y1, x2, y2) format
        boxes2: tensor of shape (M, 4) containing M boxes in (x1, y1, x2, y2) format
        
    Returns:
        tensor of shape (N, M) containing pairwise IoU values
    """
    area1 = torch.prod(boxes1[:, 2:] - boxes1[:, :2], 1)
    area2 = torch.prod(boxes2[:, 2:] - boxes2[:, :2], 1)
    
    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # left-top [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # right-bottom [N,M,2]
    
    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
    
    union = area1[:, None] + area2 - inter
    
    iou = inter / union
    return iou

## Visualization Functions

In [None]:
#| export
def plot_boxes(img, boxes, labels=None, scores=None, class_names=None, figsize=(10, 10)):
    """Plot bounding boxes on an image.
    
    Args:
        img: PIL Image or tensor
        boxes: tensor of shape (N, 4) containing boxes in (x1, y1, x2, y2) format, values in [0, 1]
        labels: optional tensor of shape (N,) containing class labels
        scores: optional tensor of shape (N,) containing confidence scores
        class_names: optional list of class names for label mapping
        figsize: figure size
        
    Returns:
        matplotlib figure
    """
    if isinstance(img, torch.Tensor):
        img = torchvision.transforms.ToPILImage()(img.cpu())
        
    fig, ax = plt.subplots(1, figsize=figsize)
    ax.imshow(img)
    
    # Random colors for different classes
    colors = np.random.rand(20, 3) if labels is not None else [[0, 1, 0]]
    
    if boxes is not None and len(boxes):
        boxes = boxes.cpu().numpy()
        
        # Convert normalized coordinates to pixel coordinates
        height, width = img.height, img.width
        boxes[:, [0, 2]] *= width
        boxes[:, [1, 3]] *= height
        
        for i, box in enumerate(boxes):
            x1, y1, x2, y2 = box
            
            # Determine color based on class label
            if labels is not None:
                label_id = labels[i].item() if isinstance(labels[i], torch.Tensor) else labels[i]
                color = colors[label_id % len(colors)]
            else:
                color = colors[0]
                
            # Plot box
            rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, 
                               edgecolor=color, linewidth=2)
            ax.add_patch(rect)
            
            # Add label and score if available
            if labels is not None:
                label_txt = class_names[label_id] if class_names else f'Class {label_id}'
                if scores is not None:
                    score = scores[i].item() if isinstance(scores[i], torch.Tensor) else scores[i]
                    label_txt += f': {score:.2f}'
                    
                ax.text(x1, y1, label_txt, bbox=dict(facecolor=color, alpha=0.5))
                
    plt.axis('off')
    return fig

## Basic Tests

In [None]:
# Test box conversion functions
boxes_xyxy = torch.tensor([[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8]])
boxes_cxcywh = box_xyxy_to_cxcywh(boxes_xyxy)
boxes_xyxy2 = box_cxcywh_to_xyxy(boxes_cxcywh)

assert torch.allclose(boxes_xyxy, boxes_xyxy2), "Box conversion is not reversible"

In [None]:
# Test IoU calculation
boxes1 = torch.tensor([[0.1, 0.1, 0.5, 0.5], [0.3, 0.3, 0.7, 0.7]])
boxes2 = torch.tensor([[0.3, 0.3, 0.6, 0.6], [0.7, 0.7, 0.9, 0.9]])

iou = box_iou(boxes1, boxes2)
assert iou.shape == (2, 2), f"Expected shape (2, 2), got {iou.shape}"