# Hierarchical Inference Pipeline

This notebook demonstrates the complete inference pipeline for our Hierarchical Object Detector. It uses a **waterfall rescue model** to enhance the predictions of a baseline YOLO L3 model.

### Key Features of this Pipeline:
1.  **Waterfall Rescue Logic**:
    -   High-confidence L3 detections (`>0.5`) are automatically confirmed.
    -   Low-confidence L3 candidates (`0.1-0.5`) are passed to the L2 head for a rescue attempt.
    -   If L2 fails, the L1 head gets a final chance to rescue the detection as a generic "vehicle".
2.  **`HierarchicalDetector` Class**: A single, powerful class that encapsulates all models and logic.
3.  **Run-Once Efficiency**: The expensive YOLO backbone is run only **once** per image.
4.  **High-Quality Visualization**: Produces clear, color-coded bounding boxes, with an option to show or hide text labels.

In [1]:
# --- Core Imports ---
import json
import random
import sys
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Tuple

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
import torch.nn as nn
from ultralytics import YOLO
from tqdm import tqdm

# --- Configuration ---
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if USE_GPU else "cpu")

# --- PATHS (Update these to match your project) ---
YOLO_L3_PATH = Path("models/yolov11_baseline.pt")
AUX_HEADS_PATH = Path("models/hierarchical_heads.pt")
MANIFEST_PATH = Path("manifest.json")

# --- FOLDERS ---
IMAGE_INPUT_DIR = Path("C:/Users/Mika/Desktop/New_Training_Run_Post_Bulgaria/mio_tcd_yolo_vehicles_only/images/val") # <--- UPDATE THIS
OUTPUT_DIR = Path("inference_results_waterfall")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# --- WATERFALL INFERENCE THRESHOLDS (Tune these to adjust performance) ---
L3_CONFIRM_THRESH = 0.5
L3_CANDIDATE_THRESH = 0.1
L2_RESCUE_THRESH = 0.6
L1_RESCUE_THRESH = 0.5

# --- Visualization Settings ---
FONT_PATH = "C:/Windows/Fonts/bahnschrift.ttf"
NUM_IMAGES_TO_PROCESS = 50
# --- NEW: Set this to False to only draw boxes without text labels ---
DRAW_LABELS = True

print("Configuration loaded.")
if not IMAGE_INPUT_DIR.exists():
    print(f"[ERROR] The specified image input directory does not exist: {IMAGE_INPUT_DIR}")
    print("Please update the `IMAGE_INPUT_DIR` variable in this cell.")

Configuration loaded.


## The `HierarchicalDetector` Class

This class is the core engine of our inference pipeline. It has been updated to use the **waterfall rescue model**.

In [2]:
# --- Helper Models (Must match training definitions) ---
class FeatureCompressor(nn.Module):
    # This class is unchanged, but included for completeness of the cell.
    def __init__(self, in_channels: int, out_dim: int, hidden_dim_factor: int = 2):
        super().__init__()
        hidden_dim = out_dim * hidden_dim_factor
        self.trunk = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(inplace=True),
            nn.Conv2d(hidden_dim, out_dim, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_dim), nn.SiLU(inplace=True),
        ); self.pool = nn.AdaptiveAvgPool2d((1, 1))
    def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.flatten(self.pool(self.trunk(x)), 1)

