# Object Detection Demo

This notebook demonstrates how to use a trained object detection model to perform inference on new images. It covers:
1. Loading the trained model checkpoint.
2. Preprocessing an input image.
3. Performing inference to get predictions (class logits, box coordinates).
4. Post-processing predictions (Softmax, NMS, decoding boxes).
5. Drawing the detected bounding boxes on the original image and displaying it.

In [None]:
import torch
import torchvision
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

%matplotlib inline

# Add the project root to sys.path to allow importing from src
# Assuming the notebook is in object_detector/notebooks/
if '../' not in sys.path:
    sys.path.append('../')

from src.model import ObjectDetectionModel
from src.transforms import get_transform
from src.dataset import PASCAL_VOC_CLASSES # Tuple of foreground class names
from src import utils

In [None]:
# Configuration
CHECKPOINT_PATH = "../models_checkpoints/checkpoint_best.pth" 
# User should ensure this path is correct relative to the notebook's location.
# Example: if notebook is in 'object_detector/notebooks' and checkpoints are in 'object_detector/models_checkpoints'

IMAGE_DIR = "../test_images/" 
# User should create this directory in the project root (e.g., object_detector/test_images/)
# and place their test images there.

DEFAULT_TEST_IMAGE = "example.jpg" # A default image name to look for in IMAGE_DIR

IMAGE_SIZE = 300 # Must match the image size the model was trained on
SCORE_THRESHOLD_DISPLAY = 0.5 # Confidence threshold for displaying detected boxes
IOU_THRESHOLD_NMS = 0.45      # IoU threshold for Non-Maximum Suppression
NUM_CLASSES_FG = 20           # Number of foreground classes (e.g., 20 for Pascal VOC)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")
print(f"PASCAL VOC Classes (foreground): {PASCAL_VOC_CLASSES}")

In [None]:
# Load Model
model = ObjectDetectionModel(
    num_classes_fg=NUM_CLASSES_FG, 
    image_size_for_default_boxes=(IMAGE_SIZE, IMAGE_SIZE)
)

if os.path.exists(CHECKPOINT_PATH):
    try:
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
        
        # Handle potential 'module.' prefix if model was saved using DataParallel
        model_state_dict = checkpoint['model_state_dict']
        if all(key.startswith('module.') for key in model_state_dict.keys()):
            print("Removing 'module.' prefix from checkpoint keys.")
            corrected_state_dict = {k.replace('module.', ''): v for k, v in model_state_dict.items()}
        else:
            corrected_state_dict = model_state_dict
            
        model.load_state_dict(corrected_state_dict)
        model.to(DEVICE)
        model.eval() # Set model to evaluation mode
        print(f"Model loaded successfully from {CHECKPOINT_PATH}.")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        print("Please ensure the checkpoint is valid and compatible with the model.")
        model = None # Ensure model is None if loading failed
else:
    print(f"Checkpoint not found at {CHECKPOINT_PATH}.")
    print("Please place your trained model checkpoint at the specified path or update CHECKPOINT_PATH.")
    model = None

In [None]:
# Helper Functions

def prepare_image(image_path, image_size, device):
    """
    Loads a PIL image, applies transformations, and prepares it for the model.
    Returns the image tensor and the original PIL image.
    """
    try:
        pil_image = Image.open(image_path).convert('RGB')
    except FileNotFoundError:
        print(f"Error: Image not found at {image_path}")
        return None, None
    
    transform = get_transform(train=False, resize_size=(image_size, image_size))
    img_tensor = transform(pil_image)
    return img_tensor.unsqueeze(0).to(device), pil_image

