# Binary Segmentation Inference Pipeline for Path Detection

## Overview
This notebook implements an advanced path detection system that combines traditional segmentation with state-of-the-art refinement techniques for robust path identification in peatland environments.

## Core Components

### 1. Binary U-Net Segmentation
- **Purpose**: Primary path detection
- **Features**:
  * Real-time capable processing
  * Temporal smoothing integration
  * Robust initial predictions
- **Output**: Coarse path masks

### 2. Segment Anything Model (SAM) Refinement
- **Purpose**: High-precision boundary refinement
- **Features**:
  * Zero-shot adaptation capability
  * Point-prompt based refinement
  * Multi-mask prediction
- **Output**: Fine-grained path boundaries

## Pipeline Integration

### Processing Flow
1. Initial path detection via U-Net
2. Temporal consistency through weighted averaging
3. SAM-based boundary refinement
4. Visualization with semi-transparent overlays

### Key Features
- Robust handling of challenging scenes
- Adaptive prompt generation
- Real-time processing capability
- Smooth temporal transitions

This system provides reliable path detection for autonomous navigation in complex peatland environments.

In [27]:
from pathlib import Path
import torch
import numpy as np
from tqdm import tqdm
import cv2 # OpenCV for video processing
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from segment_anything import sam_model_registry, SamPredictor
from collections import deque

## 2. System Configuration and Parameters

### Model Configuration

1. **Binary U-Net Settings**
   - **Model Selection**:
     * Run-specific checkpoint loading
     * Performance-optimized weights
     * Architecture configuration
   
   - **Processing Parameters**:
     * Input resolution
     * Batch processing
     * Memory optimization

2. **SAM Configuration**
   - **Model Variant**:
     * ViT-B architecture selection
     * Checkpoint management
     * Resource allocation
   
   - **Runtime Settings**:
     * Inference parameters
     * Multi-mask generation
     * Point prompt configuration

### Processing Pipeline Configuration

1. **Video Input Settings**
   - **Source Management**:
     * Video clip selection
     * Frame rate handling
     * Resolution configuration
   
   - **Format Handling**:
     * Codec compatibility
     * Color space management
     * Memory efficiency

2. **Output Organization**
   - **Directory Structure**:
     * Run-specific organization
     * Consistent naming conventions
     * Version control support
   
   - **File Management**:
     * Automatic path generation
     * Existing file handling
     * Space optimization

### Hardware Optimization
- Automatic device selection (CUDA/MPS/CPU)
- Memory usage optimization
- Processing speed adaptation
- Resource allocation

These configurations ensure:
- Consistent processing
- Reproducible results
- Efficient resource utilization
- Organized output management

In [30]:
# Specify the name of the binary U-Net training run you want to use.
RUN_NAME = "binary_unet_2025-08-03_22-53-32" # <--- CHANGE THIS

# --- SAM Model Configuration ---
SAM_CHECKPOINT_PATH = Path("./metrics/sam_vit_b.pth") # <--- Path to your downloaded SAM model
SAM_MODEL_TYPE = "vit_b"

# Specify which video clip to process
VIDEO_FILENAME = "Clip_1_42s.mp4" # <--- CHANGE THIS

# --- Paths and settings derived automatically ---
METRICS_DIR = Path("metrics") / RUN_NAME
MODEL_PATH = METRICS_DIR / "best_binary_model.pth"
VIDEO_PATH = Path("../data/video/splits") / VIDEO_FILENAME
OUTPUT_DIR = METRICS_DIR / "video_predictions"
OUTPUT_DIR.mkdir(exist_ok=True)
OUTPUT_VIDEO_PATH = OUTPUT_DIR / f"predicted_binary_{VIDEO_FILENAME}"

if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"

print(f"Loading model from run: {RUN_NAME}")
print(f"Using SAM checkpoint: {SAM_CHECKPOINT_PATH}")
print(f"Processing video: {VIDEO_PATH}")
print(f"Using device: {DEVICE}")

Loading model from run: binary_unet_2025-08-03_22-53-32
Using SAM checkpoint: metrics/sam_vit_b.pth
Processing video: ../data/video/splits/Clip_1_42s.mp4
Using device: mps


## 3. Model Loading and Preprocessing Pipeline

### Image Processing Framework

1. **Input Standardization**
   - **Resolution Control**:
     * Fixed dimensions (480x640)
     * Aspect ratio management
     * Scale normalization
   
   - **Color Processing**:
     * RGB channel handling
     * ImageNet normalization
     * Intensity scaling

