# Binary Segmentation Inference with SAM Refinement

This notebook implements an advanced path detection pipeline combining:

1. **Binary U-Net Segmentation**:
   - Initial path detection
   - Temporally smoothed predictions
   - Fast, real-time capable processing

2. **Segment Anything Model (SAM)**:
   - Refinement of U-Net predictions
   - High-precision boundary detection
   - Point-prompt based segmentation

The pipeline processes video input and generates path annotations with:
- Temporal consistency through weighted averaging
- Robust handling of challenging frames
- Semi-transparent path overlays for visualization

## 1. Required Libraries

In [27]:
from pathlib import Path
import torch
import numpy as np
from tqdm import tqdm
import cv2 # OpenCV for video processing
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from segment_anything import sam_model_registry, SamPredictor
from collections import deque

## 2. Configuration

Key settings for the inference pipeline:

### Model Settings
- **Binary U-Net**: Trained model selection via run name
- **SAM Model**: Pre-trained ViT-B checkpoint configuration
- **Device**: Automatic hardware acceleration selection

### Video Processing
- Input video selection from splits directory
- Output directory organization by run name
- Automatic file path handling

### Output Configuration
- Structured directory creation
- Consistent naming conventions
- MP4 format video output

In [30]:
# Specify the name of the binary U-Net training run you want to use.
RUN_NAME = "binary_unet_2025-08-03_22-53-32" # <--- CHANGE THIS

# --- SAM Model Configuration ---
SAM_CHECKPOINT_PATH = Path("./metrics/sam_vit_b.pth") # <--- Path to your downloaded SAM model
SAM_MODEL_TYPE = "vit_b"

# Specify which video clip to process
VIDEO_FILENAME = "Clip_1_42s.mp4" # <--- CHANGE THIS

# --- Paths and settings derived automatically ---
METRICS_DIR = Path("metrics") / RUN_NAME
MODEL_PATH = METRICS_DIR / "best_binary_model.pth"
VIDEO_PATH = Path("../data/video/splits") / VIDEO_FILENAME
OUTPUT_DIR = METRICS_DIR / "video_predictions"
OUTPUT_DIR.mkdir(exist_ok=True)
OUTPUT_VIDEO_PATH = OUTPUT_DIR / f"predicted_binary_{VIDEO_FILENAME}"

if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"

print(f"Loading model from run: {RUN_NAME}")
print(f"Using SAM checkpoint: {SAM_CHECKPOINT_PATH}")
print(f"Processing video: {VIDEO_PATH}")
print(f"Using device: {DEVICE}")

Loading model from run: binary_unet_2025-08-03_22-53-32
Using SAM checkpoint: metrics/sam_vit_b.pth
Processing video: ../data/video/splits/Clip_1_42s.mp4
Using device: mps


## 3. Model Loading and Preprocessing

### Image Processing Pipeline
- Fixed input dimensions (480x640)
- ImageNet normalization
- Albumentations-based transformations

### Model Initialization
1. **Binary U-Net**:
   - ResNet34 backbone
   - Two-class segmentation
   - CPU/GPU compatibility

2. **SAM Model**:
   - ViT-B architecture
   - Zero-shot capability
   - Point prompt interface

In [31]:
IMG_HEIGHT = 480
IMG_WIDTH = 640
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# We only need the basic transforms for inference
unet_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0),
    ToTensorV2(),
])

unet_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=2).to(DEVICE)
unet_model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
unet_model.eval()
print("U-Net model loaded successfully.")

sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)
print("SAM model loaded successfully.")

U-Net model loaded successfully.
SAM model loaded successfully.


## 4. Utility Functions

### Path Detection Pipeline

1. **Smoothed U-Net Prediction** (`get_smoothed_unet_prediction`):
   - Processes RGB frames through U-Net
   - Implements temporal smoothing via weighted averaging
   - Handles resolution matching

2. **Prompt Generation** (`mask_to_prompt_points`):
   - Extracts path center points from U-Net masks
   - Supports multiple point prompts
   - Provides fallback mechanisms for robustness

### Key Features
- Temporal consistency through moving average
- Adaptive prompt generation
- Robust contour analysis
- Automatic resolution handling

In [32]:
PATH_COLOR_BGR = (152, 16, 60) # Purple
PATH_CLASS_ID = 1

