In [None]:
!ssh -p 20189 root@174.88.252.15 -L 8080:localhost:8080

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###DATA PREP

In [None]:
import os
import glob
import yaml
import cv2
import numpy as np
import xml.etree.ElementTree as ET
from pathlib import Path
from tqdm.auto import tqdm
import shutil
import logging
import argparse
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Tuple, Optional


# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler("conversion.log")
    ]
)
logger = logging.getLogger("YOLO2VOC")


def validate_paths(yolo_path: str, output_path: str) -> bool:
    """
    Validate input and output paths

    Args:
        yolo_path: Path to YOLO format dataset
        output_path: Path to output Pascal VOC format dataset

    Returns:
        bool: True if paths are valid, False otherwise
    """
    # Check if YOLO dataset path exists
    if not os.path.exists(yolo_path):
        logger.error(f"YOLO dataset path does not exist: {yolo_path}")
        return False

    # Check if data.yaml exists
    yaml_path = os.path.join(yolo_path, "data.yaml")
    if not os.path.exists(yaml_path):
        logger.error(f"data.yaml not found at: {yaml_path}")
        return False

    # Check if output path parent directory exists
    output_parent = os.path.dirname(output_path)
    if output_parent and not os.path.exists(output_parent):
        logger.warning(f"Output parent directory does not exist: {output_parent}")
        try:
            os.makedirs(output_parent, exist_ok=True)
            logger.info(f"Created output parent directory: {output_parent}")
        except Exception as e:
            logger.error(f"Failed to create output directory: {e}")
            return False

    return True


def create_xml_annotation(
    image_path: str,
    image_width: int,
    image_height: int,
    boxes: List[List[int]],
    class_ids: List[int],
    class_names: List[str]
) -> ET.ElementTree:
    """
    Create Pascal VOC format XML annotation file

    Args:
        image_path: Path to the image
        image_width: Width of the image
        image_height: Height of the image
        boxes: List of bounding boxes as [x_min, y_min, x_max, y_max]
        class_ids: List of class IDs corresponding to each box
        class_names: List of class names

    Returns:
        ET.ElementTree: XML tree for the annotation
    """
    root = ET.Element("annotation")

    # Add basic image information
    folder = ET.SubElement(root, "folder")
    folder.text = str(Path(image_path).parent.name)

    filename = ET.SubElement(root, "filename")
    filename.text = str(Path(image_path).name)

    path = ET.SubElement(root, "path")
    path.text = str(image_path)

    source = ET.SubElement(root, "source")
    database = ET.SubElement(source, "database")
    database.text = "Unknown"

    # Add size information
    size = ET.SubElement(root, "size")
    width = ET.SubElement(size, "width")
    width.text = str(image_width)
    height = ET.SubElement(size, "height")
    height.text = str(image_height)
    depth = ET.SubElement(size, "depth")
    depth.text = "3"  # Assuming RGB images

    segmented = ET.SubElement(root, "segmented")
    segmented.text = "0"

    # Add object information for each bounding box
    for box, class_id in zip(boxes, class_ids):
        # Ensure class_id is within range
        if class_id < 0 or class_id >= len(class_names):
            logger.warning(f"Invalid class_id {class_id}, skipping this box")
            continue

        obj = ET.SubElement(root, "object")

        name = ET.SubElement(obj, "name")
        name.text = class_names[class_id]

        pose = ET.SubElement(obj, "pose")
        pose.text = "Unspecified"

        truncated = ET.SubElement(obj, "truncated")
        truncated.text = "0"

        difficult = ET.SubElement(obj, "difficult")
        difficult.text = "0"

        bndbox = ET.SubElement(obj, "bndbox")

        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(int(box[0]))

        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(int(box[1]))

        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(int(box[2]))

        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(int(box[3]))

    tree = ET.ElementTree(root)
    return tree


def process_image(
    img_path: str,
    split: str,
    yolo_dataset_path: str,
    output_path: str,
    class_names: List[str],
    overwrite: bool = False
) -> Optional[str]:
    """
    Process a single image and convert its annotations

    Args:
        img_path: Path to the image
        split: Dataset split (train, val, test)
        yolo_dataset_path: Path to YOLO format dataset
        output_path: Path to output Pascal VOC format dataset
        class_names: List of class names
        overwrite: Whether to overwrite existing files

    Returns:
        str: Image name if successfully processed, None otherwise
    """
    try:
        img_filename = os.path.basename(img_path)
        img_name, img_ext = os.path.splitext(img_filename)
        dest_img_path = os.path.join(output_path, split, 'JPEGImages', img_filename)
        xml_path = os.path.join(output_path, split, 'Annotations', f"{img_name}.xml")

        # Skip if files already exist and overwrite is False
        if not overwrite and os.path.exists(dest_img_path) and os.path.exists(xml_path):
            return img_name

        # Read the image
        img = cv2.imread(img_path)
        if img is None:
            logger.warning(f"Could not read image {img_path}, skipping...")
            return None

        # Copy the image file
        try:
            shutil.copy(img_path, dest_img_path)
        except Exception as e:
            logger.error(f"Error copying image file {img_path}: {e}")
            return None

        img_height, img_width = img.shape[:2]
        boxes = []
        class_ids = []

        # Process the label file
        label_path = os.path.join(yolo_dataset_path, split, "labels", f"{img_name}.txt")
        if os.path.exists(label_path):
            try:
                with open(label_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            try:
                                class_id = int(float(parts[0]))
                                x_center = float(parts[1])
                                y_center = float(parts[2])
                                width = float(parts[3])
                                height = float(parts[4])

                                # Convert YOLO to VOC format
                                x_min = int((x_center - width/2) * img_width)
                                y_min = int((y_center - height/2) * img_height)
                                x_max = int((x_center + width/2) * img_width)
                                y_max = int((y_center + height/2) * img_height)

                                # Ensure coordinates are within image bounds
                                x_min = max(0, x_min)
                                y_min = max(0, y_min)
                                x_max = min(img_width, x_max)
                                y_max = min(img_height, y_max)

                                # Ensure valid box dimensions
                                if x_max <= x_min:
                                    x_max = min(x_min + 1, img_width)
                                if y_max <= y_min:
                                    y_max = min(y_min + 1, img_height)

                                # Skip boxes with invalid dimensions
                                if x_max <= x_min or y_max <= y_min:
                                    logger.warning(f"Invalid box dimensions in {label_path}: ({x_min}, {y_min}, {x_max}, {y_max})")
                                    continue

                                boxes.append([x_min, y_min, x_max, y_max])
                                class_ids.append(class_id)
                            except ValueError as e:
                                logger.warning(f"Error parsing line in {label_path}: {line.strip()} - {e}")
                                continue
            except Exception as e:
                logger.error(f"Error reading label file {label_path}: {e}")

        # Create and write XML annotation
        xml_tree = create_xml_annotation(dest_img_path, img_width, img_height, boxes, class_ids, class_names)
        try:
            xml_tree.write(xml_path)
        except Exception as e:
            logger.error(f"Error writing XML file {xml_path}: {e}")
            return None

        return img_name
    except Exception as e:
        logger.error(f"Error processing image {img_path}: {e}")
        return None


def convert_yolo_to_pascal_voc(
    yolo_dataset_path: str,
    output_path: str,
    splits: Tuple[str, ...] = ('train', 'valid', 'test'),
    overwrite: bool = False,
    num_workers: int = 4
):
    """
    Convert YOLO format dataset to Pascal VOC format for SSD training

    Args:
        yolo_dataset_path: Path to YOLO format dataset with data.yaml file
        output_path: Path to output Pascal VOC format dataset
        splits: Dataset splits to convert
        overwrite: Whether to overwrite existing files
        num_workers: Number of worker threads for parallel processing
    """
    # Validate paths
    if not validate_paths(yolo_dataset_path, output_path):
        logger.error("Path validation failed. Aborting conversion.")
        return

    # Load class names from data.yaml
    try:
        with open(os.path.join(yolo_dataset_path, "data.yaml"), 'r') as f:
            yaml_data = yaml.safe_load(f)
            if not yaml_data or 'names' not in yaml_data:
                logger.error("Invalid data.yaml file: missing 'names' field")
                return

            class_names = yaml_data['names']
            logger.info(f"Classes: {class_names}")
    except Exception as e:
        logger.error(f"Error reading data.yaml: {e}")
        return

    # Create output directories
    try:
        os.makedirs(output_path, exist_ok=True)
        for split in splits:
            os.makedirs(os.path.join(output_path, split, 'JPEGImages'), exist_ok=True)
            os.makedirs(os.path.join(output_path, split, 'Annotations'), exist_ok=True)
            os.makedirs(os.path.join(output_path, split, 'ImageSets', 'Main'), exist_ok=True)
        os.makedirs(os.path.join(output_path, 'ImageSets', 'Main'), exist_ok=True)
    except Exception as e:
        logger.error(f"Error creating output directories: {e}")
        return

    missing_splits = []

    # Process each split
    for split in splits:
        logger.info(f"\nProcessing {split} split...")
        img_dir = os.path.join(yolo_dataset_path, split, "images")
        if not os.path.exists(img_dir):
            logger.warning(f"{img_dir} does not exist, marking as missing...")
            missing_splits.append(split)
            continue

        # Find all images with supported extensions
        img_paths = glob.glob(os.path.join(img_dir, "*.jpg")) + \
                    glob.glob(os.path.join(img_dir, "*.png")) + \
                    glob.glob(os.path.join(img_dir, "*.jpeg"))

        if not img_paths:
            logger.warning(f"No images found in {img_dir}, marking as missing...")
            missing_splits.append(split)
            continue

        logger.info(f"Found {len(img_paths)} images in {split} split")

        # Process images in parallel
        successful_images = []

        # Create process pool
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            # Create a progress bar
            results = list(tqdm(
                executor.map(
                    lambda img_path: process_image(
                        img_path, split, yolo_dataset_path, output_path, class_names, overwrite
                    ),
                    img_paths
                ),
                total=len(img_paths),
                desc=f"Converting {split} set"
            ))

            # Filter out None results (failed conversions)
            successful_images = [img_name for img_name in results if img_name]

        # Write the image list file
        if successful_images:
            try:
                with open(os.path.join(output_path, split, 'ImageSets', 'Main', f"{split}.txt"), 'w') as f:
                    for img_name in successful_images:
                        f.write(f"{img_name}\n")
                logger.info(f"Successfully processed {len(successful_images)} images for {split} split")
            except Exception as e:
                logger.error(f"Error writing image list file for {split} split: {e}")
        else:
            logger.warning(f"No images processed for {split} split")

    # Report missing splits
    if missing_splits:
        logger.warning(f"\nThe following splits were missing or empty: {', '.join(missing_splits)}")

    # Create labelmap.txt
    try:
        with open(os.path.join(output_path, "labelmap.txt"), 'w') as f:
            for i, class_name in enumerate(class_names):
                f.write(f"{i} {class_name}\n")
    except Exception as e:
        logger.error(f"Error writing labelmap.txt: {e}")

    # Create trainval.txt (combined train+val list)
    if 'train' not in missing_splits and 'val' not in missing_splits:
        try:
            train_list = []
            val_list = []
            train_file = os.path.join(output_path, 'train', 'ImageSets', 'Main', 'train.txt')
            val_file = os.path.join(output_path, 'val', 'ImageSets', 'Main', 'val.txt')

            if os.path.exists(train_file):
                with open(train_file, 'r') as f:
                    train_list = [line.strip() for line in f]
            if os.path.exists(val_file):
                with open(val_file, 'r') as f:
                    val_list = [line.strip() for line in f]

            with open(os.path.join(output_path, 'ImageSets', 'Main', 'trainval.txt'), 'w') as f:
                for name in train_list + val_list:
                    f.write(f"{name}\n")
        except Exception as e:
            logger.error(f"Error creating trainval.txt: {e}")

    # Print summary
    logger.info("\nConversion completed successfully!")
    logger.info(f"Pascal VOC format dataset created at: {output_path}")
    logger.info("Directory structure:")
    for split in [s for s in splits if s not in missing_splits]:
        img_count = len(glob.glob(os.path.join(output_path, split, 'JPEGImages', '*')))
        xml_count = len(glob.glob(os.path.join(output_path, split, 'Annotations', '*.xml')))
        logger.info(f"  - {split}/")
        logger.info(f"      - JPEGImages/ (contains {img_count} images)")
        logger.info(f"      - Annotations/ (contains {xml_count} XML files)")
        logger.info(f"      - ImageSets/Main/{split}.txt (contains {img_count} entries)")
    logger.info(f"  - labelmap.txt (class mapping file)")
    if 'train' not in missing_splits and 'val' not in missing_splits:
        logger.info(f"  - ImageSets/Main/trainval.txt (combined train+val list)")


# For use in a Jupyter/Colab notebook - avoids argparse conflicts
def run_conversion(
    yolo_path='/content/drive/MyDrive/DATASETS/MANGO/adjusted_dataset_20250305_203807',
    output_path='/content/ssd_dataset',
    splits='train,valid,test',
    overwrite=False,
    num_workers=4,
    log_level='INFO'
):
    """
    Run the YOLO to Pascal VOC conversion with specified parameters.
    This function is designed to be called directly in a Jupyter notebook.

    Args:
        yolo_path: Path to YOLO format dataset
        output_path: Path to output Pascal VOC format dataset
        splits: Comma-separated list of dataset splits to convert
        overwrite: Whether to overwrite existing files
        num_workers: Number of worker threads for parallel processing
        log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    """
    # Set log level
    logger.setLevel(getattr(logging, log_level))

    # Log the parameters
    logger.info("Starting conversion with parameters:")
    logger.info(f"  YOLO dataset path: {yolo_path}")
    logger.info(f"  Output path: {output_path}")
    logger.info(f"  Splits: {splits}")
    logger.info(f"  Overwrite: {overwrite}")
    logger.info(f"  Num workers: {num_workers}")

    try:
        # Split the splits string into a tuple
        splits_tuple = tuple(splits.split(','))

        # Run the conversion
        convert_yolo_to_pascal_voc(
            yolo_path,
            output_path,
            splits_tuple,
            overwrite,
            num_workers
        )
        logger.info("\n=== SSD DATASET PREPARATION COMPLETE ===")
        return True
    except Exception as e:
        logger.critical(f"Unhandled error during conversion: {e}", exc_info=True)
        return False


# This allows the script to be used both as a module and as a standalone script
if __name__ == "__main__":
    # Check if running in Jupyter/IPython environment
    try:
        # This will raise NameError if not in IPython
        if 'IPython' in sys.modules:
            # Set default parameters for direct execution in notebook
            print("Running in Jupyter/Colab environment. Use the run_conversion function instead of command-line arguments.")
            print("Example: run_conversion(yolo_path='/path/to/yolo', output_path='/path/to/output')")
        else:
            # Parse command line arguments for standalone script usage
            parser = argparse.ArgumentParser(description='Convert YOLO format dataset to Pascal VOC format')
            parser.add_argument('--yolo-path', type=str, default='/content/drive/MyDrive/DATASETS/MANGO/adjusted_dataset_20250305_203807',
                                help='Path to YOLO format dataset with data.yaml file')
            parser.add_argument('--output-path', type=str, default='/content/ssd_dataset',
                                help='Path to output Pascal VOC format dataset')
            parser.add_argument('--splits', type=str, default='train,valid,test',
                                help='Dataset splits to convert (comma-separated)')
            parser.add_argument('--overwrite', action='store_true',
                                help='Overwrite existing files')
            parser.add_argument('--num-workers', type=int, default=4,
                                help='Number of worker threads for parallel processing')
            parser.add_argument('--log-level', type=str, default='INFO',
                                choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
                                help='Logging level')

            args = parser.parse_args()

            # Run with command line arguments
            run_conversion(
                args.yolo_path,
                args.output_path,
                args.splits,
                args.overwrite,
                args.num_workers,
                args.log_level
            )
    except NameError:
        # Not in IPython, but don't want to use argparse
        print("Running in script mode without argument parsing.")
        # Run with default arguments
        run_conversion()

Running in Jupyter/Colab environment. Use the run_conversion function instead of command-line arguments.
Example: run_conversion(yolo_path='/path/to/yolo', output_path='/path/to/output')


In [None]:
run_conversion()

INFO:YOLO2VOC:Starting conversion with parameters:
INFO:YOLO2VOC:  YOLO dataset path: /content/drive/MyDrive/DATASETS/MANGO/adjusted_dataset_20250305_203807
INFO:YOLO2VOC:  Output path: /content/ssd_dataset
INFO:YOLO2VOC:  Splits: train,valid,test
INFO:YOLO2VOC:  Overwrite: False
INFO:YOLO2VOC:  Num workers: 4
INFO:YOLO2VOC:Classes: ['Anthracnose', 'Bacterial-Black-spot', 'Damaged-mango', 'Fruitly', 'Mechanical-damage', 'Others']
INFO:YOLO2VOC:
Processing train split...
INFO:YOLO2VOC:Found 2306 images in train split


Converting train set:   0%|          | 0/2306 [00:00<?, ?it/s]

INFO:YOLO2VOC:Successfully processed 2306 images for train split
INFO:YOLO2VOC:
Processing valid split...
INFO:YOLO2VOC:Found 493 images in valid split


Converting valid set:   0%|          | 0/493 [00:00<?, ?it/s]

INFO:YOLO2VOC:Successfully processed 493 images for valid split
INFO:YOLO2VOC:
Processing test split...
INFO:YOLO2VOC:Found 492 images in test split


Converting test set:   0%|          | 0/492 [00:00<?, ?it/s]

INFO:YOLO2VOC:Successfully processed 492 images for test split
INFO:YOLO2VOC:
Conversion completed successfully!
INFO:YOLO2VOC:Pascal VOC format dataset created at: /content/ssd_dataset
INFO:YOLO2VOC:Directory structure:
INFO:YOLO2VOC:  - train/
INFO:YOLO2VOC:      - JPEGImages/ (contains 2306 images)
INFO:YOLO2VOC:      - Annotations/ (contains 2306 XML files)
INFO:YOLO2VOC:      - ImageSets/Main/train.txt (contains 2306 entries)
INFO:YOLO2VOC:  - valid/
INFO:YOLO2VOC:      - JPEGImages/ (contains 493 images)
INFO:YOLO2VOC:      - Annotations/ (contains 493 XML files)
INFO:YOLO2VOC:      - ImageSets/Main/valid.txt (contains 493 entries)
INFO:YOLO2VOC:  - test/
INFO:YOLO2VOC:      - JPEGImages/ (contains 492 images)
INFO:YOLO2VOC:      - Annotations/ (contains 492 XML files)
INFO:YOLO2VOC:      - ImageSets/Main/test.txt (contains 492 entries)
INFO:YOLO2VOC:  - labelmap.txt (class mapping file)
INFO:YOLO2VOC:  - ImageSets/Main/trainval.txt (combined train+val list)
INFO:YOLO2VOC:
=== SS

True

In [None]:
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import cv2
import argparse
import shutil
import glob
import xml.etree.ElementTree as ET
from pathlib import Path
from tqdm.auto import tqdm
import pandas as pd
from PIL import Image
import torchvision
from torchvision import transforms
from torchvision.models.detection import ssd300_vgg16
from torchvision.models.detection.ssd import SSDHead
from torchvision.models.detection.anchor_utils import DefaultBoxGenerator
from torch.utils.data import Dataset, DataLoader
from torchvision.ops import box_iou
import datetime

# Check if running in Google Colab
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running in Google Colab environment")
    # Mount Google Drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully")
except ImportError:
    IN_COLAB = False
    print("Running in local environment (not Colab)")

# Set device based on availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# ===========================
# DATASET CLASS FOR PASCAL VOC
# ===========================

class PascalVOCDataset(Dataset):
    """Dataset for Pascal VOC format data"""

    def __init__(self, root, split='train', transforms=None):
        """
        Args:
            root (string): Root directory of the VOC Dataset.
            split (string): 'train', 'val', or 'test'
            transforms (callable, optional): Optional transform to be applied on a sample.
        """
        self.root = root
        self.split = split
        self.transforms = transforms

        # Load class names from labelmap
        self.classes = self._load_class_names()
        self.num_classes = len(self.classes)
        print(f"Found {self.num_classes} classes: {self.classes}")

        # Map class names to indices
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

        # Load image IDs
        split_file = os.path.join(root, split, 'ImageSets', 'Main', f'{split}.txt')
        if not os.path.exists(split_file):
            raise FileNotFoundError(f"Split file not found: {split_file}")

        with open(split_file, 'r') as f:
            self.ids = [line.strip() for line in f.readlines()]

        print(f"Loaded {len(self.ids)} images for {split} split")

    def _load_class_names(self):
        """Load class names from labelmap.txt file"""
        labelmap_file = os.path.join(self.root, 'labelmap.txt')
        if not os.path.exists(labelmap_file):
            raise FileNotFoundError(f"Labelmap file not found: {labelmap_file}")

        classes = []
        with open(labelmap_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    # Format is 'index class_name'
                    classes.append(' '.join(parts[1:]))  # Join with spaces in case class name has spaces

        # Add background class as index 0
        return ['__background__'] + classes

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]

        # Load image
        img_path = os.path.join(self.root, self.split, 'JPEGImages', f'{img_id}.jpg')
        if not os.path.exists(img_path):
            # Try PNG if JPG not found
            img_path = os.path.join(self.root, self.split, 'JPEGImages', f'{img_id}.png')
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_id}")

        img = Image.open(img_path).convert("RGB")

        # Load annotations
        anno_path = os.path.join(self.root, self.split, 'Annotations', f'{img_id}.xml')
        target = self._parse_voc_xml(ET.parse(anno_path).getroot(), img_id=idx)  # Pass idx as a unique identifier

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def _parse_voc_xml(self, node, img_id):
        """Parse Pascal VOC XML annotation file"""
        target = {}

        # Get image size
        size = node.find('size')
        width = int(size.find('width').text)
        height = int(size.find('height').text)

        # Initialize empty lists for boxes, labels
        boxes = []
        labels = []

        # Process each object annotation
        for obj in node.findall('object'):
            name = obj.find('name').text

            if name not in self.class_to_idx:
                print(f"Warning: Class '{name}' not in class map, skipping")
                continue

            # Get bounding box coordinates
            bbox = obj.find('bndbox')
            xmin = float(bbox.find('xmin').text)
            ymin = float(bbox.find('ymin').text)
            xmax = float(bbox.find('xmax').text)
            ymax = float(bbox.find('ymax').text)

            # Validate box coordinates
            if xmin >= xmax or ymin >= ymax:
                print(f"Warning: Invalid box coordinates {xmin, ymin, xmax, ymax} in {node.find('filename').text}, skipping")
                continue

            # Convert class name to index (add 1 since 0 is background)
            label = self.class_to_idx[name]

            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(label)

        # Convert lists to tensors
        if boxes:
            target["boxes"] = torch.tensor(boxes, dtype=torch.float32)
            target["labels"] = torch.tensor(labels, dtype=torch.int64)
        else:
            # Create empty tensors if no valid boxes
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0), dtype=torch.int64)

        # Use a simple integer as image_id instead of trying to parse the filename
        target["image_id"] = torch.tensor([img_id], dtype=torch.int64)

        # Calculate box areas
        target["area"] = (target["boxes"][:, 3] - target["boxes"][:, 1]) * (target["boxes"][:, 2] - target["boxes"][:, 0])
        target["iscrowd"] = torch.zeros((len(target["boxes"])), dtype=torch.int64)

        return target