# --- REVISED Model Definition ---
class AuxHeadsMLP(nn.Module):
    """
    An MLP with a shared trunk and DECOUPLED necks for the L1 and L2 heads.
    This architecture must match the one used in the training script.
    """
    def __init__(self, in_dim: int, num_l2_classes: int, hidden_dim: int = 512, neck_dim: int = 128, dropout: float = 0.3):
        super().__init__()
        # 1. Shared Trunk
        self.trunk = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.SiLU(inplace=True), nn.Dropout(dropout)
        )
        # --- ARCHITECTURAL CHANGE: Decoupled Necks ---
        # 2a. L2 Neck
        self.l2_neck = nn.Sequential(
            nn.Linear(hidden_dim // 2, neck_dim), nn.LayerNorm(neck_dim), nn.SiLU(inplace=True)
        )
        # 2b. L1 Neck
        self.l1_neck = nn.Sequential(
            nn.Linear(hidden_dim // 2, neck_dim), nn.LayerNorm(neck_dim), nn.SiLU(inplace=True)
        )
        # 3. Final Heads
        self.l2_head = nn.Linear(neck_dim, num_l2_classes)
        self.l1_head = nn.Linear(neck_dim, 1)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        shared_features = self.trunk(x)
        l2_features = self.l2_neck(shared_features)
        l1_features = self.l1_neck(shared_features)
        return {"l1_logits": self.l1_head(l1_features), "l2_logits": self.l2_head(l2_features)}

# --- Main Inference Class (Unchanged, but included for completeness) ---
class HierarchicalDetector:
    def __init__(self, yolo_path: Path, aux_heads_path: Path, manifest_path: Path):
        print("Initializing HierarchicalDetector..."); self.device = DEVICE
        with open(manifest_path) as f: self.manifest = json.load(f)
        self.hierarchy = self.manifest['hierarchy']; self.l3_names = self.hierarchy['L3_NAMES']; self.l2_names = self.hierarchy['L2_NAMES']; self.class_to_l2 = self.hierarchy['CLASS_TO_L2']
        feat_config = self.manifest['feature_extraction']; self.compressed_dim = feat_config['compressed_dim']; self.roi_align_size = tuple(feat_config['roi_align_size']); self.pyramid_thresholds = tuple(feat_config['pyramid_thresholds'])
        self.yolo_model = YOLO(yolo_path); self.feature_maps = {}; self._attach_hooks()
        self._instantiate_aux_models()
        checkpoint = torch.load(aux_heads_path, map_location=self.device)
        self.aux_heads_mlp.load_state_dict(checkpoint['model_state_dict']); self.aux_heads_mlp.eval(); self.feature_compressors.eval()
        print("HierarchicalDetector initialized and ready.")

    def _attach_hooks(self):
        def get_features_hook(name):
            def hook(model, input, output): self.feature_maps[name] = output[0] if isinstance(output, tuple) else output
            return hook
        try:
            # NOTE: YOLOv11 neck indices are different from YOLOv8. These are for v11.
            # P5 is output of layer 16, P4 is layer 19, P3 is layer 22.
            p5_module_idx, p4_module_idx, p3_module_idx = 16, 19, 22
            self.hook_p5 = self.yolo_model.model.model[p5_module_idx].register_forward_hook(get_features_hook('p5'))
            self.hook_p4 = self.yolo_model.model.model[p4_module_idx].register_forward_hook(get_features_hook('p4'))
            self.hook_p3 = self.yolo_model.model.model[p3_module_idx].register_forward_hook(get_features_hook('p3'))
            print(f"Attached forward hooks to layers {p3_module_idx}(P3), {p4_module_idx}(P4), {p5_module_idx}(P5).")
        except Exception as e: print(f"Hook attachment failed: {e}. Please verify model architecture and indices."); raise e
        # Run a dummy forward pass to populate strides and feature maps for the first time
        self.yolo_model.predict(torch.zeros(1, 3, 640, 640).to(self.device), verbose=False)
        self.strides = {'p3': self.yolo_model.model.stride[0], 'p4': self.yolo_model.model.stride[1], 'p5': self.yolo_model.model.stride[2]}

    def _instantiate_aux_models(self):
        # A dummy forward pass is needed to get the channel dimensions from the hooks
        self.yolo_model.predict(torch.zeros(1, 3, 640, 640).to(self.device), verbose=False)
        p3_channels, p4_channels, p5_channels = self.feature_maps['p3'].shape[1], self.feature_maps['p4'].shape[1], self.feature_maps['p5'].shape[1]
        self.feature_compressors = nn.ModuleDict({'p3': FeatureCompressor(p3_channels, self.compressed_dim), 'p4': FeatureCompressor(p4_channels, self.compressed_dim), 'p5': FeatureCompressor(p5_channels, self.compressed_dim)}).to(self.device)
        self.aux_heads_mlp = AuxHeadsMLP(self.compressed_dim, len(self.l2_names)).to(self.device)

    def _extract_compact_features(self, candidate_boxes: torch.Tensor) -> torch.Tensor:
        from torchvision.ops import roi_align
        def choose_pyramid_level(boxes_xyxy: torch.Tensor, thresholds: Tuple[float, float]):
            if boxes_xyxy.numel() == 0: return torch.empty(0, dtype=torch.long, device=boxes_xyxy.device)
            max_side = torch.maximum(boxes_xyxy[:, 2] - boxes_xyxy[:, 0], boxes_xyxy[:, 3] - boxes_xyxy[:, 1])
            levels = torch.full_like(max_side, 4, dtype=torch.long); levels[max_side <= thresholds[0]] = 3; levels[max_side > thresholds[1]] = 5
            return levels
        levels = choose_pyramid_level(candidate_boxes, self.pyramid_thresholds)
        final_features = torch.zeros(len(candidate_boxes), self.compressed_dim, device=self.device)
        for level_idx in [3, 4, 5]:
            mask = (levels == level_idx)
            if not mask.any(): continue
            level_name = f'p{level_idx}'; boxes_on_level = candidate_boxes[mask]
            # roi_align expects a list of tensors, one for each image in the batch. We have a batch size of 1.
            pooled_feats = roi_align(self.feature_maps[level_name], [boxes_on_level], output_size=self.roi_align_size, spatial_scale=1.0 / self.strides[level_name])
            if pooled_feats.numel() > 0: final_features[mask] = self.feature_compressors[level_name](pooled_feats)
        return final_features

    @torch.no_grad()
    def predict(self, image_path: str, l3_confirm_thresh: float, l3_candidate_thresh: float, l2_rescue_thresh: float, l1_rescue_thresh: float, verbose: bool = False):
        # This entire method is unchanged.
        results = self.yolo_model(image_path, conf=l3_candidate_thresh, device=self.device, verbose=False)
        result = results[0]
        if result.boxes is None or len(result.boxes) == 0: return [], result
        
        candidates = result.boxes; candidate_boxes = candidates.xyxy
        compact_features = self._extract_compact_features(candidate_boxes)
        aux_outputs = self.aux_heads_mlp(compact_features)
        l1_probs = torch.sigmoid(aux_outputs['l1_logits']).squeeze(-1)
        l2_probs, l2_indices = torch.max(torch.softmax(aux_outputs['l2_logits'], dim=1), dim=1)
        
        final_detections = []
        if verbose: print("\n--- Processing Candidates ---")
        for i in range(len(candidates)):
            l3_conf, l3_cls_idx = candidates.conf[i].item(), candidates.cls[i].int().item()
            l3_cls_name = self.l3_names[l3_cls_idx]
            
            if l3_conf >= l3_confirm_thresh:
                final_detections.append({'bbox': candidates.xyxy[i].cpu().numpy(), 'label': l3_cls_name, 'confidence': l3_conf, 'type': 'confirmed'})
                continue

            l1_prob, l2_prob, l2_idx = l1_probs[i].item(), l2_probs[i].item(), l2_indices[i].item()
            l2_name = self.l2_names[l2_idx]
            is_consistent = self.class_to_l2.get(l3_cls_name) == l2_name
            
            if verbose:
                print(f"  Candidate {i} (L3: {l3_cls_name} @ {l3_conf:.2f}):")
                print(f"    - L2 Score: {l2_prob:.2f} ({l2_name}, consistent: {is_consistent})")
                print(f"    - L1 Score: {l1_prob:.2f}")

            if l2_prob >= l2_rescue_thresh and is_consistent:
                if verbose: print("    - RESCUED by L2")
                final_detections.append({'bbox': candidates.xyxy[i].cpu().numpy(), 'label': l2_name, 'confidence': l2_prob, 'type': 'rescued_l2'})
                continue
            
            if l1_prob >= l1_rescue_thresh:
                if verbose: print("    - RESCUED by L1")
                final_detections.append({'bbox': candidates.xyxy[i].cpu().numpy(), 'label': "vehicle", 'confidence': l1_prob, 'type': 'rescued_l1'})
                continue
            
            if verbose: print("    - REJECTED")
            
        return final_detections, result

## Hierarchical NMS & Visualization Functions

This cell contains our key post-processing and visualization logic.

### Hierarchical Non-Maximum Suppression (NMS)
To solve the problem of multiple overlapping boxes for a single object, we implement a custom NMS function. It sorts detections not just by confidence, but by a **priority score** based on the waterfall logic (`confirmed` > `rescued_l2` > `rescued_l1`). This ensures the most specific and reliable detection is always chosen.

In [None]:
# --- NMS and Visualization Helpers ---

def calculate_iou(box1, box2):
    """Calculates Intersection over Union for two bounding boxes."""
    x1_inter = max(box1[0], box2[0])
    y1_inter = max(box1[1], box2[1])
    x2_inter = min(box1[2], box2[2])
    y2_inter = min(box1[3], box2[3])

    inter_area = max(0, x2_inter - x1_inter) * max(0, y2_inter - y1_inter)
    if inter_area == 0: return 0.0
    
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box1[1])
    union_area = box1_area + box2_area - inter_area
    
    return inter_area / union_area

def hierarchical_nms(detections: List[Dict], iou_threshold: float) -> List[Dict]:
    """
    Performs Non-Maximum Suppression, prioritizing detections based on the waterfall logic.
    """
    if not detections:
        return []

    priority_map = {'confirmed': 3, 'rescued_l2': 2, 'rescued_l1': 1}
    sorted_detections = sorted(
        detections,
        key=lambda d: (priority_map.get(d['type'], 0), d['confidence']),
        reverse=True
    )
    
    kept_detections = []
    while sorted_detections:
        best_det = sorted_detections.pop(0)
        kept_detections.append(best_det)
        remaining_detections = []
        for det in sorted_detections:
            if calculate_iou(best_det['bbox'], det['bbox']) < iou_threshold:
                remaining_detections.append(det)
        sorted_detections = remaining_detections
        
    return kept_detections

# --- NEW: Advanced Visualization Code with Collision Avoidance ---
try:
    from matplotlib.font_manager import FontProperties
    font_prop = FontProperties(fname=FONT_PATH, size=11)
    FONT_NAME = font_prop.get_name()
    print(f"Using font: {FONT_NAME}")
except Exception:
    print(f"Font at {FONT_PATH} not found. Using default 'sans-serif'.")
    FONT_NAME = 'sans-serif'
    font_prop = FontProperties(family=FONT_NAME, size=11, weight='bold')

plt.rcParams['font.family'] = FONT_NAME

COLORS = {
    'baseline_high_conf': '#3498db', # Blue
    'confirmed': '#2ecc71',          # Green
    'rescued_l2': '#e67e22',          # Orange
    'rescued_l1': '#e74c3c',          # Red
}

def _get_text_bbox(ax, x, y, text, font_prop):
    """
    Internal helper to calculate the bounding box of a text label in data coordinates
    before it is drawn. This is essential for collision detection.
    """
    text_obj = ax.text(x, y, text, fontproperties=font_prop, 
                       bbox=dict(facecolor='red', alpha=1, pad=3, edgecolor='none', boxstyle='round,pad=0.4'))
    # Use the figure's renderer to calculate the pixel bounding box
    renderer = ax.get_figure().canvas.get_renderer()
    bbox_pixel = text_obj.get_window_extent(renderer=renderer)
    # Transform the pixel bounding box back into data coordinates
    bbox_data = ax.transData.inverted().transform(bbox_pixel)
    text_obj.remove() # Clean up the temporary text object
    return bbox_data

def _check_overlap(box1, box2):
    """Checks if two bounding boxes (x1, y1, x2, y2) overlap."""
    return not (box1[2] < box2[0] or box1[0] > box2[2] or box1[3] < box2[1] or box1[1] > box2[3])

def draw_detections_with_labels(ax, detections, color_map_key, draw_labels=True):
    """
    Main function to draw detection boxes and their labels with intelligent
    placement and collision avoidance.
    """
    drawn_label_bboxes = []

    for det in detections:
        # Draw the primary detection bounding box
        bbox = det['bbox']
        color = COLORS[det.get('type', color_map_key)]
        ax.add_patch(patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], 
                                       linewidth=2.5, edgecolor=color, facecolor='none', alpha=0.9))

        if not draw_labels:
            continue
            
        # --- Label Placement Logic ---
        label_text = det['label']
        
        # 1. Calculate the initial desired position and size of the label
        # Get the text bounding box to know its height for boundary checks
        text_bbox_dims = _get_text_bbox(ax, 0, 0, label_text, font_prop)
        text_height = text_bbox_dims[1][1] - text_bbox_dims[0][1]

        # Initial position: top-left corner of the detection box
        x, y = bbox[0], bbox[1]
        
        # Boundary check: If the label would go off-screen at the top, move to bottom-left
        if y < text_height:
            y = bbox[3]

        # 2. De-confliction Loop: Check for overlaps and shift if necessary
        is_colliding = True
        while is_colliding:
            is_colliding = False
            # Calculate the bounding box of the label at its current candidate position (x, y)
            current_label_bbox = _get_text_bbox(ax, x, y, label_text, font_prop)
            
            for drawn_box in drawn_label_bboxes:
                if _check_overlap(current_label_bbox, drawn_box):
                    is_colliding = True
                    # If collision, shift the label down by its height and re-check
                    y += text_height * 1.1 
                    break
        
        # 3. Draw the final, non-overlapping label
        ax.text(x, y, label_text, fontproperties=font_prop, color='white',
                verticalalignment='top', horizontalalignment='left',
                bbox=dict(facecolor=color, alpha=0.8, pad=3, edgecolor='none', boxstyle='round,pad=0.4'))
        
        # 4. Add the final label's bounding box to our list of drawn labels
        final_label_bbox = _get_text_bbox(ax, x, y, label_text, font_prop)
        drawn_label_bboxes.append(final_label_bbox)