def get_smoothed_unet_prediction(frame: np.ndarray, logits_history: deque):
    """
    Gets a U-Net prediction and applies a weighted moving average over the
    last N frames for temporal smoothing.
    """
    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    augmented = unet_transform(image=image_rgb)
    image_tensor = augmented['image'].unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        current_logits = unet_model(image_tensor)
    
    # Add the latest prediction to our history
    logits_history.append(current_logits)
    
    # --- Weighted Moving Average ---
    # Create weights that give more importance to recent frames
    weights = torch.linspace(0.1, 1.0, len(logits_history)).to(DEVICE)
    weights = weights / weights.sum() # Normalize weights to sum to 1
    
    # Calculate the weighted average of the logits in the history
    # Reshape weights to match the tensor dimensions for broadcasting
    weighted_sum = torch.zeros_like(current_logits)
    for i, logits in enumerate(logits_history):
        weighted_sum += logits * weights[i]
        
    preds = torch.argmax(weighted_sum, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
    final_mask = cv2.resize(preds, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    return final_mask

def mask_to_prompt_points(mask: np.ndarray, max_points: 3):
    """
    Finds the largest contour in the U-Net mask and returns its center point.
    This point will be the prompt for SAM.
    """
    # Find all contours of the path
    contours, _ = cv2.findContours((mask == PATH_CLASS_ID).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if not contours:
        return None # No path found by U-Net
        
    sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True)
    
    points = []
    for contour in sorted_contours[:max_points]:
        M = cv2.moments(contour)
        if M["m00"] > 0:
            center_x = int(M["m10"] / M["m00"])
            center_y = int(M["m01"] / M["m00"])
            points.append([center_x, center_y])
            
    return np.array(points) if points else None

## 5. Video Processing Pipeline

The main processing loop implements a sophisticated path detection workflow:

### Frame Processing Steps
1. **Initial Detection**:
   - U-Net prediction
   - Temporal smoothing application
   
2. **SAM Refinement**:
   - Point prompt generation
   - Multiple mask prediction
   - Best mask selection via IoU

3. **State Management**:
   - Path point tracking
   - Fallback handling
   - Temporal consistency

4. **Visualization**:
   - Semi-transparent overlays
   - Path highlighting
   - Frame composition

The pipeline maintains smooth transitions and robust detection across challenging frames.

In [33]:
cap = cv2.VideoCapture(str(VIDEO_PATH))
if not cap.isOpened():
    print(f"Error: Could not open video file {VIDEO_PATH}")
else:
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(str(OUTPUT_VIDEO_PATH), fourcc, fps, (frame_width, frame_height))

    print(f"Processing {total_frames} frames...")

    last_known_points = None
    logits_history = deque(maxlen=10) 
    for _ in tqdm(range(total_frames)):
        ret, frame = cap.read()
        if not ret:
            break
            
        # 1. Get coarse prediction from U-Net
        unet_mask = get_smoothed_unet_prediction(frame, logits_history)
        
        # 2. Convert mask into multiple point prompts
        current_points = mask_to_prompt_points(unet_mask, max_points=3)
        
        # 3. Fallback logic and state update
        prompt_points_for_sam = None
        if current_points is not None:
            prompt_points_for_sam = current_points
            last_known_points = current_points # Update our state
        elif last_known_points is not None:
            prompt_points_for_sam = last_known_points # Use the fallback
        
        final_mask = np.zeros_like(unet_mask, dtype=bool)

        # 4. If we have any points to use, run SAM
        if prompt_points_for_sam is not None:
            sam_predictor.set_image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            
            # Create a label for each point (all are foreground prompts)
            input_labels = np.ones(len(prompt_points_for_sam))
            
            masks_sam, scores, _ = sam_predictor.predict(
                point_coords=prompt_points_for_sam,
                point_labels=input_labels,
                multimask_output=True,
            )

            # Select the best mask based on IoU with the U-Net prediction
            best_iou = -1
            best_mask_idx = -1
            unet_path_mask = (unet_mask == PATH_CLASS_ID)

            for i, mask in enumerate(masks_sam):
                intersection = np.logical_and(unet_path_mask, mask).sum()
                union = np.logical_or(unet_path_mask, mask).sum()
                iou = intersection / union if union > 0 else 0
                
                if iou > best_iou:
                    best_iou = iou
                    best_mask_idx = i
            
            if best_mask_idx != -1:
                final_mask = masks_sam[best_mask_idx]

        # 5. Create the visualization
        overlay = np.zeros_like(frame, dtype=np.uint8)
        overlay[final_mask] = PATH_COLOR_BGR
        final_frame = cv2.addWeighted(frame, 1.0, overlay, 0.6, 0)
        
        out.write(final_frame)

        # 4. Create the visualization
        overlay = np.zeros_like(frame, dtype=np.uint8)
        overlay[final_mask] = PATH_COLOR_BGR
        final_frame = cv2.addWeighted(frame, 1.0, overlay, 0.6, 0)
        
        out.write(final_frame)
        
    cap.release()
    out.release()
    print("\nProcessing complete.")
    print(f"Output video saved to: {OUTPUT_VIDEO_PATH}")

Processing 1287 frames...


100%|██████████| 1287/1287 [11:56<00:00,  1.80it/s]


Processing complete.
Output video saved to: metrics/binary_unet_2025-08-03_22-53-32/video_predictions/predicted_binary_Clip_1_42s.mp4