# ============================
# TRANSFORMS AND DATA LOADING
# ============================

class Compose:
    """Composes transforms for object detection"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    """Convert PIL image to tensor"""
    def __call__(self, image, target):
        image = transforms.ToTensor()(image)
        return image, target

class Resize:
    """Resize image and adjust boxes"""
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        # Get original image size
        width, height = image.size

        # Resize image
        image = transforms.Resize((self.size, self.size))(image)

        # Adjust bounding boxes
        if target["boxes"].shape[0] > 0:
            # Scale boxes
            x_scale = self.size / width
            y_scale = self.size / height

            boxes = target["boxes"].clone()
            boxes[:, 0] *= x_scale  # xmin
            boxes[:, 1] *= y_scale  # ymin
            boxes[:, 2] *= x_scale  # xmax
            boxes[:, 3] *= y_scale  # ymax

            target["boxes"] = boxes

            # Update areas
            target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        return image, target

# Create data transforms (with NO augmentation)
def get_transform(train, img_size=300):
    transforms = [
        Resize(img_size),
        ToTensor()
    ]

    return Compose(transforms)

# Custom collate function for batching
def collate_fn(batch):
    return tuple(zip(*batch))

# ============================
# UTILITY CLASSES
# ============================

class SmoothedValue:
    """Track a series of values and provide access to smoothed values"""
    def __init__(self, window_size=20):
        self.window_size = window_size
        self.reset()

    def reset(self):
        self.values = []
        self.total = 0.0
        self.count = 0

    def update(self, value):
        self.values.append(value)
        if len(self.values) > self.window_size:
            self.values.pop(0)
        self.total += value
        self.count += 1

    @property
    def median(self):
        return np.median(self.values).item() if self.values else 0.0

    @property
    def avg(self):
        return np.mean(self.values).item() if self.values else 0.0

    @property
    def global_avg(self):
        return self.total / self.count if self.count > 0 else 0.0

    def __str__(self):
        return f"{self.global_avg:.4f} ({self.avg:.4f})"

class MetricLogger:
    """Utility class for logging metrics during training and evaluation"""
    def __init__(self, delimiter="\t"):
        self.meters = {}
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if k not in self.meters:
                self.meters[k] = SmoothedValue()
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {meter}")
        return self.delimiter.join(loss_str)

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if header is not None:
            print(header)
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue()

        # FIX: Use string formatting that doesn't rely on format specifiers
        space_fmt = len(str(len(iterable)))

        for obj in iterable:
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                # FIX: Use a simpler string format with manual padding
                print(
                    f"{header} [{i:{space_fmt}d}/{len(iterable)}]  "
                    f"eta: {eta_string}  "
                    f"time: {iter_time.global_avg:.4f}  "
                    f"{self}"
                )
            i += 1
            end = time.time()

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(f"{header} Time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")

# ============================
# MODEL DEFINITION
# ============================

def create_ssd_model(num_classes, pretrained=True):
    """Create an SSD300 model with a VGG16 backbone"""
    # Create SSD model with pretrained VGG backbone if requested
    weights = None
    if pretrained:
        try:
            # For newer PyTorch versions
            from torchvision.models.detection.ssd import SSD300_VGG16_Weights
            weights = SSD300_VGG16_Weights.DEFAULT
        except ImportError:
            # For older PyTorch versions
            weights = None
            # Will use pretrained=True instead

    # Create the model
    if weights is not None:
        model = ssd300_vgg16(weights=weights)
    else:
        model = ssd300_vgg16(pretrained=pretrained)

    # Replace the classifier for our number of classes
    # For models created with newer PyTorch versions
    if hasattr(model, 'head'):
        # Find the number of anchors and channels
        num_anchors = model.anchor_generator.num_anchors_per_location()
        if hasattr(model.backbone, 'out_channels'):
            in_channels = model.backbone.out_channels
        else:
            # For newer versions where out_channels is not directly accessible
            # Typical values for SSD300 with VGG16
            in_channels = [512, 1024, 512, 256, 256, 256]

        # Create new SSD head
        model.head = SSDHead(in_channels, num_anchors, num_classes)
    else:
        # For older versions
        # Find out the number of classes in the pre-trained model
        old_num_classes = model.roi_heads.box_predictor.cls_score.out_features

        # Replace only if our number of classes is different
        if old_num_classes != num_classes:
            # Create a new head with the correct number of classes
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
            model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    print(f"Created SSD300 model with {'pretrained' if pretrained else 'random'} VGG16 backbone")
    return model

# ============================
# MAIN EXECUTION
# ============================

if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser(description='Train SSD model for mango disease detection (NO AUGMENTATION)')
    parser.add_argument('--voc-path', type=str, default='/content/ssd_dataset',
                        help='Path to Pascal VOC format dataset')
    parser.add_argument('--output-dir', type=str, default='/content/ssd_model',
                        help='Path to save model outputs')
    parser.add_argument('--gdrive-dir', type=str, default='/content/drive/MyDrive/MANGO/PROJECT/mango_ssd',
                        help='Google Drive directory to save model (for Colab users)')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of epochs for training')
    parser.add_argument('--batch-size', type=int, default=8,
                        help='Batch size for training')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate')
    parser.add_argument('--pretrained', action='store_true',
                        help='Use pretrained VGG backbone')
    parser.add_argument('--image-size', type=int, default=300,
                        help='Image size for SSD300')
    parser.add_argument('--eval-freq', type=int, default=5,
                        help='Frequency of evaluation during training')
    parser.add_argument('--save-freq', type=int, default=10,
                        help='Frequency of saving model checkpoints')

    # For IPython/Jupyter/Colab
    if 'ipykernel' in sys.modules or 'IPython' in sys.modules or IN_COLAB:
        # Default arguments for notebook mode
        args = parser.parse_args([])
        args.pretrained = True  # Default to using pretrained backbone in Colab
        print("Running in notebook/Colab mode with default arguments")
    else:
        args = parser.parse_args()

    # Create output directories
    os.makedirs(args.output_dir, exist_ok=True)
    if IN_COLAB and args.gdrive_dir:
        os.makedirs(args.gdrive_dir, exist_ok=True)

    # Verify dataset existence
    print("\n==== VERIFYING DATASET ====")
    if not os.path.exists(args.voc_path) or not os.path.exists(os.path.join(args.voc_path, 'labelmap.txt')):
        print(f"Error: Pascal VOC dataset not found at {args.voc_path}")
        print("Please ensure you have converted your YOLO dataset to Pascal VOC format.")
        sys.exit(1)

    # Check if training split exists
    train_dir = os.path.join(args.voc_path, 'train')
    if not os.path.exists(train_dir):
        print(f"Error: Training directory not found at {train_dir}")
        sys.exit(1)

    # Check if validation split exists
    valid_found = False
    for val_name in ['val', 'valid']:
        val_dir = os.path.join(args.voc_path, val_name)
        if os.path.exists(val_dir):
            valid_found = True
            break

    if not valid_found:
        print(f"Error: Validation directory not found at {args.voc_path}/val or {args.voc_path}/valid")
        sys.exit(1)

    print("Dataset verification completed successfully.")

    # Train the model
    try:
        model = train_ssd_model(args)
        print("SSD model training completed successfully.")
    except Exception as e:
        print(f"Error during SSD model training: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

    print("\n==== ALL STEPS COMPLETED SUCCESSFULLY ====")
    best_model_path = os.path.join(args.output_dir, 'best_model.pth')
    if os.path.exists(best_model_path):
        print(f"SSD model trained and saved to {best_model_path}")
        if IN_COLAB and args.gdrive_dir:
            gdrive_model_path = os.path.join(args.gdrive_dir, 'best_model.pth')
            if os.path.exists(gdrive_model_path):
                print(f"Model also backed up to Google Drive at {gdrive_model_path}")
    else:
        print("Warning: Best model file not found. Check logs for errors during training.")

# ============================
# TRAINING FUNCTIONS
# ============================

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
    """Train for one epoch"""
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    header = f'Epoch: [{epoch}]'

    for i, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)

        # Calculate total loss
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass and optimize
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        # Update metrics
        metric_logger.update(loss=losses.item())

        # Print loss values individually
        for k, v in loss_dict.items():
            metric_logger.update(**{k: v.item()})

    return metric_logger

def evaluate(model, data_loader, device, epoch):
    """Evaluate model on validation dataset"""
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    header = f'Validation: [{epoch}]'

    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for images, targets in metric_logger.log_every(data_loader, 10, header):
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Get model predictions
            outputs = model(images)

            # Store predictions and targets for mAP calculation
            all_predictions.extend(outputs)
            all_targets.extend(targets)

    # Calculate mAP
    mAP = calculate_mAP(all_predictions, all_targets)
    print(f"Epoch {epoch}: mAP = {mAP:.4f}")

    return mAP

def calculate_mAP(predictions, targets, iou_threshold=0.5):
    """Calculate mean Average Precision"""
    # This is a simplified mAP calculation
    # For production use, consider using torchvision's detection evaluation utils

    # Initialize APs for each class
    n_classes = max([max(target['labels']).item() for target in targets if len(target['labels']) > 0], default=0) + 1
    average_precisions = [[] for _ in range(n_classes)]

    # For each image in the batch
    for pred, target in zip(predictions, targets):
        pred_boxes = pred['boxes']
        pred_scores = pred['scores']
        pred_labels = pred['labels']

        target_boxes = target['boxes']
        target_labels = target['labels']

        # For each class
        for cls in range(1, n_classes):  # Skip background class (0)
            # Get predictions and targets for this class
            mask_pred = pred_labels == cls
            mask_target = target_labels == cls

            if not mask_target.any():
                # No ground truth for this class
                continue

            if not mask_pred.any():
                # No predictions for this class
                average_precisions[cls].append(0.0)
                continue

            # Sort predictions by score
            pred_boxes_cls = pred_boxes[mask_pred]
            pred_scores_cls = pred_scores[mask_pred]

            # Sort by confidence score
            indices = torch.argsort(pred_scores_cls, descending=True)
            pred_boxes_cls = pred_boxes_cls[indices]

            target_boxes_cls = target_boxes[mask_target]

            # Calculate IoU between predictions and targets
            ious = box_iou(pred_boxes_cls, target_boxes_cls)

            # For each prediction, check if it matches a ground truth
            tp = torch.zeros(len(pred_boxes_cls))
            fp = torch.zeros(len(pred_boxes_cls))

            for i in range(len(pred_boxes_cls)):
                # Get IoUs for this prediction
                box_ious = ious[i]

                # Get the best IoU and index
                if len(box_ious) > 0:
                    max_iou, max_idx = torch.max(box_ious, dim=0)

                    if max_iou >= iou_threshold:
                        tp[i] = 1
                        # Remove the matched target to prevent multiple matches
                        ious[:, max_idx] = 0
                    else:
                        fp[i] = 1
                else:
                    fp[i] = 1

            # Calculate precision and recall
            tp_cumsum = torch.cumsum(tp, dim=0)
            fp_cumsum = torch.cumsum(fp, dim=0)

            precisions = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
            recalls = tp_cumsum / len(target_boxes_cls)

            # Compute average precision (area under PR curve)
            # Add a start point (0, 1) and an end point (1, 0)
            precisions = torch.cat([torch.tensor([1]).to(precisions.device), precisions])
            recalls = torch.cat([torch.tensor([0]).to(recalls.device), recalls])

            # Compute area under PR curve using trapezoidal rule
            ap = torch.trapz(precisions, recalls)
            average_precisions[cls].append(ap.item())

    # Calculate mAP
    class_aps = [np.mean(aps) if aps else 0.0 for aps in average_precisions]
    mAP = np.mean([ap for ap in class_aps[1:] if not np.isnan(ap)])  # Skip background class

    return mAP

# ============================
# VISUALIZATION FUNCTIONS
# ============================

def plot_loss_curve(train_losses, val_maps, output_dir):
    """Plot training loss and validation mAP curves"""
    plt.figure(figsize=(12, 5))

    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, 'b-')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    # Plot mAP
    plt.subplot(1, 2, 2)

    # If val_maps are collected less frequently, create corresponding epoch indices
    if len(val_maps) < len(train_losses):
        eval_freq = len(train_losses) // len(val_maps)
        eval_epochs = list(range(0, len(train_losses), eval_freq))[:len(val_maps)]
        plt.plot(eval_epochs, val_maps, 'r-')
    else:
        plt.plot(val_maps, 'r-')

    plt.title('Validation mAP')
    plt.xlabel('Epoch')
    plt.ylabel('mAP')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'))
    plt.close()

def visualize_predictions(model, dataset, device, num_images=5, output_dir=None):
    """Visualize model predictions on sample images"""
    if output_dir is None:
        output_dir = "predictions"
    os.makedirs(output_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Create a subplot grid for visualization
    fig, axes = plt.subplots(num_images, 2, figsize=(12, 3*num_images))

    # Randomly sample images
    indices = np.random.choice(len(dataset), num_images, replace=False)

    for i, idx in enumerate(indices):
        # Get image and target
        image, target = dataset[idx]

        # Convert image for visualization
        image_vis = np.array(transforms.ToPILImage()(image))

        # Make prediction
        with torch.no_grad():
            prediction = model([image.to(device)])[0]

        # Plot ground truth
        axes[i, 0].imshow(image_vis)
        axes[i, 0].set_title("Ground Truth")

        # Draw ground truth boxes
        for box, label in zip(target["boxes"], target["labels"]):
            box = box.cpu().numpy()
            xmin, ymin, xmax, ymax = box
            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                 fill=False, edgecolor='green', linewidth=2)
            axes[i, 0].add_patch(rect)

            class_name = dataset.classes[label.item()]
            axes[i, 0].text(xmin, ymin-5, class_name, color='green',
                           backgroundcolor='white', fontsize=8)

        # Plot prediction
        axes[i, 1].imshow(image_vis)
        axes[i, 1].set_title("Prediction")

        # Filter predictions with confidence > 0.5
        mask = prediction["scores"] > 0.5
        boxes = prediction["boxes"][mask].cpu().numpy()
        labels = prediction["labels"][mask].cpu().numpy()
        scores = prediction["scores"][mask].cpu().numpy()

        # Draw predicted boxes
        for box, label, score in zip(boxes, labels, scores):
            xmin, ymin, xmax, ymax = box
            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                 fill=False, edgecolor='red', linewidth=2)
            axes[i, 1].add_patch(rect)

            class_name = dataset.classes[label]
            axes[i, 1].text(xmin, ymin-5, f"{class_name}: {score:.2f}",
                           color='red', backgroundcolor='white', fontsize=8)

        # Hide axis ticks
        axes[i, 0].set_xticks([])
        axes[i, 0].set_yticks([])
        axes[i, 1].set_xticks([])
        axes[i, 1].set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'sample_predictions.png'))
    plt.close()

    print(f"Prediction visualization saved to {os.path.join(output_dir, 'sample_predictions.png')}")

# ============================
# DATA LOADERS
# ============================

def create_data_loaders(args):
    train_dataset = PascalVOCDataset(
        root=args.voc_path,
        split='train',
        transforms=get_transform(train=True, img_size=args.image_size)
    )

    # Try 'val' or 'valid' for validation set
    val_split = 'val'
    if not os.path.exists(os.path.join(args.voc_path, 'val')):
        val_split = 'valid'

    val_dataset = PascalVOCDataset(
        root=args.voc_path,
        split=val_split,
        transforms=get_transform(train=False, img_size=args.image_size)
    )

    # Set num_workers=0 to avoid multiprocessing issues in Colab
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )

    return train_loader, val_loader, train_dataset.num_classes

# ============================
# MAIN TRAINING FUNCTION
# ============================

def train_ssd_model(args):
    """Main training function"""
    print("\n==== TRAINING SSD MODEL (NO AUGMENTATION) ====")
    print(f"Dataset path: {args.voc_path}")
    print(f"Output directory: {args.output_dir}")
    print(f"Training for {args.epochs} epochs with batch size {args.batch_size}")

    # Create data loaders
    try:
        train_loader, val_loader, num_classes = create_data_loaders(args)
    except Exception as e:
        print(f"Error creating data loaders: {e}")
        import traceback
        traceback.print_exc()
        return

    print(f"Created data loaders: {len(train_loader)} training batches, {len(val_loader)} validation batches")
    print(f"Number of classes (including background): {num_classes}")

    # Create model
    try:
        model = create_ssd_model(num_classes, pretrained=args.pretrained)
        model.to(device)
    except Exception as e:
        print(f"Error creating model: {e}")
        import traceback
        traceback.print_exc()
        return

    print(f"Created SSD300 model with VGG16 backbone")

    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=args.lr)

    # Learning rate scheduler
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)

    # Initialize training metrics
    train_losses = []
    val_maps = []
    best_map = 0.0
    best_model_path = os.path.join(args.output_dir, 'best_model.pth')

    # Training loop
    print("\nStarting training...")
    start_time = time.time()

    for epoch in range(args.epochs):
        try:
            # Train for one epoch
            metric_logger = train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=10)

            # Update learning rate
            lr_scheduler.step()

            # Record training loss
            train_losses.append(metric_logger.loss.global_avg)

            # Evaluate on validation set
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                mAP = evaluate(model, val_loader, device, epoch)
                val_maps.append(mAP)

                # Save best model
                if mAP > best_map:
                    best_map = mAP
                    torch.save(model.state_dict(), best_model_path)
                    print(f"Saved best model with mAP: {mAP:.4f}")

                    # Copy to Google Drive if in Colab
                    if IN_COLAB and args.gdrive_dir:
                        gdrive_best_path = os.path.join(args.gdrive_dir, 'best_model.pth')
                        shutil.copy(best_model_path, gdrive_best_path)
                        print(f"Copied best model to Google Drive: {gdrive_best_path}")

            # Save checkpoint
            if (epoch + 1) % args.save_freq == 0 or epoch == args.epochs - 1:
                checkpoint_path = os.path.join(args.output_dir, f'checkpoint_{epoch+1}.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': lr_scheduler.state_dict(),
                    'best_map': best_map
                }, checkpoint_path)
                print(f"Saved checkpoint at epoch {epoch+1}")

            # Plot training curves after each evaluation
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                plot_loss_curve(train_losses, val_maps, args.output_dir)

                # Copy visualization to Google Drive if in Colab
                if IN_COLAB and args.gdrive_dir:
                    curves_path = os.path.join(args.output_dir, 'training_curves.png')
                    gdrive_curves_path = os.path.join(args.gdrive_dir, 'training_curves.png')
                    if os.path.exists(curves_path):
                        shutil.copy(curves_path, gdrive_curves_path)

        except Exception as e:
            print(f"Error in epoch {epoch}: {e}")
            import traceback
            traceback.print_exc()
            continue

    # Training complete
    total_time = time.time() - start_time
    print(f"\nTraining complete in {total_time/60:.2f} minutes")
    print(f"Best validation mAP: {best_map:.4f}")

    # Final model evaluation
    try:
        # Load the best model
        if os.path.exists(best_model_path):
            model.load_state_dict(torch.load(best_model_path))
            print("Loaded best model for final evaluation")

            # Final validation
            final_map = evaluate(model, val_loader, device, epoch=args.epochs)
            print(f"Final validation mAP: {final_map:.4f}")

            # Visualize predictions
            print("Generating prediction visualizations...")
            val_dataset = val_loader.dataset  # Get the validation dataset
            visualize_predictions(model, val_dataset, device,
                                  num_images=5, output_dir=args.output_dir)

            # Copy visualization to Google Drive if in Colab
            if IN_COLAB and args.gdrive_dir:
                pred_path = os.path.join(args.output_dir, 'sample_predictions.png')
                gdrive_pred_path = os.path.join(args.gdrive_dir, 'sample_predictions.png')
                if os.path.exists(pred_path):
                    shutil.copy(pred_path, gdrive_pred_path)
                    print(f"Copied prediction visualization to Google Drive: {gdrive_pred_path}")
        else:
            print(f"Warning: Best model not found at {best_model_path}")

    except Exception as e:
        print(f"Error in final evaluation: {e}")
        import traceback
        traceback.print_exc()

    # Generate final report
    report_path = os.path.join(args.output_dir, 'training_report.txt')
    with open(report_path, 'w') as f:
        f.write("SSD Model Training Report (NO AUGMENTATION)\n")
        f.write("=======================================\n\n")
        f.write(f"Dataset: {args.voc_path}\n")
        f.write(f"Number of classes: {num_classes}\n")
        f.write(f"Training epochs: {args.epochs}\n")
        f.write(f"Batch size: {args.batch_size}\n")
        f.write(f"Learning rate: {args.lr}\n")
        f.write(f"Image size: {args.image_size}\n\n")
        f.write(f"Data augmentation: None\n\n")

        f.write("Results:\n")
        f.write(f"Best validation mAP: {best_map:.4f}\n")
        f.write(f"Training time: {total_time/60:.2f} minutes\n\n")

        f.write("Training Loss:\n")
        for i, loss in enumerate(train_losses):
            f.write(f"Epoch {i+1}: {loss:.4f}\n")

        f.write("\nValidation mAP:\n")
        for i, mAP in enumerate(val_maps):
            epoch = i * args.eval_freq + args.eval_freq
            f.write(f"Epoch {epoch}: {mAP:.4f}\n")

    print(f"Training report saved to {report_path}")

    # Copy report to Google Drive if in Colab
    if IN_COLAB and args.gdrive_dir:
        gdrive_report_path = os.path.join(args.gdrive_dir, 'training_report.txt')
        shutil.copy(report_path, gdrive_report_path)
        print(f"Copied training report to Google Drive: {gdrive_report_path}")

    print("\n==== SSD MODEL TRAINING COMPLETE ====")
    return model

Running in Google Colab environment
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully
Using device: cuda
Running in notebook/Colab mode with default arguments

==== VERIFYING DATASET ====
Dataset verification completed successfully.

==== TRAINING SSD MODEL (NO AUGMENTATION) ====
Dataset path: /content/ssd_dataset
Output directory: /content/ssd_model
Training for 50 epochs with batch size 8
Found 7 classes: ['__background__', 'Anthracnose', 'Bacterial-Black-spot', 'Damaged-mango', 'Fruitly', 'Mechanical-damage', 'Others']
Loaded 2306 images for train split
Found 7 classes: ['__background__', 'Anthracnose', 'Bacterial-Black-spot', 'Damaged-mango', 'Fruitly', 'Mechanical-damage', 'Others']
Loaded 493 images for valid split
Created data loaders: 289 training batches, 62 validation batches
Number of classes (including background): 7


Downloading: "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth" to /root/.cache/torch/hub/checkpoints/ssd300_vgg16_coco-b556d3b4.pth
100%|██████████| 136M/136M [00:01<00:00, 93.2MB/s]


Created SSD300 model with pretrained VGG16 backbone
Created SSD300 model with VGG16 backbone

Starting training...
Epoch: [0]
Epoch: [0] [  0/289]  eta: 0:00:25  time: 0.0897  loss: 29.2995 (29.2995)  bbox_regression: 6.1576 (6.1576)  classification: 23.1419 (23.1419)
Epoch: [0] [ 10/289]  eta: 0:00:23  time: 0.0845  loss: 35.3502 (35.3502)  bbox_regression: 5.7341 (5.7341)  classification: 29.6161 (29.6161)
Epoch: [0] [ 20/289]  eta: 0:00:21  time: 0.0796  loss: 23.2520 (22.9497)  bbox_regression: 3.8103 (3.6929)  classification: 19.4418 (19.2568)
Epoch: [0] [ 30/289]  eta: 0:00:20  time: 0.0793  loss: 18.6134 (9.4081)  bbox_regression: 3.1207 (1.6833)  classification: 15.4927 (7.7248)
Epoch: [0] [ 40/289]  eta: 0:00:19  time: 0.0790  loss: 15.9811 (8.3467)  bbox_regression: 2.6163 (1.3627)  classification: 13.3648 (6.9840)
Epoch: [0] [ 50/289]  eta: 0:00:18  time: 0.0793  loss: 14.4511 (7.9997)  bbox_regression: 2.3987 (1.2796)  classification: 12.0524 (6.7201)
Epoch: [0] [ 60/289]  

  model.load_state_dict(torch.load(best_model_path))


Validation: [50] [10/62]  eta: 0:00:03  time: 0.0696  
Validation: [50] [20/62]  eta: 0:00:02  time: 0.0685  
Validation: [50] [30/62]  eta: 0:00:02  time: 0.0685  
Validation: [50] [40/62]  eta: 0:00:01  time: 0.0686  
Validation: [50] [50/62]  eta: 0:00:00  time: 0.0686  
Validation: [50] [60/62]  eta: 0:00:00  time: 0.0686  
Validation: [50] [61/62]  eta: 0:00:00  time: 0.0681  
Validation: [50] Time: 0:00:04 (0.0682 s / it)
Epoch 50: mAP = 0.7666
Final validation mAP: 0.7666
Generating prediction visualizations...
Prediction visualization saved to /content/ssd_model/sample_predictions.png
Copied prediction visualization to Google Drive: /content/drive/MyDrive/MANGO/PROJECT/mango_ssd/sample_predictions.png
Training report saved to /content/ssd_model/training_report.txt
Copied training report to Google Drive: /content/drive/MyDrive/MANGO/PROJECT/mango_ssd/training_report.txt

==== SSD MODEL TRAINING COMPLETE ====
SSD model training completed successfully.

==== ALL STEPS COMPLETED SU

In [None]:
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import cv2
import argparse
import shutil
import glob
import xml.etree.ElementTree as ET
from pathlib import Path
from tqdm.auto import tqdm
import pandas as pd
from PIL import Image
import torchvision
from torchvision import transforms
from torchvision.models.detection import ssd300_vgg16
from torchvision.models.detection.ssd import SSDHead
from torchvision.models.detection.anchor_utils import DefaultBoxGenerator
from torch.utils.data import Dataset, DataLoader
from torchvision.ops import box_iou
import datetime

# Check if running in Google Colab
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running in Google Colab environment")
    # Mount Google Drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully")
except ImportError:
    IN_COLAB = False
    print("Running in local environment (not Colab)")

# Set device based on availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# ===========================
# DATASET CLASS FOR PASCAL VOC
# ===========================

class PascalVOCDataset(Dataset):
    """Dataset for Pascal VOC format data"""

    def __init__(self, root, split='train', transforms=None):
        """
        Args:
            root (string): Root directory of the VOC Dataset.
            split (string): 'train', 'val', or 'test'
            transforms (callable, optional): Optional transform to be applied on a sample.
        """
        self.root = root
        self.split = split
        self.transforms = transforms

        # Load class names from labelmap
        self.classes = self._load_class_names()
        self.num_classes = len(self.classes)
        print(f"Found {self.num_classes} classes: {self.classes}")

        # Map class names to indices
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

        # Load image IDs
        split_file = os.path.join(root, split, 'ImageSets', 'Main', f'{split}.txt')
        if not os.path.exists(split_file):
            raise FileNotFoundError(f"Split file not found: {split_file}")

        with open(split_file, 'r') as f:
            self.ids = [line.strip() for line in f.readlines()]

        print(f"Loaded {len(self.ids)} images for {split} split")

    def _load_class_names(self):
        """Load class names from labelmap.txt file"""
        labelmap_file = os.path.join(self.root, 'labelmap.txt')
        if not os.path.exists(labelmap_file):
            raise FileNotFoundError(f"Labelmap file not found: {labelmap_file}")

        classes = []
        with open(labelmap_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    # Format is 'index class_name'
                    classes.append(' '.join(parts[1:]))  # Join with spaces in case class name has spaces

        # Add background class as index 0
        return ['__background__'] + classes

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]

        # Load image
        img_path = os.path.join(self.root, self.split, 'JPEGImages', f'{img_id}.jpg')
        if not os.path.exists(img_path):
            # Try PNG if JPG not found
            img_path = os.path.join(self.root, self.split, 'JPEGImages', f'{img_id}.png')
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_id}")

        img = Image.open(img_path).convert("RGB")

        # Load annotations
        anno_path = os.path.join(self.root, self.split, 'Annotations', f'{img_id}.xml')
        target = self._parse_voc_xml(ET.parse(anno_path).getroot(), img_id=idx)  # Pass idx as a unique identifier

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def _parse_voc_xml(self, node, img_id):
        """Parse Pascal VOC XML annotation file"""
        target = {}

        # Get image size
        size = node.find('size')
        width = int(size.find('width').text)
        height = int(size.find('height').text)

        # Initialize empty lists for boxes, labels
        boxes = []
        labels = []

        # Process each object annotation
        for obj in node.findall('object'):
            name = obj.find('name').text

            if name not in self.class_to_idx:
                print(f"Warning: Class '{name}' not in class map, skipping")
                continue

            # Get bounding box coordinates
            bbox = obj.find('bndbox')
            xmin = float(bbox.find('xmin').text)
            ymin = float(bbox.find('ymin').text)
            xmax = float(bbox.find('xmax').text)
            ymax = float(bbox.find('ymax').text)

            # Validate box coordinates
            if xmin >= xmax or ymin >= ymax:
                print(f"Warning: Invalid box coordinates {xmin, ymin, xmax, ymax} in {node.find('filename').text}, skipping")
                continue

            # Convert class name to index (add 1 since 0 is background)
            label = self.class_to_idx[name]

            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(label)

        # Convert lists to tensors
        if boxes:
            target["boxes"] = torch.tensor(boxes, dtype=torch.float32)
            target["labels"] = torch.tensor(labels, dtype=torch.int64)
        else:
            # Create empty tensors if no valid boxes
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0), dtype=torch.int64)

        # Use a simple integer as image_id instead of trying to parse the filename
        target["image_id"] = torch.tensor([img_id], dtype=torch.int64)

        # Calculate box areas
        target["area"] = (target["boxes"][:, 3] - target["boxes"][:, 1]) * (target["boxes"][:, 2] - target["boxes"][:, 0])
        target["iscrowd"] = torch.zeros((len(target["boxes"])), dtype=torch.int64)

        return target

# ============================
# TRANSFORMS AND DATA LOADING
# ============================

class Compose:
    """Composes transforms for object detection"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    """Convert PIL image to tensor"""
    def __call__(self, image, target):
        image = transforms.ToTensor()(image)
        return image, target

class Resize:
    """Resize image and adjust boxes"""
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        # Get original image size
        width, height = image.size

        # Resize image
        image = transforms.Resize((self.size, self.size))(image)

        # Adjust bounding boxes
        if target["boxes"].shape[0] > 0:
            # Scale boxes
            x_scale = self.size / width
            y_scale = self.size / height

            boxes = target["boxes"].clone()
            boxes[:, 0] *= x_scale  # xmin
            boxes[:, 1] *= y_scale  # ymin
            boxes[:, 2] *= x_scale  # xmax
            boxes[:, 3] *= y_scale  # ymax

            target["boxes"] = boxes

            # Update areas
            target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        return image, target

# Create data transforms (with NO augmentation)
def get_transform(train, img_size=300):
    transforms = [
        Resize(img_size),
        ToTensor()
    ]

    return Compose(transforms)

# Custom collate function for batching
def collate_fn(batch):
    return tuple(zip(*batch))

# ============================
# UTILITY CLASSES
# ============================

class SmoothedValue:
    """Track a series of values and provide access to smoothed values"""
    def __init__(self, window_size=20):
        self.window_size = window_size
        self.reset()

    def reset(self):
        self.values = []
        self.total = 0.0
        self.count = 0

    def update(self, value):
        self.values.append(value)
        if len(self.values) > self.window_size:
            self.values.pop(0)
        self.total += value
        self.count += 1

    @property
    def median(self):
        return np.median(self.values).item() if self.values else 0.0

    @property
    def avg(self):
        return np.mean(self.values).item() if self.values else 0.0

    @property
    def global_avg(self):
        return self.total / self.count if self.count > 0 else 0.0

    def __str__(self):
        return f"{self.global_avg:.4f} ({self.avg:.4f})"

class MetricLogger:
    """Utility class for logging metrics during training and evaluation"""
    def __init__(self, delimiter="\t"):
        self.meters = {}
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if k not in self.meters:
                self.meters[k] = SmoothedValue()
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {meter}")
        return self.delimiter.join(loss_str)

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if header is not None:
            print(header)
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue()

        # FIX: Use string formatting that doesn't rely on format specifiers
        space_fmt = len(str(len(iterable)))

        for obj in iterable:
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                # FIX: Use a simpler string format with manual padding
                print(
                    f"{header} [{i:{space_fmt}d}/{len(iterable)}]  "
                    f"eta: {eta_string}  "
                    f"time: {iter_time.global_avg:.4f}  "
                    f"{self}"
                )
            i += 1
            end = time.time()

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(f"{header} Time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")

# ============================
# MODEL DEFINITION
# ============================

def create_ssd_model(num_classes, pretrained=True):
    """Create an SSD300 model with a VGG16 backbone"""
    # Create SSD model with pretrained VGG backbone if requested
    weights = None
    if pretrained:
        try:
            # For newer PyTorch versions
            from torchvision.models.detection.ssd import SSD300_VGG16_Weights
            weights = SSD300_VGG16_Weights.DEFAULT
        except ImportError:
            # For older PyTorch versions
            weights = None
            # Will use pretrained=True instead

    # Create the model
    if weights is not None:
        model = ssd300_vgg16(weights=weights)
    else:
        model = ssd300_vgg16(pretrained=pretrained)

    # Replace the classifier for our number of classes
    # For models created with newer PyTorch versions
    if hasattr(model, 'head'):
        # Find the number of anchors and channels
        num_anchors = model.anchor_generator.num_anchors_per_location()
        if hasattr(model.backbone, 'out_channels'):
            in_channels = model.backbone.out_channels
        else:
            # For newer versions where out_channels is not directly accessible
            # Typical values for SSD300 with VGG16
            in_channels = [512, 1024, 512, 256, 256, 256]

        # Create new SSD head
        model.head = SSDHead(in_channels, num_anchors, num_classes)
    else:
        # For older versions
        # Find out the number of classes in the pre-trained model
        old_num_classes = model.roi_heads.box_predictor.cls_score.out_features

        # Replace only if our number of classes is different
        if old_num_classes != num_classes:
            # Create a new head with the correct number of classes
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
            model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    print(f"Created SSD300 model with {'pretrained' if pretrained else 'random'} VGG16 backbone")
    return model

# ============================
# TRAINING FUNCTIONS
# ============================

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
    """Train for one epoch"""
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    header = f'Epoch: [{epoch}]'

    for i, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)

        # Calculate total loss
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass and optimize
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        # Update metrics
        metric_logger.update(loss=losses.item())

        # Print loss values individually
        for k, v in loss_dict.items():
            metric_logger.update(**{k: v.item()})

    return metric_logger

# ============================
# ENHANCED EVALUATION FUNCTIONS
# ============================

def calculate_mAP(predictions, targets, iou_threshold=0.5, confidence_threshold=0.5, return_per_class=False):
    """Calculate mean Average Precision with confidence threshold"""
    # Initialize APs for each class
    n_classes = max([max(target['labels']).item() for target in targets if len(target['labels']) > 0], default=0) + 1
    average_precisions = [[] for _ in range(n_classes)]

    # For each image in the batch
    for pred, target in zip(predictions, targets):
        pred_boxes = pred['boxes']
        pred_scores = pred['scores']
        pred_labels = pred['labels']

        # Apply confidence threshold
        mask = pred_scores >= confidence_threshold
        pred_boxes = pred_boxes[mask]
        pred_scores = pred_scores[mask]
        pred_labels = pred_labels[mask]

        target_boxes = target['boxes']
        target_labels = target['labels']

        # For each class
        for cls in range(1, n_classes):  # Skip background class (0)
            # Get predictions and targets for this class
            mask_pred = pred_labels == cls
            mask_target = target_labels == cls

            if not mask_target.any():
                # No ground truth for this class
                continue

            if not mask_pred.any():
                # No predictions for this class
                average_precisions[cls].append(0.0)
                continue

            # Sort predictions by score
            pred_boxes_cls = pred_boxes[mask_pred]
            pred_scores_cls = pred_scores[mask_pred]

            # Sort by confidence score
            indices = torch.argsort(pred_scores_cls, descending=True)
            pred_boxes_cls = pred_boxes_cls[indices]

            target_boxes_cls = target_boxes[mask_target]

            # Calculate IoU between predictions and targets
            ious = box_iou(pred_boxes_cls, target_boxes_cls)

            # For each prediction, check if it matches a ground truth
            tp = torch.zeros(len(pred_boxes_cls))
            fp = torch.zeros(len(pred_boxes_cls))

            for i in range(len(pred_boxes_cls)):
                # Get IoUs for this prediction
                box_ious = ious[i]

                # Get the best IoU and index
                if len(box_ious) > 0:
                    max_iou, max_idx = torch.max(box_ious, dim=0)

                    if max_iou >= iou_threshold:
                        tp[i] = 1
                        # Remove the matched target to prevent multiple matches
                        ious[:, max_idx] = 0
                    else:
                        fp[i] = 1
                else:
                    fp[i] = 1

            # Calculate precision and recall
            tp_cumsum = torch.cumsum(tp, dim=0)
            fp_cumsum = torch.cumsum(fp, dim=0)

            precisions = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
            recalls = tp_cumsum / len(target_boxes_cls)

            # Compute average precision (area under PR curve)
            # Add a start point (0, 1) and an end point (1, 0)
            precisions = torch.cat([torch.tensor([1]).to(precisions.device), precisions])
            recalls = torch.cat([torch.tensor([0]).to(recalls.device), recalls])

            # Compute area under PR curve using trapezoidal rule
            ap = torch.trapz(precisions, recalls)
            average_precisions[cls].append(ap.item())

    # Calculate per-class AP
    class_aps = [np.mean(aps) if aps else 0.0 for aps in average_precisions]

    # Calculate mAP
    mAP = np.mean([ap for ap in class_aps[1:] if not np.isnan(ap)])  # Skip background class

    if return_per_class:
        return mAP, class_aps
    return mAP

def calculate_pr_curves(predictions, targets, num_classes, iou_threshold=0.5):
    """Calculate precision-recall curves for each class"""
    # Initialize precision-recall data for each class
    pr_curves = []
    for cls in range(1, num_classes):  # Skip background class
        all_predictions = []
        all_targets = []

        # For each image
        for pred, target in zip(predictions, targets):
            pred_boxes = pred['boxes'].cpu().numpy()
            pred_scores = pred['scores'].cpu().numpy()
            pred_labels = pred['labels'].cpu().numpy()

            target_boxes = target['boxes'].cpu().numpy()
            target_labels = target['labels'].cpu().numpy()

            # Get predictions for this class
            cls_pred_indices = np.where(pred_labels == cls)[0]
            cls_pred_boxes = pred_boxes[cls_pred_indices] if len(cls_pred_indices) > 0 else np.empty((0, 4))
            cls_pred_scores = pred_scores[cls_pred_indices] if len(cls_pred_indices) > 0 else np.empty(0)

            # Get targets for this class
            cls_target_indices = np.where(target_labels == cls)[0]
            cls_target_boxes = target_boxes[cls_target_indices] if len(cls_target_indices) > 0 else np.empty((0, 4))

            all_predictions.append((cls_pred_boxes, cls_pred_scores))
            all_targets.append(cls_target_boxes)

        # Compute precision and recall at different score thresholds
        precisions = []
        recalls = []
        scores = []

        # Flatten predictions across all images
        all_boxes = []
        all_scores = []
        for boxes, scores in all_predictions:
            if len(boxes) > 0:
                all_boxes.append(boxes)
                all_scores.append(scores)

        if all_boxes and all_scores:
            all_boxes = np.vstack(all_boxes)
            all_scores = np.concatenate(all_scores)

            # Count total ground truth objects
            total_gt = sum(len(t) for t in all_targets)

            if total_gt > 0:
                # Sort by score
                indices = np.argsort(-all_scores)
                all_boxes = all_boxes[indices]
                all_scores = all_scores[indices]

                # Iterate through score thresholds
                tp = np.zeros(len(all_boxes))
                fp = np.zeros(len(all_boxes))

                # Make a copy of targets to mark used ones
                used_targets = [np.zeros(len(t), dtype=bool) for t in all_targets]

                # For each prediction
                for i, (box, score) in enumerate(zip(all_boxes, all_scores)):
                    matched = False

                    # For each image
                    for img_idx, target_boxes in enumerate(all_targets):
                        if len(target_boxes) == 0:
                            continue

                        # Calculate IoU with all targets in this image
                        ious = calculate_iou_numpy(box, target_boxes)

                        # Find best match
                        max_iou = np.max(ious) if len(ious) > 0 else 0
                        max_idx = np.argmax(ious) if len(ious) > 0 else -1

                        if max_iou >= iou_threshold and not used_targets[img_idx][max_idx]:
                            matched = True
                            used_targets[img_idx][max_idx] = True
                            break

                    if matched:
                        tp[i] = 1
                    else:
                        fp[i] = 1

                # Calculate precision and recall at each threshold
                cumsum_tp = np.cumsum(tp)
                cumsum_fp = np.cumsum(fp)
                rec = cumsum_tp / total_gt
                prec = cumsum_tp / (cumsum_tp + cumsum_fp)

                # Add sentinel values
                rec = np.concatenate([[0], rec, [1]])
                prec = np.concatenate([[1], prec, [0]])
                scores = np.concatenate([[1], all_scores, [0]])

                # Ensure precision is decreasing
                for i in range(len(prec) - 2, -1, -1):
                    prec[i] = max(prec[i], prec[i + 1])

                precisions = prec
                recalls = rec

        pr_curves.append((precisions, recalls))

    return pr_curves

def calculate_iou_numpy(box, boxes):
    """Calculate IoU between a box and a list of boxes using numpy"""
    # Expand box to shape [1, 4]
    box = box.reshape(1, 4)

    # Calculate intersection area
    ixmin = np.maximum(boxes[:, 0], box[0, 0])
    iymin = np.maximum(boxes[:, 1], box[0, 1])
    ixmax = np.minimum(boxes[:, 2], box[0, 2])
    iymax = np.minimum(boxes[:, 3], box[0, 3])

    iw = np.maximum(ixmax - ixmin, 0)
    ih = np.maximum(iymax - iymin, 0)

    # Intersection area
    inters = iw * ih

    # Union area
    box_area = (box[0, 2] - box[0, 0]) * (box[0, 3] - box[0, 1])
    boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    union = box_area + boxes_area - inters

    # IoU
    iou = inters / (union + 1e-6)

    return iou

def calculate_confusion_matrix(predictions, targets, num_classes, iou_threshold=0.5, confidence_threshold=0.5):
    """Calculate confusion matrix for object detection"""
    # Initialize confusion matrix
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int32)

    # For each image
    for pred, target in zip(predictions, targets):
        pred_boxes = pred['boxes']
        pred_scores = pred['scores']
        pred_labels = pred['labels']

        # Filter by confidence threshold
        mask = pred_scores >= confidence_threshold
        pred_boxes = pred_boxes[mask]
        pred_labels = pred_labels[mask]

        target_boxes = target['boxes']
        target_labels = target['labels']

        # Calculate IoU between all predictions and targets
        if len(pred_boxes) > 0 and len(target_boxes) > 0:
            ious = box_iou(pred_boxes, target_boxes)

            # Track which targets have been matched
            matched_targets = torch.zeros(len(target_labels), dtype=torch.bool)

            # For each prediction, find the best matching target
            for i, pred_label in enumerate(pred_labels):
                if len(ious[i]) > 0:
                    max_iou, max_idx = torch.max(ious[i], dim=0)

                    if max_iou >= iou_threshold and not matched_targets[max_idx]:
                        # This is a match, increment confusion matrix
                        gt_label = target_labels[max_idx]
                        confusion_matrix[gt_label, pred_label] += 1
                        matched_targets[max_idx] = True

                        # Remove this target from consideration for other predictions
                        ious[:, max_idx] = 0
                    else:
                        # False positive (wrong class or low IoU)
                        confusion_matrix[0, pred_label] += 1
                else:
                    # False positive (no target)
                    confusion_matrix[0, pred_label] += 1

            # Count false negatives (unmatched targets)
            for i, is_matched in enumerate(matched_targets):
                if not is_matched:
                    gt_label = target_labels[i]
                    confusion_matrix[gt_label, 0] += 1
        else:
            # All predictions are false positives or all targets are false negatives
            for pred_label in pred_labels:
                confusion_matrix[0, pred_label] += 1

            for gt_label in target_labels:
                confusion_matrix[gt_label, 0] += 1

    return confusion_matrix

def enhanced_evaluate(model, data_loader, device, epoch, confidence_thresholds=[0.5, 0.6, 0.7, 0.8, 0.9], output_dir=None):
    """Evaluate model on validation dataset with enhanced metrics"""
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    header = f'Validation: [{epoch}]'

    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for images, targets in metric_logger.log_every(data_loader, 10, header):
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Get model predictions
            outputs = model(images)

            # Store predictions and targets for metric calculation
            all_predictions.extend(outputs)
            all_targets.extend(targets)

    # Get dataset class names
    class_names = data_loader.dataset.classes

    # Calculate mAP for different confidence thresholds
    mAP_by_threshold = {}
    for threshold in confidence_thresholds:
        mAP, class_ap = calculate_mAP(all_predictions, all_targets, iou_threshold=0.5,
                                     confidence_threshold=threshold, return_per_class=True)
        mAP_by_threshold[threshold] = (mAP, class_ap)

    # Calculate PR curves for each class
    pr_curves = calculate_pr_curves(all_predictions, all_targets, len(class_names))

    # Save evaluation results
    if output_dir:
        # Save mAP by threshold
        save_mAP_threshold_results(mAP_by_threshold, class_names, output_dir, epoch)

        # Plot and save PR curves
        plot_pr_curves(pr_curves, class_names, output_dir, epoch)

        # Plot per-class performance
        plot_per_class_metrics(mAP_by_threshold[0.5][1], class_names, output_dir, epoch)

        # Create confusion matrix if possible
        try:
            conf_matrix = calculate_confusion_matrix(all_predictions, all_targets, len(class_names))
            plot_confusion_matrix(conf_matrix, class_names, output_dir, epoch)
        except Exception as e:
            print(f"Could not create confusion matrix: {e}")

    # Return mAP at standard threshold of 0.5
    mAP = mAP_by_threshold[0.5][0]
    print(f"Epoch {epoch}: mAP@0.5 = {mAP:.4f}")

    # Print per-class mAP
    print("\nPer-class Average Precision:")
    for i, ap in enumerate(mAP_by_threshold[0.5][1]):
        if i == 0:  # Skip background class
            continue
        print(f"  {class_names[i]}: {ap:.4f}")

    return mAP, mAP_by_threshold, pr_curves

# ============================
# ENHANCED VISUALIZATION FUNCTIONS
# ============================

def plot_loss_curve(train_losses, val_maps, output_dir):
    """Plot training loss and validation mAP curves"""
    plt.figure(figsize=(12, 5))

    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, 'b-')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    # Plot mAP
    plt.subplot(1, 2, 2)

    # If val_maps are collected less frequently, create corresponding epoch indices
    if len(val_maps) < len(train_losses):
        eval_freq = len(train_losses) // len(val_maps)
        eval_epochs = list(range(0, len(train_losses), eval_freq))[:len(val_maps)]
        plt.plot(eval_epochs, val_maps, 'r-')
    else:
        plt.plot(val_maps, 'r-')

    plt.title('Validation mAP')
    plt.xlabel('Epoch')
    plt.ylabel('mAP')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'))
    plt.close()

def plot_pr_curves(pr_curves, class_names, output_dir, epoch):
    """Plot precision-recall curves for each class"""
    plt.figure(figsize=(12, 8))

    # Skip background class (index 0)
    for cls_idx, (precisions, recalls) in enumerate(pr_curves, 1):
        if len(precisions) > 1:  # Only if we have valid data
            plt.step(recalls, precisions, where='post', label=class_names[cls_idx])

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curves (Epoch {epoch})')
    plt.xlim([0.0, 1.05])
    plt.ylim([0.0, 1.05])
    plt.legend(loc='lower left')
    plt.grid(True)

    # Save figure
    save_path = os.path.join(output_dir, f'pr_curves_epoch_{epoch}.png')
    plt.savefig(save_path)
    plt.close()
    print(f"Saved PR curves to {save_path}")

def save_mAP_threshold_results(mAP_by_threshold, class_names, output_dir, epoch):
    """Save mAP results for different confidence thresholds"""
    # Create figure for mAP vs threshold
    plt.figure(figsize=(10, 6))
    thresholds = sorted(list(mAP_by_threshold.keys()))
    mAPs = [mAP_by_threshold[t][0] for t in thresholds]

    plt.plot(thresholds, mAPs, 'o-', linewidth=2)
    plt.xlabel('Confidence Threshold')
    plt.ylabel('mAP')
    plt.title(f'mAP vs Confidence Threshold (Epoch {epoch})')
    plt.grid(True)

    # Save figure
    save_path = os.path.join(output_dir, f'mAP_vs_threshold_epoch_{epoch}.png')
    plt.savefig(save_path)
    plt.close()
    print(f"Saved mAP vs threshold plot to {save_path}")

    # Save data as CSV
    csv_path = os.path.join(output_dir, f'mAP_vs_threshold_epoch_{epoch}.csv')
    with open(csv_path, 'w') as f:
        f.write("threshold,mAP\n")
        for thresh, mAP in zip(thresholds, mAPs):
            f.write(f"{thresh},{mAP}\n")

    # Save per-class mAP at 0.5 threshold
    per_class_path = os.path.join(output_dir, f'per_class_mAP_epoch_{epoch}.csv')
    with open(per_class_path, 'w') as f:
        f.write("class,AP\n")
        for i, ap in enumerate(mAP_by_threshold[0.5][1]):
            if i == 0:  # Skip background
                continue
            f.write(f"{class_names[i]},{ap}\n")

def plot_per_class_metrics(class_aps, class_names, output_dir, epoch):
    """Plot per-class performance metrics"""
    # Skip background class (index 0)
    indices = list(range(1, len(class_aps)))
    aps = [class_aps[i] for i in indices]
    names = [class_names[i] for i in indices]

    # Sort by AP value
    sorted_indices = np.argsort(aps)
    sorted_aps = [aps[i] for i in sorted_indices]
    sorted_names = [names[i] for i in sorted_indices]

    plt.figure(figsize=(10, max(6, len(indices) * 0.4)))
    plt.barh(range(len(sorted_names)), sorted_aps, align='center')
    plt.yticks(range(len(sorted_names)), sorted_names)
    plt.xlabel('Average Precision')
    plt.title(f'Per-Class Average Precision (Epoch {epoch})')
    plt.grid(True, axis='x')
    plt.tight_layout()

    # Save figure
    save_path = os.path.join(output_dir, f'per_class_ap_epoch_{epoch}.png')
    plt.savefig(save_path)
    plt.close()
    print(f"Saved per-class AP plot to {save_path}")

def plot_confusion_matrix(confusion_matrix, class_names, output_dir, epoch):
    """Plot confusion matrix"""
    try:
        import seaborn as sns
    except ImportError:
        print("seaborn not available, installing...")
        import pip
        pip.main(['install', 'seaborn'])
        import seaborn as sns

    plt.figure(figsize=(12, 10))

    # Use only the first 20 classes if there are too many
    if len(class_names) > 20:
        confusion_matrix = confusion_matrix[:20, :20]
        display_names = class_names[:20]
    else:
        display_names = class_names

    # Create heatmap
    sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=display_names, yticklabels=display_names)

    plt.xlabel('Predicted Label')
    plt.ylabel('Ground Truth Label')
    plt.title(f'Confusion Matrix (Epoch {epoch})')
    plt.tight_layout()

    # Save figure
    save_path = os.path.join(output_dir, f'confusion_matrix_epoch_{epoch}.png')
    plt.savefig(save_path)
    plt.close()
    print(f"Saved confusion matrix to {save_path}")

def plot_mAP_threshold_curves(val_maps_by_threshold, eval_freq, output_dir):
    """Plot mAP vs confidence threshold curves across training"""
    plt.figure(figsize=(12, 8))

    # Extract all confidence thresholds (from the first evaluation)
    thresholds = sorted(list(val_maps_by_threshold[0].keys()))

    # Plot a line for each evaluation
    for i, mAP_by_threshold in enumerate(val_maps_by_threshold):
        epoch = (i + 1) * eval_freq
        mAPs = [mAP_by_threshold[t][0] for t in thresholds]
        plt.plot(thresholds, mAPs, 'o-', linewidth=2, label=f'Epoch {epoch}')

    plt.xlabel('Confidence Threshold')
    plt.ylabel('mAP')
    plt.title('mAP vs Confidence Threshold Across Training')
    plt.legend()
    plt.grid(True)

    # Save figure
    save_path = os.path.join(output_dir, 'mAP_threshold_curves.png')
    plt.savefig(save_path)
    plt.close()
    print(f"Saved mAP threshold curves to {save_path}")

def enhanced_visualize_predictions(model, dataset, device, num_images=8, confidence_threshold=0.5, output_dir=None):
    """Visualize model predictions on sample images with detailed information"""
    if output_dir is None:
        output_dir = "predictions"
    os.makedirs(output_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Create a subplot grid for visualization
    fig, axes = plt.subplots(num_images, 2, figsize=(16, 4*num_images))

    # Randomly sample images
    indices = np.random.choice(len(dataset), num_images, replace=False)

    # Track per-class detections for metrics calculation
    class_metrics = {}
    for cls_name in dataset.classes[1:]:  # Skip background
        class_metrics[cls_name] = {'TP': 0, 'FP': 0, 'FN': 0}

    for i, idx in enumerate(indices):
        # Get image and target
        image, target = dataset[idx]

        # Convert image for visualization
        image_vis = np.array(transforms.ToPILImage()(image))

        # Make prediction
        with torch.no_grad():
            prediction = model([image.to(device)])[0]

        # Plot ground truth
        axes[i, 0].imshow(image_vis)
        axes[i, 0].set_title("Ground Truth")

        # Draw ground truth boxes
        for box, label in zip(target["boxes"], target["labels"]):
            box = box.cpu().numpy()
            xmin, ymin, xmax, ymax = box
            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                 fill=False, edgecolor='green', linewidth=2)
            axes[i, 0].add_patch(rect)

            class_name = dataset.classes[label.item()]
            axes[i, 0].text(xmin, ymin-5, class_name, color='green',
                           backgroundcolor='white', fontsize=8)

        # Plot prediction
        axes[i, 1].imshow(image_vis)
        axes[i, 1].set_title(f"Prediction (conf >= {confidence_threshold})")

        # Filter predictions with confidence > threshold
        mask = prediction["scores"] > confidence_threshold
        boxes = prediction["boxes"][mask].cpu().numpy()
        labels = prediction["labels"][mask].cpu().numpy()
        scores = prediction["scores"][mask].cpu().numpy()

        # Draw predicted boxes
        for box, label, score in zip(boxes, labels, scores):
            xmin, ymin, xmax, ymax = box
            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                 fill=False, edgecolor='red', linewidth=2)
            axes[i, 1].add_patch(rect)

            class_name = dataset.classes[label]
            axes[i, 1].text(xmin, ymin-5, f"{class_name}: {score:.2f}",
                           color='red', backgroundcolor='white', fontsize=8)

        # Hide axis ticks
        axes[i, 0].set_xticks([])
        axes[i, 0].set_yticks([])
        axes[i, 1].set_xticks([])
        axes[i, 1].set_yticks([])

        # Calculate metrics for this image
        pred_boxes = prediction["boxes"].cpu()
        pred_scores = prediction["scores"].cpu()
        pred_labels = prediction["labels"].cpu()

        gt_boxes = target["boxes"].cpu()
        gt_labels = target["labels"].cpu()

        # Apply confidence threshold
        mask = pred_scores > confidence_threshold
        pred_boxes = pred_boxes[mask]
        pred_labels = pred_labels[mask]

        # Calculate IoU for all predictions with all ground truth
        if len(pred_boxes) > 0 and len(gt_boxes) > 0:
            ious = box_iou(pred_boxes, gt_boxes)

            # Keep track of matched ground truth
            matched_gt = set()

            # For each prediction
            for j, pred_label in enumerate(pred_labels):
                pred_cls = dataset.classes[pred_label.item()]

                # Find best matching ground truth
                if len(ious[j]) > 0:
                    max_iou, max_idx = torch.max(ious[j], dim=0)

                    if max_iou >= 0.5:
                        gt_label = gt_labels[max_idx].item()
                        gt_cls = dataset.classes[gt_label]

                        # True positive if class matches
                        if pred_label.item() == gt_label:
                            class_metrics[gt_cls]['TP'] += 1
                            matched_gt.add(max_idx.item())
                        else:
                            # Wrong class prediction
                            class_metrics[pred_cls]['FP'] += 1
                    else:
                        # No matching ground truth with sufficient IoU
                        class_metrics[pred_cls]['FP'] += 1
                else:
                    # No ground truth to match with
                    class_metrics[pred_cls]['FP'] += 1

            # Count false negatives (unmatched ground truth)
            for j, gt_label in enumerate(gt_labels):
                if j not in matched_gt:
                    gt_cls = dataset.classes[gt_label.item()]
                    class_metrics[gt_cls]['FN'] += 1

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'enhanced_predictions.png'))
    plt.close()

    # Create per-class metrics summary
    plt.figure(figsize=(12, 8))

    # Calculate precision, recall, F1 for each class
    class_names = []
    precision = []
    recall = []
    f1_score = []

    for cls_name, metrics in class_metrics.items():
        tp = metrics['TP']
        fp = metrics['FP']
        fn = metrics['FN']

        # Calculate metrics
        cls_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        cls_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        cls_f1 = 2 * cls_precision * cls_recall / (cls_precision + cls_recall) if (cls_precision + cls_recall) > 0 else 0

        class_names.append(cls_name)
        precision.append(cls_precision)
        recall.append(cls_recall)
        f1_score.append(cls_f1)

    # Sort by F1 score
    sorted_indices = np.argsort(f1_score)
    sorted_names = [class_names[i] for i in sorted_indices]
    sorted_precision = [precision[i] for i in sorted_indices]
    sorted_recall = [recall[i] for i in sorted_indices]
    sorted_f1 = [f1_score[i] for i in sorted_indices]

    # Plot barplot
    x = np.arange(len(sorted_names))
    width = 0.25

    plt.bar(x - width, sorted_precision, width, label='Precision')
    plt.bar(x, sorted_recall, width, label='Recall')
    plt.bar(x + width, sorted_f1, width, label='F1 Score')

    plt.xlabel('Class')
    plt.ylabel('Score')
    plt.title('Per-Class Detection Metrics')
    plt.xticks(x, sorted_names, rotation=45, ha='right')
    plt.legend()
    plt.tight_layout()

    plt.savefig(os.path.join(output_dir, 'per_class_metrics.png'))
    plt.close()

    # Save metrics as CSV
    metrics_df = pd.DataFrame({
        'Class': class_names,
        'TP': [class_metrics[cls]['TP'] for cls in class_names],
        'FP': [class_metrics[cls]['FP'] for cls in class_names],
        'FN': [class_metrics[cls]['FN'] for cls in class_names],
        'Precision': precision,
        'Recall': recall,
        'F1': f1_score
    })

    metrics_df.to_csv(os.path.join(output_dir, 'detection_metrics.csv'), index=False)

    print(f"Enhanced prediction visualization and metrics saved to {output_dir}")

def save_to_google_drive(local_dir, gdrive_dir, IN_COLAB):
    """Copy all results to Google Drive"""
    if not IN_COLAB or not gdrive_dir:
        return

    os.makedirs(gdrive_dir, exist_ok=True)

    # List all files in local directory
    files = glob.glob(os.path.join(local_dir, '*'))

    # Copy each file to Google Drive
    for file_path in files:
        if os.path.isfile(file_path):
            filename = os.path.basename(file_path)
            gdrive_path = os.path.join(gdrive_dir, filename)
            shutil.copy(file_path, gdrive_path)
            print(f"Copied {filename} to Google Drive: {gdrive_path}")
        elif os.path.isdir(file_path):
            # Recursively handle subdirectories
            subdir_name = os.path.basename(file_path)
            gdrive_subdir = os.path.join(gdrive_dir, subdir_name)
            os.makedirs(gdrive_subdir, exist_ok=True)
            save_to_google_drive(file_path, gdrive_subdir, IN_COLAB)

def visualize_predictions(model, dataset, device, num_images=5, output_dir=None):
    """Visualize model predictions on sample images"""
    if output_dir is None:
        output_dir = "predictions"
    os.makedirs(output_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Create a subplot grid for visualization
    fig, axes = plt.subplots(num_images, 2, figsize=(12, 3*num_images))

    # Randomly sample images
    indices = np.random.choice(len(dataset), num_images, replace=False)

    for i, idx in enumerate(indices):
        # Get image and target
        image, target = dataset[idx]

        # Convert image for visualization
        image_vis = np.array(transforms.ToPILImage()(image))

        # Make prediction
        with torch.no_grad():
            prediction = model([image.to(device)])[0]

        # Plot ground truth
        axes[i, 0].imshow(image_vis)
        axes[i, 0].set_title("Ground Truth")

        # Draw ground truth boxes
        for box, label in zip(target["boxes"], target["labels"]):
            box = box.cpu().numpy()
            xmin, ymin, xmax, ymax = box
            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                 fill=False, edgecolor='green', linewidth=2)
            axes[i, 0].add_patch(rect)

            class_name = dataset.classes[label.item()]
            axes[i, 0].text(xmin, ymin-5, class_name, color='green',
                           backgroundcolor='white', fontsize=8)

        # Plot prediction
        axes[i, 1].imshow(image_vis)
        axes[i, 1].set_title("Prediction")

        # Filter predictions with confidence > 0.5
        mask = prediction["scores"] > 0.5
        boxes = prediction["boxes"][mask].cpu().numpy()
        labels = prediction["labels"][mask].cpu().numpy()
        scores = prediction["scores"][mask].cpu().numpy()

        # Draw predicted boxes
        for box, label, score in zip(boxes, labels, scores):
            xmin, ymin, xmax, ymax = box
            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                 fill=False, edgecolor='red', linewidth=2)
            axes[i, 1].add_patch(rect)

            class_name = dataset.classes[label]
            axes[i, 1].text(xmin, ymin-5, f"{class_name}: {score:.2f}",
                           color='red', backgroundcolor='white', fontsize=8)

        # Hide axis ticks
        axes[i, 0].set_xticks([])
        axes[i, 0].set_yticks([])
        axes[i, 1].set_xticks([])
        axes[i, 1].set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'sample_predictions.png'))
    plt.close()

    print(f"Prediction visualization saved to {os.path.join(output_dir, 'sample_predictions.png')}")

# ============================
# DATA LOADERS
# ============================

def create_data_loaders(args):
    train_dataset = PascalVOCDataset(
        root=args.voc_path,
        split='train',
        transforms=get_transform(train=True, img_size=args.image_size)
    )

    # Try 'val' or 'valid' for validation set
    val_split = 'val'
    if not os.path.exists(os.path.join(args.voc_path, 'val')):
        val_split = 'valid'

    val_dataset = PascalVOCDataset(
        root=args.voc_path,
        split=val_split,
        transforms=get_transform(train=False, img_size=args.image_size)
    )

    # Set num_workers=0 to avoid multiprocessing issues in Colab
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )

    return train_loader, val_loader, train_dataset.num_classes

# ============================
# MAIN TRAINING FUNCTION
# ============================

def train_ssd_model(args):
    """Main training function with enhanced evaluation and visualization"""
    print("\n==== TRAINING SSD MODEL WITH ENHANCED EVALUATION ====")
    print(f"Dataset path: {args.voc_path}")
    print(f"Output directory: {args.output_dir}")
    print(f"Training for {args.epochs} epochs with batch size {args.batch_size}")

    # Create data loaders
    try:
        train_loader, val_loader, num_classes = create_data_loaders(args)
    except Exception as e:
        print(f"Error creating data loaders: {e}")
        import traceback
        traceback.print_exc()
        return

    print(f"Created data loaders: {len(train_loader)} training batches, {len(val_loader)} validation batches")
    print(f"Number of classes (including background): {num_classes}")

    # Create model
    try:
        model = create_ssd_model(num_classes, pretrained=args.pretrained)
        model.to(device)
    except Exception as e:
        print(f"Error creating model: {e}")
        import traceback
        traceback.print_exc()
        return

    print(f"Created SSD300 model with VGG16 backbone")

    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=args.lr)

    # Learning rate scheduler
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)

    # Initialize training metrics
    train_losses = []
    val_maps = []
    val_maps_by_threshold = []
    best_map = 0.0
    best_model_path = os.path.join(args.output_dir, 'best_model.pth')

    # Create evaluation subdirectory
    eval_dir = os.path.join(args.output_dir, 'evaluation')
    os.makedirs(eval_dir, exist_ok=True)

    # Training loop
    print("\nStarting training...")
    start_time = time.time()

    for epoch in range(args.epochs):
        try:
            # Train for one epoch
            metric_logger = train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=10)

            # Update learning rate
            lr_scheduler.step()

            # Record training loss
            train_losses.append(metric_logger.loss.global_avg)

            # Evaluate on validation set
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                # Create epoch-specific evaluation directory
                epoch_eval_dir = os.path.join(eval_dir, f'epoch_{epoch+1}')
                os.makedirs(epoch_eval_dir, exist_ok=True)

                # Enhanced evaluation
                mAP, mAP_by_threshold, pr_curves = enhanced_evaluate(
                    model, val_loader, device, epoch,
                    confidence_thresholds=[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
                    output_dir=epoch_eval_dir
                )

                val_maps.append(mAP)
                val_maps_by_threshold.append(mAP_by_threshold)

                # Save best model
                if mAP > best_map:
                    best_map = mAP
                    torch.save(model.state_dict(), best_model_path)
                    print(f"Saved best model with mAP: {mAP:.4f}")

                    # Copy to Google Drive if in Colab
                    if IN_COLAB and args.gdrive_dir:
                        gdrive_best_path = os.path.join(args.gdrive_dir, 'best_model.pth')
                        shutil.copy(best_model_path, gdrive_best_path)
                        print(f"Copied best model to Google Drive: {gdrive_best_path}")

                # Copy evaluation results to Google Drive
                if IN_COLAB and args.gdrive_dir:
                    gdrive_eval_dir = os.path.join(args.gdrive_dir, 'evaluation', f'epoch_{epoch+1}')
                    save_to_google_drive(epoch_eval_dir, gdrive_eval_dir, IN_COLAB)

            # Save checkpoint
            if (epoch + 1) % args.save_freq == 0 or epoch == args.epochs - 1:
                checkpoint_path = os.path.join(args.output_dir, f'checkpoint_{epoch+1}.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': lr_scheduler.state_dict(),
                    'best_map': best_map
                }, checkpoint_path)
                print(f"Saved checkpoint at epoch {epoch+1}")

            # Plot training curves after each evaluation
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                plot_loss_curve(train_losses, val_maps, args.output_dir)

                # Enhanced plots for mAP vs threshold
                if len(val_maps_by_threshold) > 0:
                    plot_mAP_threshold_curves(val_maps_by_threshold, args.eval_freq, args.output_dir)

                # Copy visualization to Google Drive if in Colab
                if IN_COLAB and args.gdrive_dir:
                    curves_path = os.path.join(args.output_dir, 'training_curves.png')
                    gdrive_curves_path = os.path.join(args.gdrive_dir, 'training_curves.png')
                    if os.path.exists(curves_path):
                        shutil.copy(curves_path, gdrive_curves_path)

                    threshold_curves_path = os.path.join(args.output_dir, 'mAP_threshold_curves.png')
                    gdrive_threshold_path = os.path.join(args.gdrive_dir, 'mAP_threshold_curves.png')
                    if os.path.exists(threshold_curves_path):
                        shutil.copy(threshold_curves_path, gdrive_threshold_path)

        except Exception as e:
            print(f"Error in epoch {epoch}: {e}")
            import traceback
            traceback.print_exc()
            continue

    # Training complete
    total_time = time.time() - start_time
    print(f"\nTraining complete in {total_time/60:.2f} minutes")
    print(f"Best validation mAP: {best_map:.4f}")

    # Final model evaluation
    try:
        # Load the best model
        if os.path.exists(best_model_path):
            model.load_state_dict(torch.load(best_model_path))
            print("Loaded best model for final evaluation")

            # Final evaluation directory
            final_eval_dir = os.path.join(eval_dir, 'final')
            os.makedirs(final_eval_dir, exist_ok=True)

            # Final validation with enhanced metrics
            final_map, final_maps_by_threshold, final_pr_curves = enhanced_evaluate(
                model, val_loader, device, epoch=args.epochs, output_dir=final_eval_dir
            )
            print(f"Final validation mAP: {final_map:.4f}")

            # Visualize predictions with enhanced visualization
            print("Generating enhanced prediction visualizations...")
            val_dataset = val_loader.dataset  # Get the validation dataset

            # Create different visualizations for multiple confidence thresholds
            for threshold in [0.3, 0.5, 0.7]:
                threshold_dir = os.path.join(final_eval_dir, f'conf_{threshold}')
                os.makedirs(threshold_dir, exist_ok=True)

                enhanced_visualize_predictions(
                    model, val_dataset, device, num_images=8,
                    confidence_threshold=threshold, output_dir=threshold_dir
                )

            # Copy final evaluation to Google Drive
            if IN_COLAB and args.gdrive_dir:
                gdrive_final_eval_dir = os.path.join(args.gdrive_dir, 'evaluation', 'final')
                save_to_google_drive(final_eval_dir, gdrive_final_eval_dir, IN_COLAB)
        else:
            print(f"Warning: Best model not found at {best_model_path}")

    except Exception as e:
        print(f"Error in final evaluation: {e}")
        import traceback
        traceback.print_exc()

    # Generate final report with enhanced metrics
    report_path = os.path.join(args.output_dir, 'training_report.txt')
    with open(report_path, 'w') as f:
        f.write("SSD Model Training Report with Enhanced Evaluation\n")
        f.write("===============================================\n\n")
        f.write(f"Dataset: {args.voc_path}\n")
        f.write(f"Number of classes: {num_classes}\n")
        f.write(f"Training epochs: {args.epochs}\n")
        f.write(f"Batch size: {args.batch_size}\n")
        f.write(f"Learning rate: {args.lr}\n")
        f.write(f"Image size: {args.image_size}\n\n")
        f.write(f"Data augmentation: None\n\n")

        f.write("Results:\n")
        f.write(f"Best validation mAP@0.5: {best_map:.4f}\n")
        f.write(f"Training time: {total_time/60:.2f} minutes\n\n")

        f.write("Training Loss:\n")
        for i, loss in enumerate(train_losses):
            f.write(f"Epoch {i+1}: {loss:.4f}\n")

        f.write("\nValidation mAP@0.5:\n")
        for i, mAP in enumerate(val_maps):
            epoch = i * args.eval_freq + args.eval_freq
            f.write(f"Epoch {epoch}: {mAP:.4f}\n")

        # Add per-class metrics if available
        if 'final_maps_by_threshold' in locals() and final_maps_by_threshold:
            f.write("\nPer-Class AP@0.5 for Final Model:\n")
            class_aps = final_maps_by_threshold[0.5][1]
            for i, ap in enumerate(class_aps):
                if i == 0:  # Skip background
                    continue
                class_name = val_loader.dataset.classes[i]
                f.write(f"{class_name}: {ap:.4f}\n")

            f.write("\nmAP at Different Confidence Thresholds for Final Model:\n")
            for threshold in sorted(final_maps_by_threshold.keys()):
                f.write(f"Threshold {threshold}: {final_maps_by_threshold[threshold][0]:.4f}\n")

    print(f"Enhanced training report saved to {report_path}")
    # Copy report to Google Drive if in Colab
    if IN_COLAB and args.gdrive_dir:
        gdrive_report_path = os.path.join(args.gdrive_dir, 'training_report.txt')
        shutil.copy(report_path, gdrive_report_path)
        print(f"Copied training report to Google Drive: {gdrive_report_path}")

        # Save all output dir to Google Drive
        save_to_google_drive(args.output_dir, args.gdrive_dir, IN_COLAB)

    print("\n==== SSD MODEL TRAINING WITH ENHANCED EVALUATION COMPLETE ====")
    return model

# ============================
# MAIN EXECUTION
# ============================

if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser(description='Train SSD model for mango disease detection with ENHANCED EVALUATION')
    parser.add_argument('--voc-path', type=str, default='/content/ssd_dataset',
                        help='Path to Pascal VOC format dataset')
    parser.add_argument('--output-dir', type=str, default='/content/ssd_model',
                        help='Path to save model outputs')
    parser.add_argument('--gdrive-dir', type=str, default='/content/drive/MyDrive/MANGO/PROJECT/mango_ssd_enhanced',
                        help='Google Drive directory to save model (for Colab users)')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of epochs for training')
    parser.add_argument('--batch-size', type=int, default=8,
                        help='Batch size for training')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate')
    parser.add_argument('--pretrained', action='store_true',
                        help='Use pretrained VGG backbone')
    parser.add_argument('--image-size', type=int, default=300,
                        help='Image size for SSD300')
    parser.add_argument('--eval-freq', type=int, default=5,
                        help='Frequency of evaluation during training')
    parser.add_argument('--save-freq', type=int, default=10,
                        help='Frequency of saving model checkpoints')

    # For IPython/Jupyter/Colab
    if 'ipykernel' in sys.modules or 'IPython' in sys.modules or IN_COLAB:
        # Default arguments for notebook mode
        args = parser.parse_args([])
        args.pretrained = True  # Default to using pretrained backbone in Colab
        print("Running in notebook/Colab mode with default arguments")
    else:
        args = parser.parse_args()

    # Create output directories
    os.makedirs(args.output_dir, exist_ok=True)
    if IN_COLAB and args.gdrive_dir:
        os.makedirs(args.gdrive_dir, exist_ok=True)

    # Verify dataset existence
    print("\n==== VERIFYING DATASET ====")
    if not os.path.exists(args.voc_path) or not os.path.exists(os.path.join(args.voc_path, 'labelmap.txt')):
        print(f"Error: Pascal VOC dataset not found at {args.voc_path}")
        print("Please ensure you have converted your YOLO dataset to Pascal VOC format.")
        sys.exit(1)

    # Check if training split exists
    train_dir = os.path.join(args.voc_path, 'train')
    if not os.path.exists(train_dir):
        print(f"Error: Training directory not found at {train_dir}")
        sys.exit(1)

    # Check if validation split exists
    valid_found = False
    for val_name in ['val', 'valid']:
        val_dir = os.path.join(args.voc_path, val_name)
        if os.path.exists(val_dir):
            valid_found = True
            break

    if not valid_found:
        print(f"Error: Validation directory not found at {args.voc_path}/val or {args.voc_path}/valid")
        sys.exit(1)

    print("Dataset verification completed successfully.")

    # Train the model
    try:
        model = train_ssd_model(args)
        print("SSD model training completed successfully.")
    except Exception as e:
        print(f"Error during SSD model training: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

    print("\n==== ALL STEPS COMPLETED SUCCESSFULLY ====")
    best_model_path = os.path.join(args.output_dir, 'best_model.pth')
    if os.path.exists(best_model_path):
        print(f"SSD model trained and saved to {best_model_path}")
        if IN_COLAB and args.gdrive_dir:
            gdrive_model_path = os.path.join(args.gdrive_dir, 'best_model.pth')
            if os.path.exists(gdrive_model_path):
                print(f"Model also backed up to Google Drive at {gdrive_model_path}")
    else:
        print("Warning: Best model file not found. Check logs for errors during training.")

Running in Google Colab environment
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully
Using device: cuda
Running in notebook/Colab mode with default arguments

==== VERIFYING DATASET ====
Dataset verification completed successfully.

==== TRAINING SSD MODEL WITH ENHANCED EVALUATION ====
Dataset path: /content/ssd_dataset
Output directory: /content/ssd_model
Training for 50 epochs with batch size 8
Found 7 classes: ['__background__', 'Anthracnose', 'Bacterial-Black-spot', 'Damaged-mango', 'Fruitly', 'Mechanical-damage', 'Others']
Loaded 2306 images for train split
Found 7 classes: ['__background__', 'Anthracnose', 'Bacterial-Black-spot', 'Damaged-mango', 'Fruitly', 'Mechanical-damage', 'Others']
Loaded 493 images for valid split
Created data loaders: 289 training batches, 62 validation batches
Number of classes (including background): 7


Downloading: "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth" to /root/.cache/torch/hub/checkpoints/ssd300_vgg16_coco-b556d3b4.pth
100%|██████████| 136M/136M [00:12<00:00, 11.8MB/s]


Created SSD300 model with pretrained VGG16 backbone
Created SSD300 model with VGG16 backbone

Starting training...
Epoch: [0]
Epoch: [0] [  0/289]  eta: 0:11:39  time: 2.4206  loss: 29.2995 (29.2995)  bbox_regression: 6.1576 (6.1576)  classification: 23.1419 (23.1419)
Epoch: [0] [ 10/289]  eta: 0:01:23  time: 0.2979  loss: 35.3502 (35.3502)  bbox_regression: 5.7341 (5.7341)  classification: 29.6161 (29.6161)
Epoch: [0] [ 20/289]  eta: 0:00:52  time: 0.1938  loss: 23.2482 (22.9456)  bbox_regression: 3.8086 (3.6911)  classification: 19.4396 (19.2545)
Epoch: [0] [ 30/289]  eta: 0:00:40  time: 0.1574  loss: 18.5814 (9.3586)  bbox_regression: 3.1084 (1.6642)  classification: 15.4731 (7.6944)
Epoch: [0] [ 40/289]  eta: 0:00:34  time: 0.1389  loss: 15.8900 (8.1639)  bbox_regression: 2.5711 (1.2718)  classification: 13.3189 (6.8921)
Epoch: [0] [ 50/289]  eta: 0:00:30  time: 0.1272  loss: 14.0374 (6.9941)  bbox_regression: 2.2056 (0.8064)  classification: 11.8317 (6.1877)
Epoch: [0] [ 60/289]  

  model.load_state_dict(torch.load(best_model_path))


Validation: [50] [10/62]  eta: 0:00:03  time: 0.0713  
Validation: [50] [20/62]  eta: 0:00:02  time: 0.0703  
Validation: [50] [30/62]  eta: 0:00:02  time: 0.0706  
Validation: [50] [40/62]  eta: 0:00:01  time: 0.0705  
Validation: [50] [50/62]  eta: 0:00:00  time: 0.0712  
Validation: [50] [60/62]  eta: 0:00:00  time: 0.0716  
Validation: [50] [61/62]  eta: 0:00:00  time: 0.0711  
Validation: [50] Time: 0:00:04 (0.0712 s / it)
Saved mAP vs threshold plot to /content/ssd_model/evaluation/final/mAP_vs_threshold_epoch_50.png
Saved PR curves to /content/ssd_model/evaluation/final/pr_curves_epoch_50.png
Saved per-class AP plot to /content/ssd_model/evaluation/final/per_class_ap_epoch_50.png
Saved confusion matrix to /content/ssd_model/evaluation/final/confusion_matrix_epoch_50.png
Epoch 50: mAP@0.5 = 0.5146

Per-class Average Precision:
  Anthracnose: 0.6203
  Bacterial-Black-spot: 0.6026
  Damaged-mango: 0.0000
  Fruitly: 0.3149
  Mechanical-damage: 0.7571
  Others: 0.7928
Final validatio

In [None]:
import os
import shutil
import glob
import sys
from tqdm.auto import tqdm

# Check if running in Google Colab
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running in Google Colab environment")
except ImportError:
    IN_COLAB = False
    print("ERROR: This script must be run in Google Colab")
    sys.exit(1)

def copy_with_progress(src_dir, dst_dir):
    """Copy files with progress bar"""
    # Get list of all files to copy
    all_files = []
    for root, dirs, files in os.walk(src_dir):
        for file in files:
            src_path = os.path.join(root, file)
            rel_path = os.path.relpath(src_path, src_dir)
            dst_path = os.path.join(dst_dir, rel_path)
            all_files.append((src_path, dst_path))

    # Create progress bar
    with tqdm(total=len(all_files), desc="Copying files") as pbar:
        for src_path, dst_path in all_files:
            # Create destination directory if it doesn't exist
            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            # Copy the file
            shutil.copy2(src_path, dst_path)
            pbar.update(1)

def push_evaluation_to_drive(local_eval_dir, drive_base_dir, model_name="ssd_model"):
    """Push evaluation folder to Google Drive"""

    if not os.path.exists(local_eval_dir):
        print(f"ERROR: Local evaluation directory not found: {local_eval_dir}")
        return False

    # Mount Google Drive if not already mounted
    drive_mounted = os.path.exists("/content/drive")
    if not drive_mounted:
        print("Mounting Google Drive...")
        drive.mount('/content/drive')
        print("Google Drive mounted successfully")

    # Create target directory
    drive_eval_dir = os.path.join(drive_base_dir, model_name, "evaluation")
    os.makedirs(drive_eval_dir, exist_ok=True)
    print(f"Target directory: {drive_eval_dir}")

    # Count files to copy
    file_count = sum([len(files) for _, _, files in os.walk(local_eval_dir)])
    dir_count = sum([len(dirs) for _, dirs, _ in os.walk(local_eval_dir)])
    print(f"Found {file_count} files in {dir_count} directories to copy")

    # Copy files
    print(f"Copying evaluation results to Google Drive...")
    copy_with_progress(local_eval_dir, drive_eval_dir)

    # Verify copy
    drive_file_count = sum([len(files) for _, _, files in os.walk(drive_eval_dir)])
    if drive_file_count == file_count:
        print(f"SUCCESS: {drive_file_count} files copied to Google Drive")
        return True
    else:
        print(f"WARNING: Only {drive_file_count} of {file_count} files were copied")
        return False

# Configure these variables to customize the backup
LOCAL_EVAL_DIR = '/content/ssd_model/evaluation'  # Source directory
DRIVE_BASE_DIR = '/content/drive/MyDrive/MANGO/PROJECT'  # Destination base directory
MODEL_NAME = 'mango_ssd_enhanced'  # Model name subfolder

print("=" * 50)
print("PUSHING EVALUATION RESULTS TO GOOGLE DRIVE")
print("=" * 50)
print(f"Source: {LOCAL_EVAL_DIR}")
print(f"Destination: {DRIVE_BASE_DIR}/{MODEL_NAME}/evaluation")

success = push_evaluation_to_drive(
    LOCAL_EVAL_DIR,
    DRIVE_BASE_DIR,
    MODEL_NAME
)

if success:
    print("\nAll evaluation results successfully copied to Google Drive!")
    print("You can safely close this notebook or continue with other tasks.")
else:
    print("\nWARNING: There may have been issues copying the evaluation results.")
    print("Please check the logs above for more details.")

Running in Google Colab environment
PUSHING EVALUATION RESULTS TO GOOGLE DRIVE
Source: /content/ssd_model/evaluation
Destination: /content/drive/MyDrive/MANGO/PROJECT/mango_ssd_enhanced/evaluation
Target directory: /content/drive/MyDrive/MANGO/PROJECT/mango_ssd_enhanced/evaluation
Found 75 files in 14 directories to copy
Copying evaluation results to Google Drive...


Copying files:   0%|          | 0/75 [00:00<?, ?it/s]

SUCCESS: 75 files copied to Google Drive

All evaluation results successfully copied to Google Drive!
You can safely close this notebook or continue with other tasks.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models.detection.ssd import SSDHead
from torchvision.models.detection.anchor_utils import DefaultBoxGenerator
from torchvision.ops import box_iou, nms
import math
import numpy as np
from collections import OrderedDict
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import cv2
import json
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score

# ===========================
# INCEPTION V2 BACKBONE
# ===========================

class BasicConv2d(nn.Module):
    """Basic convolution module used in Inception networks"""

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)

class InceptionModule(nn.Module):
    """
    Inception module with BatchNorm (InceptionV2 style)
    """

    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(InceptionModule, self).__init__()

        # 1x1 branch
        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        # 3x3 branch
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        # 5x5 branch (implemented as two 3x3 convs for efficiency)
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1),
            BasicConv2d(ch5x5, ch5x5, kernel_size=3, padding=1)
        )

        # Pooling branch
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        return torch.cat([branch1, branch2, branch3, branch4], 1)

class InceptionV2Backbone(nn.Module):
    """
    InceptionV2 backbone network for SSD
    """

    def __init__(self, pretrained=True):
        super(InceptionV2Backbone, self).__init__()

        # Initial layers similar to InceptionV2
        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Inception blocks
        self.inception3a = InceptionModule(192, 64, 64, 64, 64, 96, 32)
        self.inception3b = InceptionModule(256, 64, 64, 96, 64, 96, 64)
        self.inception3c = InceptionModule(320, 128, 64, 96, 64, 96, 64)

        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = InceptionModule(384, 192, 96, 208, 16, 48, 64)
        self.inception4b = InceptionModule(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = InceptionModule(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = InceptionModule(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = InceptionModule(528, 256, 160, 320, 32, 128, 128)

        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = InceptionModule(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = InceptionModule(832, 384, 192, 384, 48, 128, 128)

        # If using pretrained weights, load from torchvision's Inception model
        if pretrained:
            try:
                inception_model = torchvision.models.inception_v3(pretrained=True)
                # We can't use direct loading as architectures differ, so we'll do a partial transfer
                print("Loading pretrained Inception weights and adapting to InceptionV2...")
                # This would require mapping between layers, which is complex
                # In a real implementation, you might want to use a pretrained checkpoint specifically for InceptionV2
            except Exception as e:
                print(f"Could not load pretrained weights: {e}")
                print("Using randomly initialized weights.")

        # Initialize weights if not using pretrained
        if not pretrained:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Feature extraction and return intermediate features for SSD
        features = []

        # Initial convolutions
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        # Inception blocks with feature collection
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.inception3c(x)

        # Save feature after Inception block 3
        features.append(x)  # feature 1 (~similar to conv4_3 in VGG16)

        x = self.maxpool3(x)

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)

        # Save feature after Inception block 4
        features.append(x)  # feature 2 (~similar to conv7 in SSD with VGG16)

        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        # Save feature after Inception block 5
        features.append(x)  # feature 3

        return features

# ===========================
# ENHANCED ANCHOR BOX GENERATION
# ===========================

class CustomBoxGenerator(DefaultBoxGenerator):
    """Enhanced anchor box generator with better aspect ratios and scales"""

    def __init__(self, aspect_ratios=None, min_ratio=0.05, max_ratio=0.95, scales=None):
        # Use more diverse aspect ratios to better capture different object shapes
        if aspect_ratios is None:
            aspect_ratios = [[2, 3, 1/2, 1/3], [2, 3, 1/2, 1/3],
                             [2, 3, 1/2, 1/3], [2, 3, 1/2, 1/3],
                             [2, 3, 1/2, 1/3], [2, 3, 1/2, 1/3]]

        # Use finer scale steps for better coverage
        if scales is None:
            scales = [0.04, 0.1, 0.26, 0.42, 0.58, 0.74, 0.9]

        super().__init__(aspect_ratios=aspect_ratios, scales=scales)

# ===========================
# FOCAL LOSS FOR SSD
# ===========================

class FocalLossForSSD(nn.Module):
    """
    Implements Focal Loss for SSD object detection.
    Helps address class imbalance between foreground and background.
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction='sum'):
        super(FocalLossForSSD, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, classifications, targets):
        """
        Arguments:
            classifications (batch_size, num_anchors, num_classes)
            targets (batch_size, num_anchors)
        """
        batch_size = classifications.size(0)
        classification_losses = []

        for i in range(batch_size):
            classification = classifications[i, :, :]  # (num_anchors, num_classes)
            target = targets[i, :]  # (num_anchors)

            alpha_factor = torch.ones_like(target) * self.alpha
            alpha_factor = torch.where(target > 0, alpha_factor, 1 - alpha_factor)

            # One-hot encode targets
            num_classes = classifications.size(2)
            target_onehot = F.one_hot(target.long(), num_classes=num_classes)

            # Apply sigmoid to get probabilities
            prob = torch.sigmoid(classification)

            # Calculate focal weight
            focal_weight = torch.where(target_onehot > 0, 1 - prob, prob)
            focal_weight = alpha_factor * focal_weight.pow(self.gamma)

            # Calculate binary cross entropy
            bce = F.binary_cross_entropy_with_logits(
                classification,
                target_onehot.float(),
                reduction='none'
            )

            # Apply focal weight to loss
            cls_loss = focal_weight * bce

            # Normalize by positive samples if needed
            if self.reduction == 'sum':
                cls_loss = cls_loss.sum()
            elif self.reduction == 'mean':
                cls_loss = cls_loss.mean()

            classification_losses.append(cls_loss)

        return torch.stack(classification_losses).mean()