print("Hierarchical NMS and Advanced Visualization helpers defined.")

Using font: Bahnschrift
Hierarchical NMS and Visualization helpers defined.


## 🚀 Cell 4: Main Inference Loop

This is the main execution cell. It performs the following steps:
1.  Instantiates our `HierarchicalDetector`.
2.  Selects a random sample of images from the specified input directory.
3.  For each image, it gets the **raw hierarchical detections**.
4.  It then runs our **`hierarchical_nms`** function to clean up overlapping boxes.
5.  Finally, it generates and displays the side-by-side comparison plot.

In [None]:
# --- NMS Threshold ---
NMS_IOU_THRESH = 0.5 # IoU threshold for suppressing overlapping boxes

# --- Instantiate the Detector ---
try:
    detector = HierarchicalDetector(YOLO_L3_PATH, AUX_HEADS_PATH, MANIFEST_PATH)
except FileNotFoundError as e:
    print(f"\n[FATAL ERROR] Could not initialize detector. A required file was not found: {e}")
    print("Please check the paths in Cell 1.")
except Exception as e:
    print(f"\n[FATAL ERROR] An unexpected error occurred during initialization: {e}")

# --- Get Images and Run Inference ---
if 'detector' in locals():
    image_files = [p for p in IMAGE_INPUT_DIR.glob("**/*") if p.suffix.lower() in ('.jpg', '.jpeg', '.png')]
    if not image_files:
        print(f"[ERROR] No images found in {IMAGE_INPUT_DIR}. Please check the path.")
    else:
        sample_images = random.sample(image_files, k=min(NUM_IMAGES_TO_PROCESS, len(image_files)))
        print(f"\nStarting inference on {len(sample_images)} images...")

        for i, img_path in enumerate(tqdm(sample_images, desc="Processing images")):
            # --- Run Inference to get raw detections ---
            raw_hierarchical_detections, baseline_result = detector.predict(
                image_path=str(img_path),
                l3_confirm_thresh=L3_CONFIRM_THRESH,
                l3_candidate_thresh=L3_CANDIDATE_THRESH,
                l2_rescue_thresh=L2_RESCUE_THRESH,
                l1_rescue_thresh=L1_RESCUE_THRESH,
                verbose=False # Set to True for detailed per-candidate logs
            )
            
            # --- Apply Hierarchical NMS ---
            final_detections = hierarchical_nms(raw_hierarchical_detections, NMS_IOU_THRESH)
            
            # --- Plotting ---
            img_rgb = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
            fig, axs = plt.subplots(1, 2, figsize=(28, 14))
            fig.patch.set_facecolor('white')

            # --- Plot Baseline Results ---
            axs[0].imshow(img_rgb)
            axs[0].set_title(f"Baseline YOLO (conf > {L3_CONFIRM_THRESH})", fontsize=18, weight='bold', pad=20)
            axs[0].axis('off')
            
            baseline_detections_to_draw = []
            if baseline_result.boxes is not None:
                for box in baseline_result.boxes:
                    if box.conf.item() >= L3_CONFIRM_THRESH:
                        label_name = detector.l3_names[int(box.cls.item())]
                        label = f"{label_name.upper()}: {box.conf.item():.2f}"
                        baseline_detections_to_draw.append({
                            'bbox': box.xyxy[0].cpu().numpy(),
                            'label': label,
                            'type': 'baseline_high_conf' # Key for color mapping
                        })
            draw_detections_with_labels(axs[0], baseline_detections_to_draw, 'baseline_high_conf', draw_labels=DRAW_LABELS)

            # --- Plot Hierarchical Results (after NMS) ---
            axs[1].imshow(img_rgb)
            axs[1].set_title("Hierarchical Model (after Waterfall Rescue + NMS)", fontsize=18, weight='bold', pad=20)
            axs[1].axis('off')
            
            # Prepare labels for hierarchical detections
            hierarchical_detections_to_draw = []
            for det in final_detections:
                label_text = det['label'].replace('_', ' ').upper()
                conf = det['confidence']
                det_type = det['type']
                if det_type == 'confirmed': label = f"{label_text}: {conf:.2f}"
                elif det_type == 'rescued_l2': label = f"{label_text} [L2 RESCUED]: {conf:.2f}"
                elif det_type == 'rescued_l1': label = f"{label_text.upper()} [L1 RESCUED]: {conf:.2f}"
                det_copy = det.copy()
                det_copy['label'] = label
                hierarchical_detections_to_draw.append(det_copy)
            draw_detections_with_labels(axs[1], hierarchical_detections_to_draw, 'confirmed', draw_labels=DRAW_LABELS)
            
            plt.tight_layout(pad=1.5)
            save_path = OUTPUT_DIR / f"comparison_{img_path.stem}.png"
            plt.savefig(save_path, dpi=120, bbox_inches='tight')
            # print(f"Saved comparison image to: {save_path}") # Optional: uncomment for verbose saving
            
            plt.show()
            plt.close(fig)

        print(f"\n✅ Inference complete. {len(sample_images)} comparison images saved to '{OUTPUT_DIR}'.")