2. **Augmentation Pipeline**
   - **Albumentations Integration**:
     * Efficient transformation pipeline
     * GPU-compatible operations
     * Memory-optimized processing

### Model Architecture Setup

1. **Binary U-Net Configuration**
   - **Architecture Details**:
     * ResNet34 backbone selection
     * Skip connection implementation
     * Binary classification head
   
   - **Initialization Process**:
     * Weight loading verification
     * Device optimization
     * Memory allocation

2. **SAM Integration**
   - **Model Setup**:
     * ViT-B architecture loading
     * Checkpoint verification
     * Device placement
   
   - **Predictor Configuration**:
     * Point prompt interface setup
     * Multi-mask generation
     * Zero-shot adaptation

### System Integration
- Coordinated model initialization
- Memory-efficient loading
- Device-specific optimization
- Pipeline synchronization

This setup ensures efficient and reliable model deployment for real-time path detection.

In [31]:
IMG_HEIGHT = 480
IMG_WIDTH = 640
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# We only need the basic transforms for inference
unet_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0),
    ToTensorV2(),
])

unet_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=2).to(DEVICE)
unet_model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
unet_model.eval()
print("U-Net model loaded successfully.")

sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)
print("SAM model loaded successfully.")

U-Net model loaded successfully.
SAM model loaded successfully.


## 4. Core Processing Functions

### Path Detection Pipeline Components

1. **Smoothed U-Net Prediction System**
   
   **Function**: `get_smoothed_unet_prediction`
   - **Input Processing**:
     * RGB frame conversion
     * Resolution standardization
     * Tensor preparation
   
   - **Prediction Pipeline**:
     * U-Net inference
     * Temporal smoothing
     * Weighted averaging
   
   - **Output Processing**:
     * Resolution restoration
     * Binary mask generation
     * Memory optimization

2. **Prompt Generation Framework**
   
   **Function**: `mask_to_prompt_points`
   - **Contour Analysis**:
     * Path region identification
     * Area-based filtering
     * Multi-point selection
   
   - **Point Generation**:
     * Centroid calculation
     * Multiple prompt support
     * Spatial distribution

### Advanced Features

1. **Temporal Consistency**
   - **Moving Average Implementation**:
     * Weight distribution optimization
     * Historical context integration
     * Smooth transition management
   
   - **State Management**:
     * History buffer maintenance
     * Update frequency control
     * Memory efficiency

2. **Robustness Mechanisms**
   - **Fallback Strategies**:
     * No-detection handling
     * State persistence
     * Recovery procedures
   
   - **Quality Assurance**:
     * Contour validation
     * Area thresholding
     * Consistency checking

### System Integration
- Efficient memory utilization
- Real-time processing capability
- Error handling mechanisms
- Pipeline synchronization

These functions form the core processing engine for reliable path detection and refinement.

In [32]:
PATH_COLOR_BGR = (152, 16, 60) # Purple
PATH_CLASS_ID = 1