# ===========================
# L2Norm Module
# ===========================

class L2Norm(nn.Module):
    """
    L2 normalization module as used in SSD
    """
    def __init__(self, n_channels, scale=20):
        super(L2Norm, self).__init__()
        self.n_channels = n_channels
        self.scale = scale
        self.eps = 1e-10
        self.weight = nn.Parameter(torch.ones(n_channels))

    def forward(self, x):
        # L2 norm across channel dimension
        norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
        x = x / norm
        # Scale the output - fixed the dimension mismatch
        out = x * self.weight.view(1, -1, 1, 1)
        return out

# ===========================
# SSD Classification and Regression Head
# ===========================

class SSDClassificationHead(nn.Module):
    """Classification head for SSD"""
    def __init__(self, in_channels, num_anchors, num_classes):
        super(SSDClassificationHead, self).__init__()
        self.num_classes = num_classes

        self.cls_heads = nn.ModuleList()
        for channels, anchors in zip(in_channels, num_anchors):
            self.cls_heads.append(
                nn.Conv2d(channels, anchors * num_classes, kernel_size=3, padding=1)
            )

    def forward(self, features):
        """Forward pass through classification head"""
        cls_logits = []
        for feature, head in zip(features, self.cls_heads):
            cls_logits.append(head(feature).permute(0, 2, 3, 1).contiguous())

        batch_size = features[0].shape[0]
        return torch.cat([c.view(batch_size, -1, self.num_classes) for c in cls_logits], dim=1)