def draw_boxes_on_image(pil_img, boxes_xyxy_norm, labels, scores, class_names, score_thresh):
    """
    Draws bounding boxes, labels, and scores on a PIL image.
    boxes_xyxy_norm: Normalized [xmin, ymin, xmax, ymax] coordinates [0,1].
    labels: Integer class indices (foreground, 0 to N-1).
    scores: Confidence scores for each box.
    class_names: Tuple/list of foreground class names.
    score_thresh: Threshold to display boxes.
    """
    draw = ImageDraw.Draw(pil_img)
    
    # Define a list of distinct colors for classes
    # Using a fixed list of colors. For more classes, this list would need to be extended.
    colors = [
        "red", "green", "blue", "yellow", "purple", "orange", "cyan", "magenta",
        "brown", "pink", "gray", "olive", "navy", "teal", "maroon", "lime",
        "aqua", "fuchsia", "silver", "gold"
    ] * (len(class_names) // 20 + 1) # Repeat colors if more than 20 classes

    try:
        # Try to load a default font, fallback if not available
        font = ImageFont.load_default()
    except IOError:
        print("Default font not found. Using a basic font representation.")
        font = None

    for i in range(boxes_xyxy_norm.shape[0]):
        if scores[i] < score_thresh:
            continue

        # Denormalize box coordinates
        xmin = boxes_xyxy_norm[i, 0] * pil_img.width
        ymin = boxes_xyxy_norm[i, 1] * pil_img.height
        xmax = boxes_xyxy_norm[i, 2] * pil_img.width
        ymax = boxes_xyxy_norm[i, 3] * pil_img.height

        box_coords = [(xmin, ymin), (xmax, ymax)]
        label_idx = labels[i].item()
        class_name = class_names[label_idx] if label_idx < len(class_names) else f"Class_{label_idx}"
        
        label_text = f"{class_name}: {scores[i]:.2f}"
        box_color = colors[label_idx % len(colors)]

        draw.rectangle(box_coords, outline=box_color, width=2)
        
        if font:
            # Get text size using textbbox for Pillow versions that support it
            try:
                text_bbox = draw.textbbox((xmin, ymin), label_text, font=font)
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]
            except AttributeError:
                # Fallback for older Pillow versions
                text_width, text_height = draw.textsize(label_text, font=font)
            
            # Position text above the box, or inside if it overflows
            text_x = xmin
            text_y = ymin - text_height - 2 # Position above the box
            if text_y < 0: # If text goes off the top, place it inside
                text_y = ymin + 2
            
            # Background for text for better readability
            draw.rectangle([(text_x, text_y), (text_x + text_width, text_y + text_height)], fill=box_color)
            draw.text((text_x, text_y), label_text, fill="black", font=font)
        else:
            # Basic text drawing if font loading failed
            draw.text((xmin + 2, ymin + 2), label_text, fill=box_color)
            
    return pil_img

In [None]:
# Perform Inference and Display

