In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2 # OpenCV for image loading/resizing (example dependency)
import os # For file path handling (example dependency)
# Assume external libraries for XML/JSON parsing, proposal loading, IoU calculation, NMS

# --- Utility Functions (Conceptual - Needs Implementation) ---

def calculate_iou(box_a, boxes_b):
    """Calculates IoU between box_a and multiple boxes_b.
    Args:
        box_a (tensor): Single box [x1, y1, x2, y2]
        boxes_b (tensor): Multiple boxes [N, 4]
    Returns:
        tensor: IoU values [N,]
    """
    # Placeholder - Requires implementation based on box format
    # (e.g., using torchvision.ops.box_iou if boxes are [N, 4])
    # Ensure correct format and device handling
    # Example logic:
    x_a = torch.max(box_a[0], boxes_b[:, 0])
    y_a = torch.max(box_a[1], boxes_b[:, 1])
    x_b = torch.min(box_a[2], boxes_b[:, 2])
    y_b = torch.min(box_a[3], boxes_b[:, 3])

    inter_area = torch.clamp(x_b - x_a + 1e-6, min=0) * torch.clamp(y_b - y_a + 1e-6, min=0)

    box_a_area = (box_a[2] - box_a[0] + 1e-6) * (box_a[3] - box_a[1] + 1e-6)
    box_b_area = (boxes_b[:, 2] - boxes_b[:, 0] + 1e-6) * (boxes_b[:, 3] - boxes_b[:, 1] + 1e-6)

    iou = inter_area / (box_a_area + box_b_area - inter_area)
    return iou


def calculate_regression_targets(proposal, gt_box):
    """Calculates ground truth regression targets (v_x, v_y, v_w, v_h).
    Args:
        proposal (tensor): Proposal box [x1, y1, x2, y2]
        gt_box (tensor): Ground truth box [x1, y1, x2, y2]
    Returns:
        tensor: Targets [4,] (tx, ty, tw, th)
    """
    # Placeholder - Requires implementation matching paper's parameterization
    # Assumes input boxes are [x1, y1, x2, y2]
    # Convert to Px, Py, Pw, Ph and Gx, Gy, Gw, Gh if needed
    prop_w = proposal[2] - proposal[0] + 1e-6
    prop_h = proposal[3] - proposal[1] + 1e-6
    prop_cx = proposal[0] + 0.5 * prop_w
    prop_cy = proposal[1] + 0.5 * prop_h

    gt_w = gt_box[2] - gt_box[0] + 1e-6
    gt_h = gt_box[3] - gt_box[1] + 1e-6
    gt_cx = gt_box[0] + 0.5 * gt_w
    gt_cy = gt_box[1] + 0.5 * gt_h

    vx = (gt_cx - prop_cx) / prop_w
    vy = (gt_cy - prop_cy) / prop_h
    vw = torch.log(gt_w / prop_w)
    vh = torch.log(gt_h / prop_h)
    return torch.tensor([vx, vy, vw, vh], dtype=torch.float32)

def apply_regression_offsets(proposal, offsets):
    """Applies predicted offsets to a proposal box.
    Args:
        proposal (tensor): Proposal box [x1, y1, x2, y2]
        offsets (tensor): Predicted offsets [4,] (tx, ty, tw, th)
    Returns:
        tensor: Refined box [x1, y1, x2, y2]
    """
    # Placeholder - Inverse of calculate_regression_targets
    prop_w = proposal[2] - proposal[0] + 1e-6
    prop_h = proposal[3] - proposal[1] + 1e-6
    prop_cx = proposal[0] + 0.5 * prop_w
    prop_cy = proposal[1] + 0.5 * prop_h

    tx, ty, tw, th = offsets

    pred_cx = prop_w * tx + prop_cx
    pred_cy = prop_h * ty + prop_cy
    pred_w = prop_w * torch.exp(tw)
    pred_h = prop_h * torch.exp(th)

    pred_x1 = pred_cx - 0.5 * pred_w
    pred_y1 = pred_cy - 0.5 * pred_h
    pred_x2 = pred_cx + 0.5 * pred_w
    pred_y2 = pred_cy + 0.5 * pred_h
    return torch.stack([pred_x1, pred_y1, pred_x2, pred_y2])