class SSDRegressionHead(nn.Module):
    """Regression head for SSD"""
    def __init__(self, in_channels, num_anchors):
        super(SSDRegressionHead, self).__init__()

        self.reg_heads = nn.ModuleList()
        for channels, anchors in zip(in_channels, num_anchors):
            self.reg_heads.append(
                nn.Conv2d(channels, anchors * 4, kernel_size=3, padding=1)
            )

    def forward(self, features):
        """Forward pass through regression head"""
        bbox_regression = []
        for feature, head in zip(features, self.reg_heads):
            bbox_regression.append(head(feature).permute(0, 2, 3, 1).contiguous())

        batch_size = features[0].shape[0]
        return torch.cat([b.view(batch_size, -1, 4) for b in bbox_regression], dim=1)

# ===========================
# SSD WITH INCEPTION V2 BACKBONE
# ===========================

class SSD300_InceptionV2(nn.Module):
    """
    SSD300 model with InceptionV2 backbone

    Key improvements:
    1. InceptionV2 backbone instead of VGG16
    2. Feature pyramid connections for better scale handling
    3. Focal loss for class imbalance
    4. Custom anchor boxes
    """

    def __init__(self, num_classes, pretrained=True):
        super(SSD300_InceptionV2, self).__init__()

        # Create InceptionV2 backbone
        self.backbone = InceptionV2Backbone(pretrained=pretrained)

        # Determine feature output channels - Fixed to match actual outputs
        self.backbone_out_channels = [384, 832, 1024]  # [inception3c, inception4e, inception5b]

        # Additional SSD detection layers (conv layers to extend feature hierarchy)
        self.additional_blocks = nn.ModuleList([
            # Additional feature layers decreasing in size
            # Format: (input_channels, output_channels, kernel_size, stride, padding)
            BasicConv2d(1024, 512, kernel_size=1),  # Reduction layer
            BasicConv2d(512, 512, kernel_size=3, stride=2, padding=1),  # conv8_2

            BasicConv2d(512, 256, kernel_size=1),   # Reduction layer
            BasicConv2d(256, 512, kernel_size=3, stride=2, padding=1),  # conv9_2

            BasicConv2d(512, 128, kernel_size=1),   # Reduction layer
            BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1),  # conv10_2

            BasicConv2d(256, 128, kernel_size=1),   # Reduction layer
            BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1),  # conv11_2
        ])

        # Lateral connections for feature pyramid (similar to FPN)
        self.lateral_connections = nn.ModuleList([
            nn.Conv2d(384, 256, kernel_size=1),    # For inception3c - Fixed to match actual channel count
            nn.Conv2d(832, 256, kernel_size=1),    # For inception4e
            nn.Conv2d(1024, 256, kernel_size=1),   # For inception5b
            nn.Conv2d(512, 256, kernel_size=1),    # For conv8_2
            nn.Conv2d(512, 256, kernel_size=1),    # For conv9_2
            nn.Conv2d(256, 256, kernel_size=1),    # For conv10_2
        ])

        # L2 Normalization for the first feature map (inception3c) - Fixed to match actual channel count
        self.l2_norm = L2Norm(384, scale=20)

        # Use custom anchor box generator
        self.anchor_generator = CustomBoxGenerator()

        # Define the number of anchors per location for each feature map
        self.num_anchors = [4, 6, 6, 6, 4, 4]  # Adjusted for our aspect ratios

        # Create detection heads
        # Classification head (predicts classes)
        self.classification_head = SSDClassificationHead(
            in_channels=[256] * 6,  # All feature maps have 256 channels after lateral connections
            num_anchors=self.num_anchors,
            num_classes=num_classes
        )

        # Regression head (predicts bounding box offsets)
        self.regression_head = SSDRegressionHead(
            in_channels=[256] * 6,
            num_anchors=self.num_anchors
        )

        # Initialize the weights
        self._initialize_weights()

        # Use focal loss
        self.focal_loss = FocalLossForSSD(alpha=0.25, gamma=2.0)

    def _initialize_weights(self):
        # Initialize the additional blocks and detection heads
        for module in [self.additional_blocks, self.lateral_connections,
                      self.classification_head, self.regression_head]:
            for m in module.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

    def extract_features(self, x):
        """Extract and process feature maps for SSD"""
        # Get backbone features
        features = self.backbone(x)

        # Apply L2 norm to the first feature map
        features[0] = self.l2_norm(features[0])

        # Additional SSD feature maps
        x = features[-1]  # Start from the last backbone feature

        # Apply additional blocks in pairs (reduction + conv)
        for i in range(0, len(self.additional_blocks), 2):
            x = self.additional_blocks[i](x)
            x = self.additional_blocks[i+1](x)
            features.append(x)

        # Apply lateral connections to all feature maps
        enhanced_features = []
        for i, feature in enumerate(features):
            if i < len(self.lateral_connections):
                lateral = self.lateral_connections[i](feature)
                enhanced_features.append(lateral)

        return enhanced_features

    def forward(self, images, targets=None):
        """Forward pass for SSD model"""
        # Check if list or tensor
        if isinstance(images, list):
            images = torch.stack(images)

        # Get image sizes for anchor generation
        image_sizes = [(img.shape[-2], img.shape[-1]) for img in images]

        # Extract features
        features = self.extract_features(images)

        # Get predictions from detection heads
        cls_logits = self.classification_head(features)
        bbox_regression = self.regression_head(features)

        # Generate anchors
        anchors = self.anchor_generator(features, image_sizes)

        if self.training and targets is not None:
            # Training mode with actual loss calculation
            matched_idxs = []
            matched_labels = []

            for anchors_per_image, targets_per_image in zip(anchors, targets):
                if targets_per_image["boxes"].numel() == 0:
                    matched_idxs.append(torch.zeros_like(anchors_per_image[:, 0], dtype=torch.int64))
                    matched_labels.append(torch.zeros_like(anchors_per_image[:, 0], dtype=torch.int64))
                    continue

                # Calculate IoU between targets and anchors
                match_quality_matrix = box_iou(targets_per_image["boxes"], anchors_per_image)

                # Match each anchor to the target with highest IoU
                matched_vals, matches = match_quality_matrix.max(dim=0)

                # Assign labels to anchors
                matched_labels_per_image = targets_per_image["labels"][matches]

                # Background (negative) if IoU < 0.4
                bg_indices = matched_vals < 0.4
                matched_labels_per_image[bg_indices] = 0

                # Ignore anchors if 0.4 <= IoU < 0.5
                ignore_indices = (matched_vals >= 0.4) & (matched_vals < 0.5)
                matched_labels_per_image[ignore_indices] = -1

                matched_idxs.append(matches)
                matched_labels.append(matched_labels_per_image)

            # Compute regression targets
            regression_targets = []
            for anchors_per_image, matched_idxs_per_image, targets_per_image in zip(
                anchors, matched_idxs, targets):

                if targets_per_image["boxes"].numel() == 0:
                    regression_targets.append(torch.zeros_like(anchors_per_image))
                    continue

                matched_gt_boxes = targets_per_image["boxes"][matched_idxs_per_image]

                # Encode regression targets
                regression_targets.append(self.encode_boxes(matched_gt_boxes, anchors_per_image))

            regression_targets = torch.stack(regression_targets)
            matched_labels = torch.stack(matched_labels)

            # Calculate losses
            # Use hard negative mining for classification loss
            N = cls_logits.size(0)
            cls_logits_reshape = cls_logits.reshape(N, -1, self.classification_head.num_classes)

            # Apply focal loss for classification
            cls_loss = self.focal_loss(cls_logits_reshape, matched_labels)

            # Regression loss - only for positive anchors
            positive_mask = matched_labels > 0
            if positive_mask.sum() > 0:
                reg_loss = F.smooth_l1_loss(
                    bbox_regression[positive_mask],
                    regression_targets[positive_mask],
                    reduction='sum'
                ) / positive_mask.sum().float()
            else:
                reg_loss = torch.tensor(0.0, device=cls_loss.device)

            return {
                'classification': cls_loss,
                'bbox_regression': reg_loss,
                'total_loss': cls_loss + reg_loss
            }
        else:
            # Inference mode
            # Apply NMS and return detections

            # Create detection dictionaries
            detections = []

            # For each image in the batch
            for img_idx in range(len(images)):
                # Run NMS and format detections for this image
                boxes, scores, labels = self.postprocess_detections(
                    cls_logits[img_idx],
                    bbox_regression[img_idx],
                    anchors[img_idx],
                    image_sizes[img_idx]
                )

                detections.append({
                    'boxes': boxes,
                    'scores': scores,
                    'labels': labels
                })

            return detections

    def encode_boxes(self, gt_boxes, anchors):
        """
        Encode ground-truth boxes to regression targets

        Args:
            gt_boxes: Tensor of shape (N, 4) with ground-truth boxes
            anchors: Tensor of shape (N, 4) with anchor boxes

        Returns:
            Encoded regression targets
        """
        # Extract coordinates
        anchor_widths = anchors[:, 2] - anchors[:, 0]
        anchor_heights = anchors[:, 3] - anchors[:, 1]
        anchor_ctr_x = anchors[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y = anchors[:, 1] + 0.5 * anchor_heights

        gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0]
        gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1]
        gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_widths
        gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_heights

        # Prevent division by zero
        eps = 1e-7
        anchor_widths = torch.clamp(anchor_widths, min=eps)
        anchor_heights = torch.clamp(anchor_heights, min=eps)

        # Calculate targets
        targets_dx = (gt_ctr_x - anchor_ctr_x) / anchor_widths
        targets_dy = (gt_ctr_y - anchor_ctr_y) / anchor_heights
        targets_dw = torch.log(gt_widths / anchor_widths)
        targets_dh = torch.log(gt_heights / anchor_heights)

        targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
        return targets

    def postprocess_detections(self, cls_logits, bbox_regression, anchors, image_size):
        """Post-process detections including NMS"""
        # Apply softmax to get class probabilities
        scores = F.softmax(cls_logits, dim=1)

        # Get highest scoring class (excluding background class 0)
        scores, labels = scores[:, 1:].max(dim=1)
        labels = labels + 1  # Adjust labels (add 1) because we excluded background

        # Decode regression values to bounding boxes
        boxes = self.decode_boxes(bbox_regression, anchors)

        # Filter out low scoring boxes
        keep_idxs = scores > 0.05  # Confidence threshold
        boxes = boxes[keep_idxs]
        scores = scores[keep_idxs]
        labels = labels[keep_idxs]

        # Clip boxes to image boundaries
        boxes[:, 0].clamp_(min=0, max=image_size[1])
        boxes[:, 1].clamp_(min=0, max=image_size[0])
        boxes[:, 2].clamp_(min=0, max=image_size[1])
        boxes[:, 3].clamp_(min=0, max=image_size[0])

        # Non-maximum suppression
        keep = nms(boxes, scores, iou_threshold=0.5)

        # Get final detections
        boxes = boxes[keep]
        scores = scores[keep]
        labels = labels[keep]

        return boxes, scores, labels

    def decode_boxes(self, box_regression, anchors):
        """
        Decode regression values to box coordinates

        Args:
            box_regression: Tensor of shape (N, 4) with encoded offsets
            anchors: Tensor of shape (N, 4) with anchor boxes

        Returns:
            Decoded boxes in (x1, y1, x2, y2) format
        """
        # Extract anchor coordinates
        anchor_widths = anchors[:, 2] - anchors[:, 0]
        anchor_heights = anchors[:, 3] - anchors[:, 1]
        anchor_ctr_x = anchors[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y = anchors[:, 1] + 0.5 * anchor_heights

        # Get predicted offsets
        dx = box_regression[:, 0]
        dy = box_regression[:, 1]
        dw = box_regression[:, 2]
        dh = box_regression[:, 3]

        # Convert predictions to center, width, height format
        pred_ctr_x = dx * anchor_widths + anchor_ctr_x
        pred_ctr_y = dy * anchor_heights + anchor_ctr_y
        pred_w = torch.exp(dw) * anchor_widths
        pred_h = torch.exp(dh) * anchor_heights

        # Convert back to (x1, y1, x2, y2) format
        pred_boxes = torch.zeros_like(box_regression)
        pred_boxes[:, 0] = pred_ctr_x - 0.5 * pred_w
        pred_boxes[:, 1] = pred_ctr_y - 0.5 * pred_h
        pred_boxes[:, 2] = pred_ctr_x + 0.5 * pred_w
        pred_boxes[:, 3] = pred_ctr_y + 0.5 * pred_h

        return pred_boxes

# ===========================
# DATA LOADING AND PREPROCESSING
# ===========================

class VOCDataset(Dataset):
    """
    VOC Dataset for SSD training
    """
    def __init__(self, root, year='2012', image_set='train', transform=None):
        self.root = os.path.join(root, f'VOC{year}')
        self.image_set = image_set
        self.transform = transform

        # Load class names
        self.classes = [
            'background',  # Explicitly include background class (index 0)
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
            'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]

        # Create class to index mapping
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        # Get image and annotation paths
        self.images = []
        self.annotations = []

        # Path to image set file
        image_set_file = os.path.join(self.root, 'ImageSets', 'Main', f'{image_set}.txt')

        # Read image ids
        with open(image_set_file, 'r') as f:
            file_ids = [line.strip() for line in f.readlines()]

        # Get full paths
        for file_id in file_ids:
            self.images.append(os.path.join(self.root, 'JPEGImages', f'{file_id}.jpg'))
            self.annotations.append(os.path.join(self.root, 'Annotations', f'{file_id}.xml'))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        # Load image
        img_path = self.images[index]
        img = Image.open(img_path).convert('RGB')
        # Load annotations
        ann_path = self.annotations[index]
        boxes, labels = self._parse_voc_xml(ann_path)
        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        # Create target dictionary
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([index]),
            'orig_size': torch.as_tensor([img.height, img.width])
        }
        # Apply transformations
        if self.transform:
            img, target = self.transform(img, target)
        return img, target


    def _parse_voc_xml(self, xml_path):
        """Parse Pascal VOC XML annotation file"""
        tree = ET.parse(xml_path)
        root = tree.getroot()

        boxes = []
        labels = []

        for obj in root.findall('object'):
            # Get class label
            class_name = obj.find('name').text
            if class_name in self.class_to_idx:
                label = self.class_to_idx[class_name]
            else:
                continue  # Skip objects with unknown class

            # Get bounding box
            bbox = obj.find('bndbox')
            xmin = float(bbox.find('xmin').text)
            ymin = float(bbox.find('ymin').text)
            xmax = float(bbox.find('xmax').text)
            ymax = float(bbox.find('ymax').text)

            # Skip invalid boxes
            if xmax <= xmin or ymax <= ymin:
                continue

            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(label)

        return boxes, labels

# ===========================
# TRANSFORMS FOR SSD TRAINING
# ===========================


class Compose:
    """Compose multiple transforms"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    """Convert PIL Image to Tensor"""
    def __call__(self, image, target):
        image = transforms.ToTensor()(image)
        return image, target

class Resize:
    """Resize image and adjust bounding boxes accordingly"""
    def __init__(self, size):
        self.size = size  # (height, width)

    def __call__(self, image, target):
        # Original size
        orig_width, orig_height = image.size

        # Resize image
        image = transforms.Resize(self.size)(image)

        # Adjust boxes
        if target['boxes'].shape[0] > 0:
            # Scale factors
            width_scale = self.size[1] / orig_width
            height_scale = self.size[0] / orig_height

            # Scale boxes
            boxes = target['boxes'].clone()
            boxes[:, 0] *= width_scale
            boxes[:, 2] *= width_scale
            boxes[:, 1] *= height_scale
            boxes[:, 3] *= height_scale

            target['boxes'] = boxes

        return image, target

class RandomHorizontalFlip:
    """Randomly flip image horizontally"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            # Flip image
            image = transforms.functional.hflip(image)

            # Flip boxes
            if target['boxes'].shape[0] > 0:
                width = image.width
                boxes = target['boxes'].clone()
                boxes[:, [0, 2]] = width - boxes[:, [2, 0]]
                target['boxes'] = boxes

        return image, target

class ColorJitter:
    """Apply color jitter to image"""
    def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1):
        self.transform = transforms.ColorJitter(
            brightness=brightness,
            contrast=contrast,
            saturation=saturation,
            hue=hue
        )

    def __call__(self, image, target):
        image = self.transform(image)
        return image, target

