In [11]:
import os
from typing import List, Dict, Any, Tuple
from torch.utils.data import Dataset
import cv2
import numpy as np


class SimpleDetectionDataset(Dataset):
    """Minimal dataset for small experiments.

    records: list of dicts with keys:
      - image_path: str (absolute or relative to project root, e.g. 'data/sample.jpg')
      - bboxes: list of [x1,y1,x2,y2] in pascal_voc format
      - classes: list of int
    """

    def __init__(self, records: List[Dict[str, Any]], transforms=None):
        self.records = []
        self.transforms = transforms

        # --- FIX: Robust project root detection for both script and notebook use ---
        try:
            # Case 1: Running as a script (e.g., python src/data_loader.py)
            base_dir = os.path.dirname(os.path.abspath(__file__))
            # Go one level up (from src/ to project_root/)
            project_root = os.path.abspath(os.path.join(base_dir, ".."))
        except NameError:
            # Case 2: Running in an interactive environment (Jupyter/Colab)
            # Fallback to Current Working Directory (CWD)
            project_root = os.path.abspath(os.getcwd())
            
            # If CWD is '.../yolo/src', we must go up one level to get the 'yolo' root
            if os.path.basename(project_root).lower() == 'src':
                project_root = os.path.abspath(os.path.join(project_root, ".."))
        # --------------------------------------------------------------------------

        print(f"INFO: Determined project root as: {project_root}")


        for rec in records:
            img_path = rec["image_path"]

            # If not absolute, make it relative to the determined project root
            if not os.path.isabs(img_path):
                img_path = os.path.join(project_root, img_path)

            self.records.append({
                "image_path": os.path.abspath(img_path),
                "bboxes": rec.get("bboxes", []),
                "classes": rec.get("classes", [])
            })

    def __len__(self) -> int:
        return len(self.records)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        rec = self.records[idx]
        img = cv2.imread(rec['image_path'])
        
        # This is where the FileNotFoundError occurred previously
        if img is None:
            # Displaying the final attempted path for debugging
            raise FileNotFoundError(f"Image not found at final path: {rec['image_path']}")

        img = img[:, :, ::-1]  # BGR -> RGB
        bboxes = rec.get('bboxes', [])
        classes = rec.get('classes', [])

        if self.transforms:
            # Assuming transforms are from Albumentations (need to install if used)
            transformed = self.transforms(image=img, bboxes=bboxes, class_labels=classes)
            img = transformed['image']
            bboxes = np.array(transformed['bboxes'], dtype=np.float32)
            classes = np.array(transformed['class_labels'], dtype=np.int64)
        else:
            # Default PyTorch-style transform: HWC -> CHW, Normalize
            img = img.transpose(2, 0, 1).astype('float32') / 255.0
            bboxes = np.array(bboxes, dtype=np.float32)
            classes = np.array(classes, dtype=np.int64)

        return {'image': img, 'bboxes': bboxes, 'classes': classes}


def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
    """Simple collate function for object detection. Returns lists of tensors/arrays."""
    images = [b['image'] for b in batch]
    bboxes = [b['bboxes'] for b in batch]
    classes = [b['classes'] for b in batch]
    return images, bboxes, classes


if __name__ == '__main__':
    # NOTE: For this test to run successfully, you must have the file 
    # 'data/sample.jpg' in your main project folder (e.g., 'C:/Users/Acer/Internship/yolo/data/sample.jpg').
    
    # ------------------------------------------------------------------------------------
    # Ensure you have run: pip install opencv-python numpy<2
    # ------------------------------------------------------------------------------------
    
    # Example records using both relative and absolute paths
    recs = [
        # Relative path, should resolve to 'C:/.../yolo/data/sample.jpg'
        {'image_path': 'data/sample.jpg', 'bboxes': [[10, 10, 100, 120]], 'classes': [0]},
        # Absolute path (will be used as-is)
        {'image_path': r'C:/Users/Acer/Internship/yolo/data/sample.jpg', 'bboxes': [[30, 30, 150, 180]], 'classes': [1]}
    ]
    
    try:
        ds = SimpleDetectionDataset(recs)
        print('\n✅ Dataset initialized successfully.')
        print('len', len(ds))
        
        sample = ds[0]
        print("\n--- Sample 0 Output ---")
        # For default transforms (HWC -> CHW, Normalized):
        print("Image shape:", sample['image'].shape) # Should be (3, H, W)
        print("BBoxes:", sample['bboxes'])
        print("Classes:", sample['classes'])

    except FileNotFoundError as e:
        print(f"\n❌ FATAL ERROR: {e}")
        print("Please ensure the image 'data/sample.jpg' exists in your project root.")
    except Exception as e:
        print(f"\n❌ An unexpected error occurred: {e}")

INFO: Determined project root as: C:\Users\Acer\Internship\yolo

✅ Dataset initialized successfully.
len 2

--- Sample 0 Output ---
Image shape: (3, 6306, 4204)
BBoxes: [[ 10.  10. 100. 120.]]
Classes: [0]