def non_maximum_suppression(boxes, scores, iou_threshold):
    """Performs Non-Maximum Suppression.
    Args:
        boxes (tensor): Boxes [N, 4] for a specific class.
        scores (tensor): Scores [N,] for the boxes.
        iou_threshold (float): IoU threshold for suppression.
    Returns:
        tensor: Indices of boxes to keep.
    """
    # Placeholder - Requires implementation
    # Can use torchvision.ops.nms
    # Example: keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
    # return keep_indices
    # Simplified greedy NMS logic:
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    areas = (x2 - x1 + 1e-6) * (y2 - y1 + 1e-6)
    order = scores.argsort(descending=True)

    keep = []
    while order.numel() > 0:
        i = order[0]
        keep.append(i)
        if order.numel() == 1:
            break

        # Calculate IoU of the current box with remaining boxes
        xx1 = torch.maximum(x1[i], x1[order[1:]])
        yy1 = torch.maximum(y1[i], y1[order[1:]])
        xx2 = torch.minimum(x2[i], x2[order[1:]])
        yy2 = torch.minimum(y2[i], y2[order[1:]])

        w = torch.clamp(xx2 - xx1 + 1e-6, min=0.0)
        h = torch.clamp(yy2 - yy1 + 1e-6, min=0.0)
        inter = w * h
        iou = inter / (areas[i] + areas[order[1:]] - inter)

        # Keep boxes with IoU below the threshold
        inds = torch.where(iou <= iou_threshold)[0]
        order = order[inds + 1] # +1 because we compare with order[1:]

    return torch.tensor(keep, dtype=torch.int64, device=boxes.device)


# --- Data Loading and Preprocessing ---