class Normalize:
    """Normalize image"""
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = transforms.functional.normalize(image, mean=self.mean, std=self.std)
        return image, target

# ===========================
# DATA LOADERS
# ===========================

def create_data_loaders(args):
    """Create data loaders for training and validation"""

    # Define transforms
    train_transform = Compose([
        Resize((args.image_size, args.image_size)),
        RandomHorizontalFlip(),
        ColorJitter(),
        ToTensor(),
        Normalize()
    ])

    val_transform = Compose([
        Resize((args.image_size, args.image_size)),
        ToTensor(),
        Normalize()
    ])

    # Create datasets
    try:
        # Try to load from specified path
        train_dataset = VOCDataset(
            root=args.voc_path,
            image_set='train',
            transform=train_transform
        )

        val_dataset = VOCDataset(
            root=args.voc_path,
            image_set='val',
            transform=val_transform
        )
    except FileNotFoundError:
        print(f"Could not find VOC dataset at {args.voc_path}")
        print("Using dummy dataset for testing purposes...")

        # Create dummy datasets for testing
        train_dataset = create_dummy_dataset(num_samples=100, image_size=args.image_size, transform=train_transform)
        val_dataset = create_dummy_dataset(num_samples=20, image_size=args.image_size, transform=val_transform)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True
    )

    return train_loader, val_loader, len(train_dataset.classes)