def get_smoothed_unet_prediction(frame: np.ndarray, logits_history: deque):
    """
    Gets a U-Net prediction and applies a weighted moving average over the
    last N frames for temporal smoothing.
    """
    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    augmented = unet_transform(image=image_rgb)
    image_tensor = augmented['image'].unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        current_logits = unet_model(image_tensor)
    
    # Add the latest prediction to our history
    logits_history.append(current_logits)
    
    # --- Weighted Moving Average ---
    # Create weights that give more importance to recent frames
    weights = torch.linspace(0.1, 1.0, len(logits_history)).to(DEVICE)
    weights = weights / weights.sum() # Normalize weights to sum to 1
    
    # Calculate the weighted average of the logits in the history
    # Reshape weights to match the tensor dimensions for broadcasting
    weighted_sum = torch.zeros_like(current_logits)
    for i, logits in enumerate(logits_history):
        weighted_sum += logits * weights[i]
        
    preds = torch.argmax(weighted_sum, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
    final_mask = cv2.resize(preds, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    return final_mask

def mask_to_prompt_points(mask: np.ndarray, max_points: 3):
    """
    Finds the largest contour in the U-Net mask and returns its center point.
    This point will be the prompt for SAM.
    """
    # Find all contours of the path
    contours, _ = cv2.findContours((mask == PATH_CLASS_ID).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if not contours:
        return None # No path found by U-Net
        
    sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True)
    
    points = []
    for contour in sorted_contours[:max_points]:
        M = cv2.moments(contour)
        if M["m00"] > 0:
            center_x = int(M["m10"] / M["m00"])
            center_y = int(M["m01"] / M["m00"])
            points.append([center_x, center_y])
            
    return np.array(points) if points else None

## 5. Video Processing Pipeline

### Processing Framework

1. **Frame Acquisition and Management**
   - **Video Stream Handling**:
     * Frame extraction
     * Format verification
     * Resolution management
     * Memory optimization
   
   - **State Management**:
     * Buffer initialization
     * History tracking
     * Resource allocation

2. **Path Detection Workflow**

   **Initial Detection Phase**:
   - U-Net prediction generation
   - Temporal smoothing application
   - Confidence validation
   - State persistence
   
   **Refinement Phase**:
   - Point prompt extraction
   - SAM model adaptation
   - Multi-mask generation
   - IoU-based selection

3. **Robustness Mechanisms**

   **State Handling**:
   - Path point tracking
   - Historical context
   - Recovery procedures
   - Fallback strategies
   
   **Quality Assurance**:
   - Prediction validation
   - Consistency checking
   - Error recovery
   - Performance monitoring

4. **Visualization Pipeline**

   **Mask Rendering**:
   - Semi-transparent overlay
   - Color scheme application
   - Path highlighting
   - Visual clarity
   
   **Frame Composition**:
   - Layer blending
   - Opacity control
   - Visual feedback
   - Quality assurance

### System Integration
- Efficient frame processing
- Memory management
- Error handling
- Performance optimization

This comprehensive pipeline ensures:
- Reliable path detection
- Smooth visual output
- Real-time processing
- Robust performance

In [33]:
cap = cv2.VideoCapture(str(VIDEO_PATH))
if not cap.isOpened():
    print(f"Error: Could not open video file {VIDEO_PATH}")
else:
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(str(OUTPUT_VIDEO_PATH), fourcc, fps, (frame_width, frame_height))

    print(f"Processing {total_frames} frames...")

    last_known_points = None
    logits_history = deque(maxlen=10) 
    for _ in tqdm(range(total_frames)):
        ret, frame = cap.read()
        if not ret:
            break
            
        # 1. Get coarse prediction from U-Net
        unet_mask = get_smoothed_unet_prediction(frame, logits_history)
        
        # 2. Convert mask into multiple point prompts
        current_points = mask_to_prompt_points(unet_mask, max_points=3)
        
        # 3. Fallback logic and state update
        prompt_points_for_sam = None
        if current_points is not None:
            prompt_points_for_sam = current_points
            last_known_points = current_points # Update our state
        elif last_known_points is not None:
            prompt_points_for_sam = last_known_points # Use the fallback
        
        final_mask = np.zeros_like(unet_mask, dtype=bool)

        # 4. If we have any points to use, run SAM
        if prompt_points_for_sam is not None:
            sam_predictor.set_image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            
            # Create a label for each point (all are foreground prompts)
            input_labels = np.ones(len(prompt_points_for_sam))
            
            masks_sam, scores, _ = sam_predictor.predict(
                point_coords=prompt_points_for_sam,
                point_labels=input_labels,
                multimask_output=True,
            )

            # Select the best mask based on IoU with the U-Net prediction
            best_iou = -1
            best_mask_idx = -1
            unet_path_mask = (unet_mask == PATH_CLASS_ID)

            for i, mask in enumerate(masks_sam):
                intersection = np.logical_and(unet_path_mask, mask).sum()
                union = np.logical_or(unet_path_mask, mask).sum()
                iou = intersection / union if union > 0 else 0
                
                if iou > best_iou:
                    best_iou = iou
                    best_mask_idx = i
            
            if best_mask_idx != -1:
                final_mask = masks_sam[best_mask_idx]

        # 5. Create the visualization
        overlay = np.zeros_like(frame, dtype=np.uint8)
        overlay[final_mask] = PATH_COLOR_BGR
        final_frame = cv2.addWeighted(frame, 1.0, overlay, 0.6, 0)
        
        out.write(final_frame)

        # 4. Create the visualization
        overlay = np.zeros_like(frame, dtype=np.uint8)
        overlay[final_mask] = PATH_COLOR_BGR
        final_frame = cv2.addWeighted(frame, 1.0, overlay, 0.6, 0)
        
        out.write(final_frame)
        
    cap.release()
    out.release()
    print("\nProcessing complete.")
    print(f"Output video saved to: {OUTPUT_VIDEO_PATH}")

Processing 1287 frames...


100%|██████████| 1287/1287 [11:56<00:00,  1.80it/s]


Processing complete.
Output video saved to: metrics/binary_unet_2025-08-03_22-53-32/video_predictions/predicted_binary_Clip_1_42s.mp4