class ObjectDetectionDataset(Dataset):
    """Conceptual Dataset for loading images, annotations, and proposals."""
    def __init__(self, image_dir, annotation_dir, proposal_dir, # Basic directories required
                 num_classes, class_map, # Class variables
                 target_scale=600, max_scale=1000, # Image process variable
                 use_random_scale=False, scales=(480, 576, 688, 864, 1200), # To use random scale or not.
                 use_flip=True): #To use flip or not
        super().__init__()
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.proposal_dir = proposal_dir # Assume proposals are pre-computed
        self.num_classes = num_classes
        self.class_map = class_map # e.g., {'background': 0, 'car': 1, ...}
        self.target_scale = target_scale # For single-scale processing
        self.max_scale = max_scale
        self.use_random_scale = use_random_scale # For multi-scale augmentation
        self.scales = scales
        self.use_flip = use_flip

        self.image_ids = self._load_image_ids() # Function to list image identifiers

        # Standard ImageNet normalization
        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # --- IMPORTANT: Add BBox Target Normalization ---
        # self.bbox_means, self.bbox_stds = self._calculate_bbox_stats() # Needs separate pass over dataset
        # Dummy values for now - REPLACE WITH ACTUAL STATS
        self.bbox_means = torch.tensor([0.0] * 4)
        self.bbox_stds = torch.tensor([1.0] * 4) # Using 1.0 means no normalization initially


    def _load_image_ids(self):
        # Placeholder: Scan annotation_dir or use a predefined list
        # Returns a list of unique identifiers (e.g., filenames without extension)
        # Example: return sorted([f.split('.')[0] for f in os.listdir(self.annotation_dir)])
        return ["image_001", "image_002"] # Dummy IDs

    def _load_annotation(self, image_id):
        # Placeholder: Load and parse annotation file (e.g., XML for VOC)
        # Should return:
        #   - gt_boxes (tensor): [N_gt, 4] ground truth boxes [x1, y1, x2, y2]
        #   - gt_labels (tensor): [N_gt,] ground truth class labels (integer indices)
        # Example:
        # xml_path = os.path.join(self.annotation_dir, f"{image_id}.xml")
        # boxes, labels = parse_voc_xml(xml_path, self.class_map)
        # return torch.tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)
        # Dummy annotation:
        gt_boxes = torch.tensor([[50, 50, 250, 250], [300, 100, 400, 200]], dtype=torch.float32)
        gt_labels = torch.tensor([1, 2], dtype=torch.long) # Assume class 1 and 2 exist
        return gt_boxes, gt_labels

    def _load_proposals(self, image_id):
        # Placeholder: Load pre-computed proposals (e.g., from a .mat or .txt file)
        # Should return:
        #   - proposals (tensor): [N_prop, 4] proposal boxes [x1, y1, x2, y2]
        # Example:
        # prop_path = os.path.join(self.proposal_dir, f"{image_id}_proposals.txt")
        # proposals = load_proposals_from_file(prop_path)
        # return torch.tensor(proposals, dtype=torch.float32)
        # Dummy proposals:
        proposals = torch.randint(0, 300, (2000, 4)).float() # ~2000 proposals
        proposals[:, 2] += proposals[:, 0] + 50 # Ensure x2 > x1
        proposals[:, 3] += proposals[:, 1] + 50 # Ensure y2 > y1
        return proposals

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]

        # 1. Load Image
        img_path = os.path.join(self.image_dir, f"{image_id}.jpg") # Assuming jpg
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB
        H_orig, W_orig, _ = img.shape

        # 2. Load Annotations and Proposals
        gt_boxes, gt_labels = self._load_annotation(image_id)
        proposals = self._load_proposals(image_id)

        # 3. Data Augmentation: Flipping
        apply_flip = False
        if self.use_flip and torch.rand(1) < 0.5:
            apply_flip = True
            img = cv2.flip(img, 1) # Horizontal flip
            # Flip boxes (x1' = W - x2, x2' = W - x1)
            proposals_orig_x1 = proposals[:, 0].clone()
            proposals[:, 0] = W_orig - proposals[:, 2]
            proposals[:, 2] = W_orig - proposals_orig_x1
            if gt_boxes.shape[0] > 0:
                gt_boxes_orig_x1 = gt_boxes[:, 0].clone()
                gt_boxes[:, 0] = W_orig - gt_boxes[:, 2]
                gt_boxes[:, 2] = W_orig - gt_boxes_orig_x1

        # 4. Data Augmentation / Preprocessing: Rescaling
        scale = self.target_scale
        if self.use_random_scale: # Choose a random scale if training augmentation is enabled
             scale = np.random.choice(self.scales)

        min_size = scale
        max_size = self.max_scale
        im_size_min = min(H_orig, W_orig)
        im_size_max = max(H_orig, W_orig)
        im_scale = float(min_size) / float(im_size_min)
        # Prevent the biggest axis from being more than MAX_SIZE
        if np.round(im_scale * im_size_max) > max_size:
            im_scale = float(max_size) / float(im_size_max)

        # Resize image using OpenCV (can use PIL or torchvision.transforms too)
        new_H = int(np.round(H_orig * im_scale))
        new_W = int(np.round(W_orig * im_scale))
        img_resized = cv2.resize(img, (new_W, new_H), interpolation=cv2.INTER_LINEAR)

        # Apply normalization and ToTensor transform
        img_tensor = self.image_transform(img_resized)

        # Rescale boxes (proposals and ground truth)
        proposals_resized = proposals * im_scale
        gt_boxes_resized = gt_boxes * im_scale

        # 5. Assign Labels and Regression Targets to Proposals
        # This is a complex step matching R-CNN/Fast R-CNN logic

        if gt_boxes_resized.shape[0] > 0:
            # Calculate IoU between all proposals and all GT boxes
            ious = torch.zeros((proposals_resized.shape[0], gt_boxes_resized.shape[0]))
            for i in range(proposals_resized.shape[0]):
                 ious[i, :] = calculate_iou(proposals_resized[i], gt_boxes_resized)

            # Find max IoU for each proposal and the corresponding GT box index
            max_ious, gt_assignment = ious.max(dim=1)
        else: # Handle images with no ground truth objects
            max_ious = torch.zeros(proposals_resized.shape[0])
            gt_assignment = torch.zeros(proposals_resized.shape[0], dtype=torch.long) -1 # Invalid assignment

        # Assign labels based on IoU thresholds (as per paper section 2.3)
        labels = torch.zeros(proposals_resized.shape[0], dtype=torch.long) # Default to background (0)
        # Foreground: >= 0.5 IoU
        fg_inds = torch.where(max_ious >= 0.5)[0]
        if fg_inds.numel() > 0:
            labels[fg_inds] = gt_labels[gt_assignment[fg_inds]]

        # Background: [0.1, 0.5) IoU
        bg_inds = torch.where((max_ious >= 0.1) & (max_ious < 0.5))[0]
        labels[bg_inds] = 0 # Explicitly set to background

        # Calculate regression targets ONLY for foreground proposals
        bbox_targets = torch.zeros((proposals_resized.shape[0], 4), dtype=torch.float32)
        if fg_inds.numel() > 0:
            assigned_gt_boxes = gt_boxes_resized[gt_assignment[fg_inds]]
            foreground_proposals = proposals_resized[fg_inds]
            # Calculate targets
            targets_raw = torch.zeros_like(foreground_proposals)
            for i in range(foreground_proposals.shape[0]):
                 targets_raw[i,:] = calculate_regression_targets(foreground_proposals[i], assigned_gt_boxes[i])

            # --- Normalize targets ---
            targets_normalized = (targets_raw - self.bbox_means.to(targets_raw.device)) / self.bbox_stds.to(targets_raw.device)
            bbox_targets[fg_inds, :] = targets_normalized

        return {
            'image': img_tensor,
            'proposals': proposals_resized, # Keep resized proposals for potential use
            'labels': labels,
            'bbox_targets': bbox_targets,
            'image_id': image_id # Keep track for debugging/evaluation
        }