def collate_fn(batch):
    """Custom collate function for object detection batches"""
    images = []
    targets = []

    for img, target in batch:
        images.append(img)
        targets.append(target)

    return images, targets

def create_dummy_dataset(num_samples=100, image_size=300, transform=None):
    """Create a dummy dataset for testing"""
    class DummyDataset(Dataset):
        def __init__(self, num_samples, image_size, transform):
            self.num_samples = num_samples
            self.image_size = image_size
            self.transform = transform

            # Define classes
            self.classes = [
                'background',
                'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
                'bus', 'car', 'cat', 'chair', 'cow',
                'diningtable', 'dog', 'horse', 'motorbike', 'person',
                'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
            ]

        def __len__(self):
            return self.num_samples

        def __getitem__(self, index):
            # Create random image
            img = torch.rand(3, self.image_size, self.image_size)
            img = transforms.ToPILImage()(img)

            # Create random boxes (1-3 boxes per image)
            num_boxes = random.randint(1, 3)
            boxes = []

            for _ in range(num_boxes):
                # Random box dimensions
                x1 = random.uniform(0, self.image_size - 100)
                y1 = random.uniform(0, self.image_size - 100)
                x2 = random.uniform(x1 + 50, min(x1 + 150, self.image_size))
                y2 = random.uniform(y1 + 50, min(y1 + 150, self.image_size))

                boxes.append([x1, y1, x2, y2])

            # Random labels (excluding background class 0)
            labels = [random.randint(1, len(self.classes) - 1) for _ in range(num_boxes)]

            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)

            # Create target dictionary
            target = {
                'boxes': boxes,
                'labels': labels,
                'image_id': torch.tensor([index]),
                'orig_size': torch.as_tensor([self.image_size, self.image_size])
            }

            # Apply transformations
            if self.transform:
                img, target = self.transform(img, target)

            return img, target

    return DummyDataset(num_samples, image_size, transform)