if model is not None: # Proceed only if model was loaded successfully
    # Check if IMAGE_DIR exists, create if not, and provide instructions
    if not os.path.exists(IMAGE_DIR):
        print(f"Test image directory '{IMAGE_DIR}' not found.")
        print(f"Please create it in the project root (relative to this notebook: {os.path.abspath(IMAGE_DIR)}) and place test images there.")
        # As a fallback, let's try to create a dummy image if directory can be made
        try:
            os.makedirs(IMAGE_DIR, exist_ok=True)
            dummy_img = Image.new('RGB', (600, 400), color = 'skyblue')
            draw_dummy = ImageDraw.Draw(dummy_img)
            try: font = ImageFont.load_default(); draw_dummy.text((50,180), "Please replace with your test image!", font=font, fill='black')
            except: draw_dummy.text((50,180), "Please replace with your test image!", fill='black')
            dummy_img_path = os.path.join(IMAGE_DIR, DEFAULT_TEST_IMAGE)
            dummy_img.save(dummy_img_path)
            print(f"Created a dummy image at {dummy_img_path}. Please replace it with a real test image.")
        except Exception as e:
            print(f"Could not create dummy image directory/file: {e}")
            DEFAULT_TEST_IMAGE = None # Cannot proceed with image loading
    
    if DEFAULT_TEST_IMAGE:
        test_image_path = os.path.join(IMAGE_DIR, DEFAULT_TEST_IMAGE)

        if not os.path.exists(test_image_path):
            print(f"Test image '{DEFAULT_TEST_IMAGE}' not found in '{IMAGE_DIR}'.")
            print("Please place an image there or update DEFAULT_TEST_IMAGE in Cell 3.")
        else:
            img_tensor, original_pil = prepare_image(test_image_path, IMAGE_SIZE, DEVICE)

            if img_tensor is not None and original_pil is not None:
                with torch.no_grad():
                    # Model returns: cls_logits, bbox_pred_cxcywh, default_boxes_xyxy
                    cls_logits, bbox_pred_cxcywh, _ = model(img_tensor)

                # Process predictions for the single image (index 0 of batch)
                pred_scores_softmax = torch.softmax(cls_logits[0], dim=-1) # (num_default_boxes, num_classes_loss)
                pred_boxes_cxcywh = bbox_pred_cxcywh[0]                   # (num_default_boxes, 4)
                
                # Convert predicted boxes from [cx, cy, w, h] to [xmin, ymin, xmax, ymax] (still normalized)
                pred_boxes_xyxy_norm = utils.box_cxcywh_to_xyxy(pred_boxes_cxcywh)

                all_final_pred_boxes = []
                all_final_pred_labels = []
                all_final_pred_scores = []

                # Iterate through foreground classes to apply NMS per class
                for cls_idx in range(NUM_CLASSES_FG):
                    class_scores = pred_scores_softmax[:, cls_idx + 1] # Scores for current fg class (skip background)
                    
                    # Filter by score threshold before NMS
                    score_filter_mask = class_scores >= SCORE_THRESHOLD_DISPLAY 
                    if score_filter_mask.sum() == 0:
                        continue

                    current_class_scores = class_scores[score_filter_mask]
                    current_class_boxes_xyxy = pred_boxes_xyxy_norm[score_filter_mask]
                    
                    # Apply Non-Maximum Suppression
                    # NMS expects absolute coordinates if images vary greatly, but normalized is fine here as all inputs are resized.
                    # However, utils.non_max_suppression might internally assume absolute if not careful.
                    # For consistency with evaluate.py, let's assume NMS can handle normalized or that utils.non_max_suppression is robust.
                    # If NMS needed absolute for some reason: current_class_boxes_abs = current_class_boxes_xyxy * IMAGE_SIZE
                    keep_indices = utils.non_max_suppression(
                        current_class_boxes_xyxy, # Normalized boxes
                        current_class_scores,
                        IOU_THRESHOLD_NMS
                    )

                    all_final_pred_boxes.append(current_class_boxes_xyxy[keep_indices])
                    all_final_pred_labels.append(torch.full_like(current_class_scores[keep_indices], cls_idx, dtype=torch.long))
                    all_final_pred_scores.append(current_class_scores[keep_indices])
                
                if len(all_final_pred_boxes) > 0:
                    final_boxes = torch.cat(all_final_pred_boxes).cpu()    # Normalized xyxy
                    final_labels = torch.cat(all_final_pred_labels).cpu()  # Foreground class indices (0 to N-1)
                    final_scores = torch.cat(all_final_pred_scores).cpu()

                    result_img_pil = draw_boxes_on_image(
                        original_pil.copy(), 
                        final_boxes, 
                        final_labels, 
                        final_scores, 
                        PASCAL_VOC_CLASSES, # Pass the tuple of fg class names
                        SCORE_THRESHOLD_DISPLAY
                    )
                    print(f"Detected {final_boxes.shape[0]} objects above threshold.")
                else:
                    result_img_pil = original_pil # No detections above threshold
                    print("No objects detected above the threshold.")

                # Display using Matplotlib
                plt.figure(figsize=(12, 10))
                plt.imshow(result_img_pil)
                plt.title(f"Detections on: {os.path.basename(test_image_path)}")
                plt.axis('off')
                plt.show()
else:
    print("Model not loaded. Cannot perform inference.")

## Instructions for Use

1.  **Place Trained Checkpoint:** 
    *   Ensure your trained model checkpoint (e.g., `checkpoint_best.pth`) is located in the `object_detector/models_checkpoints/` directory (relative to the project root).
    *   If your checkpoint file has a different name or path, update the `CHECKPOINT_PATH` variable in **Cell 3 (Configuration)**.

2.  **Create Test Image Directory:**
    *   Create a directory named `test_images` in the project root (e.g., `object_detector/test_images/`).
    *   Place the images you want to test your model on into this `test_images/` directory.

3.  **Set Test Image Name:**
    *   In **Cell 3 (Configuration)**, modify the `DEFAULT_TEST_IMAGE` variable to the filename of the image you want to process (e.g., `"my_test_image.jpg"`). This image should be inside the `IMAGE_DIR` specified.

4.  **Run All Cells:**
    *   From the Jupyter Notebook menu, select "Cell" -> "Run All" or use the corresponding toolbar button.
    *   The notebook will load the model, process the specified image, and display the image with detected bounding boxes.

**Notes:**
*   The `IMAGE_SIZE` in **Cell 3** should match the image size your model was trained with (default is 300x300).
*   You can adjust `SCORE_THRESHOLD_DISPLAY` to control how sensitive the display of detections is. Lower values will show more (potentially less confident) detections.
*   If the model or image files are not found, the notebook will print error messages. Please check the paths and filenames carefully.