def collate_fn_fast_rcnn(batch, R=128, N=2, num_classes=20):
    """Custom collate function for hierarchical sampling."""
    # batch: A list of dictionaries, where each dict is the output of __getitem__

    # 1. Select N images for the batch
    num_images_in_batch = min(N, len(batch)) # Handle cases where batch size < N
    selected_indices = np.random.choice(len(batch), num_images_in_batch, replace=False)
    selected_batch = [batch[i] for i in selected_indices]

    images_list = []
    rois_list = []
    labels_list = []
    bbox_targets_list = []

    rois_per_image = R // num_images_in_batch
    fg_rois_per_image = int(np.round(0.25 * rois_per_image)) # 25% foreground
    bg_rois_per_image = rois_per_image - fg_rois_per_image

    for i in range(num_images_in_batch):
        data = selected_batch[i]
        images_list.append(data['image'])

        proposals = data['proposals']
        labels = data['labels']
        bbox_targets = data['bbox_targets']

        # Separate foreground and background proposals based on labels
        fg_inds = torch.where(labels > 0)[0]
        bg_inds = torch.where(labels == 0)[0]

        # Sample foreground RoIs
        if fg_inds.numel() > fg_rois_per_image:
            fg_inds_sampled = np.random.choice(fg_inds.numpy(), size=fg_rois_per_image, replace=False)
        elif fg_inds.numel() > 0: # Sample with replacement if not enough
             fg_inds_sampled = np.random.choice(fg_inds.numpy(), size=fg_rois_per_image, replace=True)
        else: # Handle no foreground objects
             fg_inds_sampled = np.array([], dtype=np.int64)

        # Sample background RoIs
        num_bg_needed = rois_per_image - fg_inds_sampled.shape[0] # Adjust if fewer fg found
        if bg_inds.numel() > num_bg_needed:
            bg_inds_sampled = np.random.choice(bg_inds.numpy(), size=num_bg_needed, replace=False)
        elif bg_inds.numel() > 0: # Sample with replacement
            bg_inds_sampled = np.random.choice(bg_inds.numpy(), size=num_bg_needed, replace=True)
        else: # Handle no background objects (rare)
             bg_inds_sampled = np.array([], dtype=np.int64)

        # Combine sampled indices
        keep_inds = np.concatenate([fg_inds_sampled, bg_inds_sampled])
        if keep_inds.size == 0: continue # Skip image if no RoIs sampled

        # Select corresponding proposals, labels, targets
        sampled_proposals = proposals[keep_inds]
        sampled_labels = labels[keep_inds]
        sampled_bbox_targets = bbox_targets[keep_inds]

        # Create RoI tensor with batch index (format: batch_idx, x1, y1, x2, y2)
        batch_idx_tensor = torch.full((sampled_proposals.shape[0], 1), i, dtype=torch.float32)
        rois_for_image = torch.cat([batch_idx_tensor, sampled_proposals], dim=1)

        rois_list.append(rois_for_image)
        labels_list.append(sampled_labels)
        bbox_targets_list.append(sampled_bbox_targets)

    # Stack images and concatenate RoIs, labels, targets
    images = torch.stack(images_list, 0)
    if not rois_list: # Handle cases where no valid RoIs were sampled across batch
        # Return dummy tensors or raise error, depends on desired behavior
        return images, torch.empty((0,5)), torch.empty((0,)), torch.empty((0,4))

    rois = torch.cat(rois_list, 0)
    labels = torch.cat(labels_list, 0)
    bbox_targets = torch.cat(bbox_targets_list, 0)

    # --- Map RoIs to Feature Map Scale ---
    # This depends on the backbone's total stride S (e.g., 16 for VGG/ResNet default)
    # stride = 16.0
    # rois[:, 1:] /= stride # Divide x1, y1, x2, y2 by stride
    # Note: This mapping should technically happen *inside* the model or just before
    # RoI pooling if the stride is fixed. Passing image-scale RoIs might be cleaner
    # and mapping happens relative to the computed feature map size dynamically.
    # Let's assume image-scale RoIs are passed and mapping happens later.

    # --- Prepare bbox_targets for Smooth L1 Loss ---
    # The loss function needs targets expanded per class.
    # We only have targets for the GT class. We need a tensor [N_roi, num_classes * 4]
    # where targets are placed at indices corresponding to labels[roi]
    # and the rest are ignored (e.g., zeroed, masked out in loss).
    # This part is complex and often handled within the loss calculation itself
    # by selecting the correct predicted offsets based on the true label 'u'.
    # For simplicity here, we pass the [N_roi, 4] targets and the labels,
    # assuming the loss function handles the selection.

    return images, rois, labels, bbox_targets