# ===========================
# EVALUATION METRICS
# ===========================

def calculate_mAP(predictions, targets, iou_threshold=0.5, confidence_threshold=0.05, return_per_class=False):
    """
    Calculate mean Average Precision (mAP) for object detection

    Args:
        predictions: List of prediction dictionaries with 'boxes', 'scores', 'labels'
        targets: List of target dictionaries with 'boxes', 'labels'
        iou_threshold: IoU threshold for considering a detection as correct
        confidence_threshold: Confidence threshold for filtering detections
        return_per_class: Whether to return per-class AP

    Returns:
        mAP: Mean Average Precision
        class_aps: List of per-class AP if return_per_class=True
    """
    # Get all classes (excluding background class 0)
    all_labels = set()
    for target in targets:
        if len(target['labels']) > 0:
            all_labels.update(target['labels'].tolist())

    # Filter out background class
    all_labels = [label for label in all_labels if label > 0]
    num_classes = max(all_labels) + 1

    # Initialize per-class metrics
    class_metrics = {cls: {'tp': [], 'fp': [], 'scores': [], 'num_gt': 0} for cls in range(1, num_classes)}

    # Process each image
    for pred, target in zip(predictions, targets):
        pred_boxes = pred['boxes']
        pred_scores = pred['scores']
        pred_labels = pred['labels']

        target_boxes = target['boxes']
        target_labels = target['labels']

        # Count ground truth instances per class
        for cls in range(1, num_classes):
            class_metrics[cls]['num_gt'] += (target_labels == cls).sum().item()

        # Filter by confidence threshold
        keep = pred_scores > confidence_threshold
        pred_boxes = pred_boxes[keep]
        pred_scores = pred_scores[keep]
        pred_labels = pred_labels[keep]

        # Process each class
        for cls in range(1, num_classes):
            # Filter predictions by class
            cls_pred_mask = pred_labels == cls
            cls_pred_boxes = pred_boxes[cls_pred_mask]
            cls_pred_scores = pred_scores[cls_pred_mask]

            # Filter targets by class
            cls_target_mask = target_labels == cls
            cls_target_boxes = target_boxes[cls_target_mask]

            # Sort predictions by score (descending)
            if len(cls_pred_scores) > 0:
                sort_idx = torch.argsort(cls_pred_scores, descending=True)
                cls_pred_boxes = cls_pred_boxes[sort_idx]
                cls_pred_scores = cls_pred_scores[sort_idx]

            # Initialize target flags (used to track matched targets)
            target_flags = [False] * len(cls_target_boxes)

            # For each prediction, find if it matches any ground truth
            for i, pred_box in enumerate(cls_pred_boxes):
                # Store the score
                class_metrics[cls]['scores'].append(cls_pred_scores[i].item())

                if len(cls_target_boxes) == 0:
                    # No ground truth, all predictions are false positives
                    class_metrics[cls]['tp'].append(0)
                    class_metrics[cls]['fp'].append(1)
                    continue

                # Calculate IoU with all ground truths
                ious = box_iou(pred_box.unsqueeze(0), cls_target_boxes)[0]

                # Get maximum IoU and corresponding index
                max_iou, max_idx = torch.max(ious, dim=0)

                if max_iou >= iou_threshold and not target_flags[max_idx]:
                    # True positive - matched with a ground truth that hasn't been matched before
                    class_metrics[cls]['tp'].append(1)
                    class_metrics[cls]['fp'].append(0)
                    target_flags[max_idx] = True
                else:
                    # False positive - either doesn't match any ground truth or matches one that's already been matched
                    class_metrics[cls]['tp'].append(0)
                    class_metrics[cls]['fp'].append(1)

    # Calculate AP for each class
    class_aps = []

    for cls in range(1, num_classes):
        if class_metrics[cls]['num_gt'] == 0:
            class_aps.append(0.0)
            continue

        # Convert lists to numpy arrays
        scores = np.array(class_metrics[cls]['scores'])
        tp = np.array(class_metrics[cls]['tp'])
        fp = np.array(class_metrics[cls]['fp'])

        if len(scores) == 0:
            class_aps.append(0.0)
            continue

        # Sort by score
        indices = np.argsort(-scores)
        tp = tp[indices]
        fp = fp[indices]

        # Compute cumulative sum
        tp_cumsum = np.cumsum(tp)
        fp_cumsum = np.cumsum(fp)

        # Calculate precision and recall
        precision = tp_cumsum / (tp_cumsum + fp_cumsum)
        recall = tp_cumsum / class_metrics[cls]['num_gt']

        # Add sentinel values
        precision = np.concatenate(([1.0], precision))
        recall = np.concatenate(([0.0], recall))

        # Compute average precision using 11-point interpolation
        ap = 0.0
        for t in np.linspace(0, 1, 11):
            mask = recall >= t
            if mask.any():
                p = precision[mask].max()
            else:
                p = 0.0
            ap += p / 11.0

        class_aps.append(ap)

    # Calculate mAP
    mAP = sum(class_aps) / len(class_aps) if class_aps else 0.0

    if return_per_class:
        return mAP, class_aps
    else:
        return mAP

def calculate_confusion_matrix(predictions, targets, num_classes, iou_threshold=0.5, confidence_threshold=0.5):
    """
    Calculate confusion matrix for object detection

    Args:
        predictions: List of prediction dictionaries with 'boxes', 'scores', 'labels'
        targets: List of target dictionaries with 'boxes', 'labels'
        num_classes: Number of classes (including background)
        iou_threshold: IoU threshold for considering a detection as correct
        confidence_threshold: Confidence threshold for filtering detections

    Returns:
        Confusion matrix: (num_classes, num_classes) tensor
    """
    # Initialize confusion matrix (rows=ground truth, cols=prediction)
    conf_matrix = torch.zeros((num_classes, num_classes), dtype=torch.int64)

    # Initialize counter for unmatched ground truth (misses)
    misses = torch.zeros(num_classes, dtype=torch.int64)

    # Process each image
    for pred, target in zip(predictions, targets):
        pred_boxes = pred['boxes']
        pred_scores = pred['scores']
        pred_labels = pred['labels']

        target_boxes = target['boxes']
        target_labels = target['labels']

        # Filter by confidence threshold
        keep = pred_scores > confidence_threshold
        pred_boxes = pred_boxes[keep]
        pred_labels = pred_labels[keep]

        # Initialize target flags (used to track matched targets)
        target_flags = [False] * len(target_boxes)

        # For each prediction, find matching ground truth
        for i, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
            if len(target_boxes) == 0:
                # No ground truth, all predictions are false positives (class 0 = background)
                conf_matrix[0, pred_label] += 1
                continue

            # Calculate IoU with all ground truths
            ious = box_iou(pred_box.unsqueeze(0), target_boxes)[0]

            # Get maximum IoU and corresponding index
            max_iou, max_idx = torch.max(ious, dim=0)

            if max_iou >= iou_threshold and not target_flags[max_idx]:
                # Prediction matches a ground truth
                gt_label = target_labels[max_idx]
                conf_matrix[gt_label, pred_label] += 1
                target_flags[max_idx] = True
            else:
                # No match, false positive (class 0 = background)
                conf_matrix[0, pred_label] += 1

        # Count unmatched ground truths (misses)
        for i, (flag, label) in enumerate(zip(target_flags, target_labels)):
            if not flag:
                # Missed detection, counted as predicting background
                conf_matrix[label, 0] += 1

    return conf_matrix

def calculate_precision_recall_curve(predictions, targets, class_idx, iou_threshold=0.5):
    """
    Calculate precision-recall curve for a specific class

    Args:
        predictions: List of prediction dictionaries with 'boxes', 'scores', 'labels'
        targets: List of target dictionaries with 'boxes', 'labels'
        class_idx: Class index to calculate PR curve for
        iou_threshold: IoU threshold for considering a detection as correct

    Returns:
        precision: List of precision values
        recall: List of recall values
        thresholds: List of score thresholds
    """
    # Collect all predictions and ground truths for this class
    all_scores = []
    all_tp = []
    all_fp = []
    num_gt = 0

    # Process each image
    for pred, target in zip(predictions, targets):
        pred_boxes = pred['boxes']
        pred_scores = pred['scores']
        pred_labels = pred['labels']

        target_boxes = target['boxes']
        target_labels = target['labels']

        # Count ground truth for this class
        num_gt += (target_labels == class_idx).sum().item()

        # Filter predictions by class
        cls_pred_mask = pred_labels == class_idx
        cls_pred_boxes = pred_boxes[cls_pred_mask]
        cls_pred_scores = pred_scores[cls_pred_mask]

        # Filter targets by class
        cls_target_mask = target_labels == class_idx
        cls_target_boxes = target_boxes[cls_target_mask]

        # Sort predictions by score (descending)
        if len(cls_pred_scores) > 0:
            sort_idx = torch.argsort(cls_pred_scores, descending=True)
            cls_pred_boxes = cls_pred_boxes[sort_idx]
            cls_pred_scores = cls_pred_scores[sort_idx]

        # Initialize target flags (used to track matched targets)
        target_flags = [False] * len(cls_target_boxes)

        # For each prediction, find if it matches any ground truth
        for i, pred_box in enumerate(cls_pred_boxes):
            # Store the score
            all_scores.append(cls_pred_scores[i].item())

            if len(cls_target_boxes) == 0:
                # No ground truth, all predictions are false positives
                all_tp.append(0)
                all_fp.append(1)
                continue

            # Calculate IoU with all ground truths
            ious = box_iou(pred_box.unsqueeze(0), cls_target_boxes)[0]

            # Get maximum IoU and corresponding index
            max_iou, max_idx = torch.max(ious, dim=0)

            if max_iou >= iou_threshold and not target_flags[max_idx]:
                # True positive
                all_tp.append(1)
                all_fp.append(0)
                target_flags[max_idx] = True
            else:
                # False positive
                all_tp.append(0)
                all_fp.append(1)

    # Calculate precision-recall curve
    if len(all_scores) == 0 or num_gt == 0:
        return [], [], []

    # Convert to numpy arrays
    scores = np.array(all_scores)
    tp = np.array(all_tp)
    fp = np.array(all_fp)

    # Sort by score
    indices = np.argsort(-scores)
    tp = tp[indices]
    fp = fp[indices]
    thresholds = scores[indices]

    # Compute cumulative sum
    tp_cumsum = np.cumsum(tp)
    fp_cumsum = np.cumsum(fp)

    # Calculate precision and recall
    precision = tp_cumsum / (tp_cumsum + fp_cumsum)
    recall = tp_cumsum / num_gt

    return precision, recall, thresholds