# --- Backbone Network Definitions (VGG16 and ResNet50 as before) ---
class Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = None
        self.out_channels = None
        self.output_size = None
        self.stride = 16.0 # Default stride assumption

    def forward(self, x):
        raise NotImplementedError

class VGG16(Backbone):
    def __init__(self, pretrained=True):
        super().__init__()
        vgg16 = models.vgg16(pretrained=pretrained)
        # Remove final max pool layer from features to get output of conv5_3
        self.features = nn.Sequential(*list(vgg16.features.children())[:-1])
        self.out_channels = 512
        self.output_size = (7, 7)
        self.stride = 16.0

        # Keep original classifier structure up to fc7 (or equivalent)
        # Remove avgpool and the final classifier layer (fc8)
        self.roi_head_feature_extractor = nn.Sequential(
            nn.Linear(self.out_channels * self.output_size[0] * self.output_size[1], 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout()
        )
        # Initialize new FC layers properly (example)
        for layer in self.roi_head_feature_extractor:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        feature_map = self.features(x)
        # Note: We don't run the fc layers here, that happens *after* RoI pooling
        return feature_map # Only return the feature map


class ResNet50(Backbone):
    def __init__(self, pretrained=True):
        super().__init__()
        resnet = models.resnet50(pretrained=pretrained)
        # Use layers up to the end of the last conv block (e.g., layer4)
        self.features = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
        )
        self.out_channels = 2048 # Output channels of layer4
        self.output_size = (7, 7) # Often use 7x7 RoI pool for ResNet heads too
        self.stride = 16.0 # Note: Stride calculation might be more complex depending on exact ResNet variant

        # Define the RoI head feature extractor (replaces avgpool and fc)
        # Often uses AdaptiveAvgPool + FC layer(s)
        self.roi_head_feature_extractor = nn.Sequential(
             nn.AdaptiveAvgPool2d(self.output_size), # Pool to desired size first? Alternative is AdaptiveAvgPool(1) then FC
             nn.Flatten(), # Flatten before FC
             nn.Linear(self.out_channels * self.output_size[0] * self.output_size[1], 1024), # Example intermediate size
             nn.ReLU(True),
             # Can add more FC layers here if needed, e.g., mapping to 4096 like VGG
             nn.Linear(1024, 4096), # Map to 4096 for consistency with sibling heads
             nn.ReLU(True)
             # No dropout needed if only used internally before final heads? Check common practice.
        )
        # Initialize new layers
        for layer in self.roi_head_feature_extractor:
            if isinstance(layer, nn.Linear):
                 nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') # He initialization common for ResNet
                 nn.init.constant_(layer.bias, 0)


    def forward(self, x):
        feature_map = self.features(x)
        # The RoI head feature extractor will be applied *after* RoI pooling
        return feature_map


# --- RoI Pooling Layer (Simplified - using torchvision.ops) ---
# Using torchvision's RoIPool is much more robust and efficient
from torchvision.ops import RoIPool as TorchVisionRoIPool

class RoIPoolWrapper(nn.Module):
    def __init__(self, output_size, spatial_scale):
        super().__init__()
        # Note: torchvision RoIPool expects output_size as (h, w) tuple
        # spatial_scale is 1.0 / stride
        self.roi_pool = TorchVisionRoIPool(output_size, spatial_scale)

    def forward(self, features, rois):
        # rois need format [batch_idx, x1, y1, x2, y2] matching feature map scale
        # The collate_fn provided rois with batch_idx and *image scale* coordinates
        # We need to adjust here based on the feature map scale
        # Assuming rois[:, 0] is batch index and rois[:, 1:] are IMAGE scale coords

        # Correct application requires knowing feature map size relative to image
        # This spatial_scale handles it.
        return self.roi_pool(features, rois)


# --- Fast R-CNN Model Definition (Revised) ---
class FastRCNN(nn.Module):
    def __init__(self, num_classes, backbone_name="vgg16", pretrained=True, roi_output_size=(7, 7)):
        super(FastRCNN, self).__init__()
        self.num_classes = num_classes

        # Instantiate selected backbone
        if backbone_name.lower() == "vgg16":
            self.backbone = VGG16(pretrained=pretrained)
        elif backbone_name.lower() == "resnet50":
            self.backbone = ResNet50(pretrained=pretrained)
        else:
            raise ValueError("Invalid backbone name.")

        # RoI Pooling Layer (using torchvision)
        spatial_scale = 1.0 / self.backbone.stride
        self.roi_pool = RoIPoolWrapper(output_size=roi_output_size, spatial_scale=spatial_scale)

        # RoI Head Feature Extractor (from backbone definition)
        self.roi_head_feature_extractor = self.backbone.roi_head_feature_extractor

        # Sibling Output Layers
        # Input dimension depends on the output of roi_head_feature_extractor
        # Assuming it consistently outputs 4096 features for this example
        feature_dim = 4096
        self.cls_score_head = nn.Linear(feature_dim, num_classes + 1)
        self.bbox_pred_head = nn.Linear(feature_dim, num_classes * 4)

        # --- Initialize sibling layers ---
        nn.init.normal_(self.cls_score_head.weight, std=0.01)
        nn.init.constant_(self.cls_score_head.bias, 0)
        nn.init.normal_(self.bbox_pred_head.weight, std=0.001)
        nn.init.constant_(self.bbox_pred_head.bias, 0)

    def forward(self, images, rois):
        # images: (B, 3, H_img, W_img)
        # rois: (N_roi, 5) [batch_idx, x1_img, y1_img, x2_img, y2_img] (IMAGE scale)

        # 1. Get feature map from backbone
        feature_map = self.backbone(images)

        # 2. Perform RoI Pooling
        # RoIPoolWrapper handles the spatial scaling internally
        pooled_features = self.roi_pool(feature_map, rois) # Output: (N_roi, C, H_pool, W_pool)

        # 3. Pass through RoI head feature extractor
        # Reshape pooled features if necessary before FC layers
        # The exact reshaping depends on roi_head_feature_extractor structure
        # If it starts with FC layers (like VGG):
        pooled_features_flat = torch.flatten(pooled_features, start_dim=1)
        shared_roi_features = self.roi_head_feature_extractor(pooled_features_flat)
        # If it starts with AvgPool (like ResNet example):
        # shared_roi_features = self.roi_head_feature_extractor(pooled_features) # Assumes extractor handles input shape

        # 4. Get final predictions from sibling heads
        cls_score = self.cls_score_head(shared_roi_features)
        bbox_pred_offsets = self.bbox_pred_head(shared_roi_features)

        return cls_score, bbox_pred_offsets

# --- Loss Function Definitions (Mostly Unchanged) ---
def fast_rcnn_loss(cls_score, bbox_pred, labels, bbox_targets, bbox_inside_weights, bbox_outside_weights, lambda_loc=1.0):
    """Calculates Fast R-CNN loss.
    Args:
        cls_score (tensor): [N_roi, K+1] class scores.
        bbox_pred (tensor): [N_roi, K*4] predicted offsets.
        labels (tensor): [N_roi] integer class labels (0 for bg).
        bbox_targets (tensor): [N_roi, K*4] GT offsets (only relevant entries filled).
        bbox_inside_weights (tensor): [N_roi, K*4] Mask for L_loc components (1 if relevant, 0 otherwise).
        bbox_outside_weights (tensor): [N_roi, K*4] Balancing weight for L_loc (typically 1/N_roi for relevant).
        lambda_loc (float): Balancing factor.
    Returns:
        tensor: Total loss.
        tensor: Classification loss.
        tensor: Localization loss.
    """
    # Classification Loss (Cross Entropy)
    loss_cls = F.cross_entropy(cls_score, labels)

    # Localization Loss (Smooth L1)
    # Select predictions and targets for foreground objects only
    # This selection is now often done using the weights instead of direct indexing

    # Calculate element-wise Smooth L1
    loss_box_all = F.smooth_l1_loss(
        bbox_pred * bbox_inside_weights, # Apply mask to predictions
        bbox_targets * bbox_inside_weights, # Apply mask to targets
        reduction='none' # Get element-wise loss
    )

    # Sum the loss across the 4 coordinates, weighted
    loss_box = torch.sum(bbox_outside_weights * loss_box_all)

    # Combine losses
    loss = loss_cls + lambda_loc * loss_box

    return loss, loss_cls, loss_box


# --- Usage Instructions / Inference ---