def visualize_predictions(model, dataset, device, num_images=5, output_dir=None):
    """
    Visualize model predictions on sample images

    Args:
        model: Trained SSD model
        dataset: Dataset to sample images from
        device: Device to run model on
        num_images: Number of images to visualize
        output_dir: Directory to save visualizations
    """
    # Sample indices
    indices = random.sample(range(len(dataset)), min(num_images, len(dataset)))

    model.eval()

    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get image and target
            img, target = dataset[idx]

            # Run model
            pred = model([img.to(device)])[0]

            # Convert tensor to PIL image for drawing
            img_np = img.permute(1, 2, 0).cpu().numpy()
            img_np = (img_np * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            img_np = img_np.astype(np.uint8)
            img_pil = Image.fromarray(img_np)

            # Create drawing context
            draw = ImageDraw.Draw(img_pil)

            # Draw ground truth boxes (green)
            for box, label in zip(target['boxes'], target['labels']):
                box = box.tolist()
                label_text = dataset.classes[label.item()]
                draw.rectangle(box, outline='green', width=2)
                draw.text((box[0], box[1]), label_text, fill='green')

            # Draw predicted boxes (red)
            for box, score, label in zip(pred['boxes'], pred['scores'], pred['labels']):
                if score > 0.5:  # Only draw high-confidence predictions
                    box = box.tolist()
                    label_text = f"{dataset.classes[label.item()]}: {score:.2f}"
                    draw.rectangle(box, outline='red', width=2)
                    draw.text((box[0], box[3] + 10), label_text, fill='red')

            # Save visualization
            if output_dir:
                img_pil.save(os.path.join(output_dir, f'viz_{i}.png'))

# ===========================
# TEST-TIME AUGMENTATION
# ===========================

class TestTimeAugmentation:
    """
    Apply test-time augmentation to improve detection performance
    without modifying the training procedure.
    """

    def __init__(self, model, device, num_augmentations=5):
        self.model = model
        self.device = device
        self.num_augmentations = num_augmentations

    def __call__(self, images):
        """
        Apply multiple augmentations to the input images and ensemble the results

        Args:
            images: List of input images

        Returns:
            List of detection results after ensembling
        """
        # Store original images
        original_images = images
        all_detections = []

        # Original prediction (no augmentation)
        with torch.no_grad():
            original_detections = self.model(original_images)

        all_detections.append(original_detections)

        # Apply various augmentations
        augmentations = [
            # Horizontal flip
            lambda img: torch.flip(img, dims=[2]),

            # Slight rotation (convert to PIL, rotate, convert back)
            lambda img: torchvision.transforms.functional.to_tensor(
                torchvision.transforms.functional.rotate(
                    torchvision.transforms.functional.to_pil_image(img),
                    angle=5.0
                )
            ),

            # Brightness adjustment
            lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness_factor=0.9),

            # Contrast adjustment
            lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast_factor=1.1),

            # Small translation (shift)
            lambda img: torch.nn.functional.pad(
                img[:, :, :-1, :], (0, 0, 1, 0), mode='replicate'
            )
        ]

        # Apply each augmentation and get predictions
        for aug_fn in augmentations[:self.num_augmentations]:
            # Apply augmentation to each image
            aug_images = [aug_fn(img) for img in original_images]

            # Get predictions
            with torch.no_grad():
                aug_detections = self.model(aug_images)

            # For horizontal flip, adjust bounding box coordinates
            if aug_fn == augmentations[0]:  # Horizontal flip
                for det in aug_detections:
                    if len(det['boxes']) > 0:
                        # Flip back the coordinates
                        img_width = original_images[0].shape[2]
                        boxes = det['boxes']
                        flipped_boxes = boxes.clone()
                        flipped_boxes[:, 0] = img_width - boxes[:, 2]
                        flipped_boxes[:, 2] = img_width - boxes[:, 0]
                        det['boxes'] = flipped_boxes

            all_detections.append(aug_detections)

        # Merge detections from all augmentations
        merged_detections = self._merge_detections(all_detections)

        return merged_detections

    def _merge_detections(self, all_detections):
        """Merge detections from multiple augmentations"""
        merged_detections = []

        # Process each image separately
        for img_idx in range(len(all_detections[0])):
            # Collect all boxes, scores and labels for this image
            boxes = []
            scores = []
            labels = []

            for aug_detections in all_detections:
                det = aug_detections[img_idx]
                boxes.append(det['boxes'])
                scores.append(det['scores'])
                labels.append(det['labels'])

            # Concatenate all predictions if not empty
            if all(len(b) > 0 for b in boxes):
                boxes = torch.cat(boxes, dim=0)
                scores = torch.cat(scores, dim=0)
                labels = torch.cat(labels, dim=0)

                # Apply weighted NMS per class
                result = {}
                result_boxes = []
                result_scores = []
                result_labels = []

                # Process each class separately
                unique_labels = torch.unique(labels)
                for cls in unique_labels:
                    cls_mask = (labels == cls)
                    cls_boxes = boxes[cls_mask]
                    cls_scores = scores[cls_mask]

                    # Apply NMS
                    keep = nms(cls_boxes, cls_scores, iou_threshold=0.5)

                    result_boxes.append(cls_boxes[keep])
                    result_scores.append(cls_scores[keep])
                    result_labels.append(torch.full_like(cls_scores[keep], cls, dtype=torch.int64))

                # Combine results
                if result_boxes:
                    result['boxes'] = torch.cat(result_boxes, dim=0)
                    result['scores'] = torch.cat(result_scores, dim=0)
                    result['labels'] = torch.cat(result_labels, dim=0)
                else:
                    result['boxes'] = torch.empty((0, 4), device=boxes.device)
                    result['scores'] = torch.empty((0,), device=boxes.device)
                    result['labels'] = torch.empty((0,), device=boxes.device, dtype=torch.int64)
            else:
                # Handle the case where some detections are empty
                result = {'boxes': torch.empty((0, 4)), 'scores': torch.empty(0), 'labels': torch.empty(0, dtype=torch.int64)}

            merged_detections.append(result)

        return merged_detections

# ===========================
# ENHANCED OPTIMIZER AND SCHEDULER
# ===========================


def create_enhanced_optimizer(model, lr=0.0005, weight_decay=0.0001):
    """Create an enhanced optimizer with parameter-specific learning rates"""
    # Separate backbone and detection head parameters for different learning rates
    backbone_params = []
    head_params = []

    for name, param in model.named_parameters():
        if 'backbone' in name:
            backbone_params.append(param)
        else:
            head_params.append(param)

    # Create optimizer with different learning rates
    optimizer = torch.optim.AdamW([
        {'params': backbone_params, 'lr': lr * 0.1},  # Lower LR for backbone
        {'params': head_params, 'lr': lr}
    ], weight_decay=weight_decay)

    return optimizer

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, cycles=0.5):
    """
    Create a schedule with a learning rate that decreases following the
    values of the cosine function between the initial lr and 0, after
    a warmup period during which it increases linearly from 0 to the initial lr.
    """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# ===========================
# HARD NEGATIVE MINING
# ===========================

def hard_negative_mining(cls_loss, pos_mask, neg_pos_ratio=3):
    """
    Hard negative mining to select the most difficult background examples.

    Args:
        cls_loss: Classification loss per anchor
        pos_mask: Boolean mask indicating positive (foreground) anchors
        neg_pos_ratio: Ratio of negative (background) to positive examples

    Returns:
        Boolean mask indicating which negative anchors to keep
    """
    # Count positives
    pos_count = pos_mask.sum().item()

    # Number of negatives to sample
    neg_count = int(pos_count * neg_pos_ratio)

    # If no positives, sample a fixed number of negatives
    if pos_count == 0:
        neg_count = 100

    # Get loss for negative anchors
    neg_mask = ~pos_mask
    neg_losses = cls_loss * neg_mask.float()

    # Sort negative losses from highest to lowest
    _, indices = neg_losses.sort(descending=True)
    _, orders = indices.sort()

    # Select the top K negatives
    neg_mask_top_k = orders < neg_count

    # Combine with positive mask
    hard_mining_mask = pos_mask | (neg_mask & neg_mask_top_k)

    return hard_mining_mask

# ===========================
# ADVANCED EVALUATION METRICS
# ===========================

def calculate_ap_per_iou(predictions, targets, class_idx, iou_thresholds=None):
    """
    Calculate Average Precision across multiple IoU thresholds (COCO-style)

    Args:
        predictions: List of prediction dictionaries with 'boxes', 'scores', 'labels'
        targets: List of target dictionaries with 'boxes', 'labels'
        class_idx: Class index to calculate AP for
        iou_thresholds: List of IoU thresholds, defaults to [0.5, 0.55, 0.6, ..., 0.95]

    Returns:
        AP: Average Precision across multiple IoU thresholds
    """
    if iou_thresholds is None:
        iou_thresholds = np.linspace(0.5, 0.95, 10)

    # Calculate AP for each IoU threshold
    aps = []
    for iou_threshold in iou_thresholds:
        # Collect all predictions and ground truths for this class
        all_scores = []
        all_tp = []
        all_fp = []
        num_gt = 0

        # Process each image
        for pred, target in zip(predictions, targets):
            pred_boxes = pred['boxes']
            pred_scores = pred['scores']
            pred_labels = pred['labels']

            target_boxes = target['boxes']
            target_labels = target['labels']

            # Count ground truth for this class
            num_gt += (target_labels == class_idx).sum().item()

            # Filter predictions by class
            cls_pred_mask = pred_labels == class_idx
            cls_pred_boxes = pred_boxes[cls_pred_mask]
            cls_pred_scores = pred_scores[cls_pred_mask]

            # Filter targets by class
            cls_target_mask = target_labels == class_idx
            cls_target_boxes = target_boxes[cls_target_mask]

            # Sort predictions by score (descending)
            if len(cls_pred_scores) > 0:
                sort_idx = torch.argsort(cls_pred_scores, descending=True)
                cls_pred_boxes = cls_pred_boxes[sort_idx]
                cls_pred_scores = cls_pred_scores[sort_idx]

            # Initialize target flags (used to track matched targets)
            target_flags = [False] * len(cls_target_boxes)

            # For each prediction, find if it matches any ground truth
            for i, pred_box in enumerate(cls_pred_boxes):
                # Store the score
                all_scores.append(cls_pred_scores[i].item())

                if len(cls_target_boxes) == 0:
                    # No ground truth, all predictions are false positives
                    all_tp.append(0)
                    all_fp.append(1)
                    continue

                # Calculate IoU with all ground truths
                ious = box_iou(pred_box.unsqueeze(0), cls_target_boxes)[0]

                # Get maximum IoU and corresponding index
                max_iou, max_idx = torch.max(ious, dim=0)

                if max_iou >= iou_threshold and not target_flags[max_idx]:
                    # True positive
                    all_tp.append(1)
                    all_fp.append(0)
                    target_flags[max_idx] = True
                else:
                    # False positive
                    all_tp.append(0)
                    all_fp.append(1)

        # Calculate AP for this IoU threshold
        if len(all_scores) == 0 or num_gt == 0:
            aps.append(0.0)
            continue

        # Convert to numpy arrays
        scores = np.array(all_scores)
        tp = np.array(all_tp)
        fp = np.array(all_fp)

        # Sort by score
        indices = np.argsort(-scores)
        tp = tp[indices]
        fp = fp[indices]

        # Compute cumulative sum
        tp_cumsum = np.cumsum(tp)
        fp_cumsum = np.cumsum(fp)

        # Calculate precision and recall
        precision = tp_cumsum / (tp_cumsum + fp_cumsum)
        recall = tp_cumsum / num_gt

        # Add sentinel values
        precision = np.concatenate(([1.0], precision))
        recall = np.concatenate(([0.0], recall))

        # Compute average precision using 11-point interpolation
        ap = 0.0
        for t in np.linspace(0, 1, 11):
            mask = recall >= t
            if mask.any():
                p = precision[mask].max()
            else:
                p = 0.0
            ap += p / 11.0

        aps.append(ap)

    # Return mean AP across IoU thresholds
    return np.mean(aps)

def calculate_coco_metrics(predictions, targets, num_classes):
    """
    Calculate COCO-style metrics including:
    - mAP@[0.5:0.95]
    - mAP@0.5
    - mAP@0.75
    - mAP for small, medium, large objects

    Args:
        predictions: List of prediction dictionaries with 'boxes', 'scores', 'labels'
        targets: List of target dictionaries with 'boxes', 'labels'
        num_classes: Number of classes (including background)

    Returns:
        Dictionary of COCO metrics
    """
    # Initialize metrics
    metrics = {
        'mAP': 0.0,
        'mAP_50': 0.0,
        'mAP_75': 0.0,
        'mAP_small': 0.0,
        'mAP_medium': 0.0,
        'mAP_large': 0.0
    }
    # Calculate AP for each class at various IoU thresholds
    class_aps = []
    class_aps_50 = []
    class_aps_75 = []

    # AP by object size
    class_aps_small = []
    class_aps_medium = []
    class_aps_large = []

    for cls in range(1, num_classes):  # Skip background class
        # Calculate AP across IoU thresholds [0.5:0.95]
        ap = calculate_ap_per_iou(predictions, targets, cls)
        class_aps.append(ap)
        # Calculate AP at IoU 0.5
        ap_50 = calculate_ap_per_iou(predictions, targets, cls, [0.5])
        class_aps_50.append(ap_50)
        # Calculate AP at IoU 0.75
        ap_75 = calculate_ap_per_iou(predictions, targets, cls, [0.75])
        class_aps_75.append(ap_75)
        # Calculate AP by object size
        ap_small, ap_medium, ap_large = calculate_ap_by_size(predictions, targets, cls)
        class_aps_small.append(ap_small)
        class_aps_medium.append(ap_medium)
        class_aps_large.append(ap_large)
    # Calculate mean metrics
    metrics['mAP'] = np.mean(class_aps) if class_aps else 0.0
    metrics['mAP_50'] = np.mean(class_aps_50) if class_aps_50 else 0.0
    metrics['mAP_75'] = np.mean(class_aps_75) if class_aps_75 else 0.0
    metrics['mAP_small'] = np.mean(class_aps_small) if class_aps_small else 0.0
    metrics['mAP_medium'] = np.mean(class_aps_medium) if class_aps_medium else 0.0
    metrics['mAP_large'] = np.mean(class_aps_large) if class_aps_large else 0.0

    return metrics

def calculate_ap_by_size(predictions, targets, class_idx, iou_threshold=0.5):
    """
    Calculate Average Precision for different object sizes:
    - Small: area < 32²
    - Medium: 32² <= area < 96²
    - Large: area >= 96²

    Args:
        predictions: List of prediction dictionaries with 'boxes', 'scores', 'labels'
        targets: List of target dictionaries with 'boxes', 'labels'
        class_idx: Class index to calculate AP for
        iou_threshold: IoU threshold for considering a detection as correct

    Returns:
        ap_small, ap_medium, ap_large: AP for each size category
    """
    # Define size thresholds
    small_threshold = 32 * 32
    medium_threshold = 96 * 96

    # Initialize metrics for each size category
    small_metrics = {'tp': [], 'fp': [], 'scores': [], 'num_gt': 0}
    medium_metrics = {'tp': [], 'fp': [], 'scores': [], 'num_gt': 0}
    large_metrics = {'tp': [], 'fp': [], 'scores': [], 'num_gt': 0}

    # Process each image
    for pred, target in zip(predictions, targets):
        pred_boxes = pred['boxes']
        pred_scores = pred['scores']
        pred_labels = pred['labels']

        target_boxes = target['boxes']
        target_labels = target['labels']

        # Filter predictions by class
        cls_pred_mask = pred_labels == class_idx
        cls_pred_boxes = pred_boxes[cls_pred_mask]
        cls_pred_scores = pred_scores[cls_pred_mask]

        # Filter targets by class
        cls_target_mask = target_labels == class_idx
        cls_target_boxes = target_boxes[cls_target_mask]

        # Calculate areas
        if len(cls_target_boxes) > 0:
            target_areas = (cls_target_boxes[:, 2] - cls_target_boxes[:, 0]) * (cls_target_boxes[:, 3] - cls_target_boxes[:, 1])

            # Count ground truths by size
            small_metrics['num_gt'] += (target_areas < small_threshold).sum().item()
            medium_metrics['num_gt'] += ((target_areas >= small_threshold) & (target_areas < medium_threshold)).sum().item()
            large_metrics['num_gt'] += (target_areas >= medium_threshold).sum().item()

            # Create size masks for targets
            small_targets_mask = target_areas < small_threshold
            medium_targets_mask = (target_areas >= small_threshold) & (target_areas < medium_threshold)
            large_targets_mask = target_areas >= medium_threshold

            # Get target boxes by size
            small_targets = cls_target_boxes[small_targets_mask]
            medium_targets = cls_target_boxes[medium_targets_mask]
            large_targets = cls_target_boxes[large_targets_mask]

            # Initialize target flags
            small_flags = [False] * len(small_targets)
            medium_flags = [False] * len(medium_targets)
            large_flags = [False] * len(large_targets)

            # For each prediction, find if it matches any ground truth
            for i, pred_box in enumerate(cls_pred_boxes):
                pred_score = cls_pred_scores[i].item()

                # Calculate area of prediction (to categorize it)
                pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])

                # Small predictions
                if pred_area < small_threshold:
                    # Add score
                    small_metrics['scores'].append(pred_score)

                    if len(small_targets) == 0:
                        # No small ground truths, count as false positive
                        small_metrics['tp'].append(0)
                        small_metrics['fp'].append(1)
                    else:
                        # Calculate IoU with small ground truths
                        ious = box_iou(pred_box.unsqueeze(0), small_targets)[0]
                        max_iou, max_idx = torch.max(ious, dim=0)

                        if max_iou >= iou_threshold and not small_flags[max_idx]:
                            # True positive
                            small_metrics['tp'].append(1)
                            small_metrics['fp'].append(0)
                            small_flags[max_idx] = True
                        else:
                            # False positive
                            small_metrics['tp'].append(0)
                            small_metrics['fp'].append(1)

                # Medium predictions
                elif pred_area < medium_threshold:
                    # Add score
                    medium_metrics['scores'].append(pred_score)

                    if len(medium_targets) == 0:
                        # No medium ground truths, count as false positive
                        medium_metrics['tp'].append(0)
                        medium_metrics['fp'].append(1)
                    else:
                        # Calculate IoU with medium ground truths
                        ious = box_iou(pred_box.unsqueeze(0), medium_targets)[0]
                        max_iou, max_idx = torch.max(ious, dim=0)
                        if max_iou >= iou_threshold and not medium_flags[max_idx]:
                            # True positive
                            medium_metrics['tp'].append(1)
                            medium_metrics['fp'].append(0)
                            medium_flags[max_idx] = True
                        else:
                            # False positive
                            medium_metrics['tp'].append(0)
                            medium_metrics['fp'].append(1)
                # Large predictions
                else:
                    # Add score
                    large_metrics['scores'].append(pred_score)

                    if len(large_targets) == 0:
                        # No large ground truths, count as false positive
                        large_metrics['tp'].append(0)
                        large_metrics['fp'].append(1)
                    else:
                        # Calculate IoU with large ground truths
                        ious = box_iou(pred_box.unsqueeze(0), large_targets)[0]
                        max_iou, max_idx = torch.max(ious, dim=0)

                        if max_iou >= iou_threshold and not large_flags[max_idx]:
                            # True positive
                            large_metrics['tp'].append(1)
                            large_metrics['fp'].append(0)
                            large_flags[max_idx] = True
                        else:
                            # False positive
                            large_metrics['tp'].append(0)
                            large_metrics['fp'].append(1)

    # Calculate AP for each size category
    ap_small = calculate_ap_from_metrics(small_metrics)
    ap_medium = calculate_ap_from_metrics(medium_metrics)
    ap_large = calculate_ap_from_metrics(large_metrics)
    return ap_small, ap_medium, ap_large
def calculate_ap_from_metrics(metrics):
    """Helper function to calculate AP from metrics dictionary"""
    if metrics['num_gt'] == 0 or len(metrics['scores']) == 0:
        return 0.0
    # Convert to numpy arrays
    scores = np.array(metrics['scores'])
    tp = np.array(metrics['tp'])
    fp = np.array(metrics['fp'])
    # Sort by score
    indices = np.argsort(-scores)
    tp = tp[indices]
    fp = fp[indices]
    # Compute cumulative sum
    tp_cumsum = np.cumsum(tp)
    fp_cumsum = np.cumsum(fp)
    # Calculate precision and recall
    precision = tp_cumsum / (tp_cumsum + fp_cumsum)
    recall = tp_cumsum / metrics['num_gt']
    # Add sentinel values
    precision = np.concatenate(([1.0], precision))
    recall = np.concatenate(([0.0], recall))
    # Compute average precision using 11-point interpolation
    ap = 0.0
    for t in np.linspace(0, 1, 11):
        mask = recall >= t
        if mask.any():
            p = precision[mask].max()
        else:
            p = 0.0
        ap += p / 11.0
    return ap