def detect_objects(model, image_path, proposal_loader, device,
                   conf_threshold=0.7, nms_threshold=0.3,
                   target_scale=600, max_scale=1000,
                   bbox_means=None, bbox_stds=None, # For de-normalization
                   num_classes=20, class_map_inv=None): # To map indices back to names
    """Performs object detection on a single image."""

    model.eval() # Set model to evaluation mode
    model.to(device)

    # 1. Load and Preprocess Image
    img_bgr = cv2.imread(image_path)
    if img_bgr is None:
        raise FileNotFoundError(f"Image not found at {image_path}")
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    H_orig, W_orig, _ = img_rgb.shape

    # Rescale image (same logic as dataset __getitem__)
    min_size = target_scale
    max_size = max_scale
    im_size_min = min(H_orig, W_orig)
    im_size_max = max(H_orig, W_orig)
    im_scale = float(min_size) / float(im_size_min)
    if np.round(im_scale * im_size_max) > max_size:
        im_scale = float(max_size) / float(im_size_max)
    new_H = int(np.round(H_orig * im_scale))
    new_W = int(np.round(W_orig * im_scale))
    img_resized = cv2.resize(img_rgb, (new_W, new_H), interpolation=cv2.INTER_LINEAR)

    # Apply normalization
    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = img_transform(img_resized).unsqueeze(0).to(device) # Add batch dimension

    # 2. Load/Generate Proposals
    # proposals_img_scale: (N_prop, 4) [x1, y1, x2, y2] in original image coordinates
    proposals_img_scale = proposal_loader(image_path) # Conceptual function
    if proposals_img_scale is None or proposals_img_scale.shape[0] == 0:
        print("No proposals found for image.")
        return [], [], []

    # Add batch index (0 for single image inference)
    batch_idx_tensor = torch.zeros((proposals_img_scale.shape[0], 1), device=device)
    rois_img_scale = torch.cat([batch_idx_tensor, torch.tensor(proposals_img_scale, device=device)], dim=1)

    # 3. Run Model Forward Pass
    with torch.no_grad():
        cls_score, bbox_pred_offsets = model(img_tensor, rois_img_scale)

    # 4. Post-process Outputs
    probs = F.softmax(cls_score, dim=1) # Get probabilities (N_roi, K+1)
    proposals_resized_scale = rois_img_scale[:, 1:] * im_scale # Rescale proposals to match resized image

    # Use means/stds for de-normalization if provided
    if bbox_means is None: bbox_means = torch.zeros(4, device=device)
    if bbox_stds is None: bbox_stds = torch.ones(4, device=device)

    final_boxes = []
    final_scores = []
    final_labels = []

    # Iterate through each foreground class (skip background class 0)
    for class_idx in range(1, num_classes + 1):
        # Get scores for this class
        class_scores = probs[:, class_idx]

        # Get predicted offsets for this class
        # Indices for class `j` are `j*4` to `j*4+3` (if bg is 0, adjust if bg is last class)
        # Assuming K classes means indices 1 to K map to output offsets 0*4.. to (K-1)*4... Needs care!
        # Let's assume class_idx 1 corresponds to first block of 4 offsets, etc.
        offset_indices = slice((class_idx - 1) * 4, class_idx * 4) # Adjust if class indices differ
        class_bbox_offsets = bbox_pred_offsets[:, offset_indices]

        # Filter proposals by confidence threshold
        keep_inds = torch.where(class_scores >= conf_threshold)[0]
        if keep_inds.numel() == 0:
            continue

        filtered_proposals = proposals_resized_scale[keep_inds]
        filtered_scores = class_scores[keep_inds]
        filtered_offsets = class_bbox_offsets[keep_inds]

        # De-normalize offsets
        filtered_offsets_denorm = (filtered_offsets * bbox_stds) + bbox_means

        # Apply offsets to proposals to get refined boxes
        refined_boxes = torch.zeros_like(filtered_proposals)
        for i in range(filtered_proposals.shape[0]):
            refined_boxes[i] = apply_regression_offsets(filtered_proposals[i], filtered_offsets_denorm[i])

        # Clip boxes to image boundaries (of the *resized* image)
        refined_boxes[:, 0::2].clamp_(min=0, max=new_W - 1) # x1, x2
        refined_boxes[:, 1::2].clamp_(min=0, max=new_H - 1) # y1, y2

        # Perform Non-Maximum Suppression (NMS)
        keep_nms_inds = non_maximum_suppression(refined_boxes, filtered_scores, nms_threshold)
        if keep_nms_inds.numel() == 0:
            continue

        # Scale final boxes back to original image size
        final_class_boxes = refined_boxes[keep_nms_inds] / im_scale
        final_class_scores = filtered_scores[keep_nms_inds]
        final_class_labels = torch.full_like(final_class_scores, class_idx, dtype=torch.long)

        final_boxes.append(final_class_boxes)
        final_scores.append(final_class_scores)
        final_labels.append(final_class_labels)

    if not final_boxes:
        return [], [], []

    # Concatenate results across all classes
    final_boxes = torch.cat(final_boxes, dim=0)
    final_scores = torch.cat(final_scores, dim=0)
    final_labels = torch.cat(final_labels, dim=0)

    # Optional: Map labels back to names
    if class_map_inv:
        final_label_names = [class_map_inv[label.item()] for label in final_labels]
        return final_boxes.cpu().numpy(), final_scores.cpu().numpy(), final_label_names
    else:
        return final_boxes.cpu().numpy(), final_scores.cpu().numpy(), final_labels.cpu().numpy()


# --- Example Usage Outline ---

# 1. Setup (Paths, Classes, etc.)
# IMAGE_DIR = "path/to/your/images"
# ANNOTATION_DIR = "path/to/your/annotations"
# PROPOSAL_DIR = "path/to/your/proposals"
# NUM_CLASSES = 20 # e.g., PASCAL VOC
# CLASS_MAP = {'background': 0, 'aeroplane': 1, ...} # Map class names to integers
# CLASS_MAP_INV = {v: k for k, v in CLASS_MAP.items()} # For inference output
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Create Dataset and DataLoader (for Training)
# train_dataset = ObjectDetectionDataset(IMAGE_DIR, ANNOTATION_DIR, PROPOSAL_DIR, NUM_CLASSES, CLASS_MAP, use_flip=True, use_random_scale=True)
# # Note: Need robust collate_fn implementation
# train_loader = DataLoader(train_dataset, batch_size=None, # Batch size handled by collate_fn logic implicitly via N
#                           num_workers=4, collate_fn=lambda batch: collate_fn_fast_rcnn(batch, R=128, N=2, num_classes=NUM_CLASSES))

# 3. Initialize Model and Optimizer (for Training)
# model = FastRCNN(num_classes=NUM_CLASSES, backbone_name="vgg16", pretrained=True).to(DEVICE)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30000, gamma=0.1) # Example scheduler

# 4. Training Loop (Conceptual)
# num_iterations = 40000 # e.g., 30k + 10k
# model.train()
# for i, batch_data in enumerate(train_loader):
#     if i >= num_iterations: break
#     images, rois, labels, bbox_targets = batch_data
#     images = images.to(DEVICE)
#     rois = rois.to(DEVICE)
#     labels = labels.to(DEVICE)
#     bbox_targets = bbox_targets.to(DEVICE) # Needs proper expansion/masking based on labels for loss calculation

     # --- Calculate loss weights (conceptual) ---
#     bbox_inside_weights = torch.zeros_like(bbox_targets.repeat(1, NUM_CLASSES)) # Placeholder for masking L_loc
#     bbox_outside_weights = torch.zeros_like(bbox_targets.repeat(1, NUM_CLASSES))# Placeholder for balancing L_loc
     # Populate weights based on labels - complex logic needed here
     # Correct calculation of these weights is crucial for the loss function.

#     optimizer.zero_grad()
#     cls_score, bbox_pred = model(images, rois)

#     loss, loss_cls, loss_loc = fast_rcnn_loss(
#         cls_score, bbox_pred, labels,
#         bbox_targets, # Pass appropriately structured/expanded targets
#         bbox_inside_weights, bbox_outside_weights, # Pass weights
#         lambda_loc=1.0
#     )

#     loss.backward()
#     optimizer.step()
#     lr_scheduler.step()

#     if i % 100 == 0: # Logging example
#         print(f"Iter: {i}, Loss: {loss.item():.4f}, Loss Cls: {loss_cls.item():.4f}, Loss Loc: {loss_loc.item():.4f}")

     # Add validation, checkpointing etc.

# 5. Inference / Detection (Example)
# Load trained model weights
# model.load_state_dict(torch.load("path/to/trained_model.pth"))
# model.eval()
# model.to(DEVICE)

# Define a conceptual proposal loader function
# def my_proposal_loader(image_path):
#    # Load proposals for this image_path (e.g., from Selective Search output)
#    # Return numpy array [N_prop, 4]
#    return np.random.rand(500, 4) * 200 # Dummy

# Perform detection
# try:
#     boxes, scores, labels = detect_objects(
#         model, "path/to/test_image.jpg", my_proposal_loader, DEVICE,
#         conf_threshold=0.8, nms_threshold=0.3,
#         # Pass actual bbox means/stds if normalization was used
#         num_classes=NUM_CLASSES, class_map_inv=CLASS_MAP_INV
#     )
#     print("Detections:")
#     for box, score, label in zip(boxes, scores, labels):
#         print(f"  Label: {label}, Score: {score:.3f}, Box: {box}")
#     # Add visualization code here (draw boxes on image)
# except FileNotFoundError as e:
#     print(e)