In [1]:
import torch
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
from ultralytics import YOLO
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from collections import deque
from segment_anything import sam_model_registry, SamPredictor

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
# --- Model Paths ---
SEGMENTATION_RUN_NAME = "binary_unet_2025-08-03_22-53-32" # <--- Your best binary U-Net run
DETECTION_RUN_NAME = "finetuned" # <--- Your best fine-tuned detector run
SAM_CHECKPOINT_PATH = Path("sam_vit_b.pth") # <--- Path to your downloaded SAM model
SAM_MODEL_TYPE = "vit_b"

# --- Video Configuration ---
VIDEO_FILENAME = "Clip_2_34s.mp4" 
FPS_REDUCTION_FACTOR = 6 # Process 1 frame every N frames (60fps / 6 = 10fps)

# --- Detection Configuration ---
CONFIDENCE_THRESHOLD = 0.5
MIN_PATH_AREA_THRESHOLD = 2000 # Minimum area of the path to be considered valid

# --- Paths and settings derived automatically ---
SEG_MODEL_PATH = Path("metrics/segmentation") / SEGMENTATION_RUN_NAME / "best_binary_model.pth"
DET_MODEL_PATH = Path("metrics/detection") / DETECTION_RUN_NAME / "weights/best.pt"
VIDEO_PATH = Path("../data/video/splits") / VIDEO_FILENAME
OUTPUT_DIR = Path("final_video_output")
OUTPUT_DIR.mkdir(exist_ok=True)
OUTPUT_VIDEO_PATH = OUTPUT_DIR / f"showcase_{VIDEO_FILENAME}"

if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"

print(f"Device: {DEVICE}")
print(f"Loading Segmentation Model: {SEG_MODEL_PATH}")
print(f"Loading Detection Model: {DET_MODEL_PATH}")

Device: mps
Loading Segmentation Model: metrics/segmentation/binary_unet_2025-08-03_22-53-32/best_binary_model.pth
Loading Detection Model: metrics/detection/finetuned/weights/best.pt


In [16]:
# --- Segmentation U-Net ---
seg_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=2).to(DEVICE)
seg_model.load_state_dict(torch.load(SEG_MODEL_PATH, map_location=torch.device(DEVICE)))
seg_model.eval()
seg_transform = A.Compose([
    A.Resize(height=480, width=640),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
    ToTensorV2(),
])
print("Segmentation model loaded.")

# --- SAM ---
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)
print("SAM model loaded.")

# --- Detection YOLOv8 ---
detection_model = YOLO(DET_MODEL_PATH)
BENCH_CLASS_ID = [k for k, v in detection_model.names.items() if v == 'bench'][0]
CONE_CLASS_ID = [k for k, v in detection_model.names.items() if v == 'cone'][0]
print("Detection model loaded.")

# --- Depth MiDaS ---
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
midas.to(DEVICE)
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
midas_transform = midas_transforms.small_transform
print("Depth model loaded.")

Segmentation model loaded.
SAM model loaded.
Detection model loaded.


Using cache found in /Users/stahlma/.cache/torch/hub/intel-isl_MiDaS_master
Using cache found in /Users/stahlma/.cache/torch/hub/rwightman_gen-efficientnet-pytorch_master


Loading weights:  None
Depth model loaded.


Using cache found in /Users/stahlma/.cache/torch/hub/intel-isl_MiDaS_master


In [17]:
PATH_COLOR_BGR = (227, 124, 240)  # A pleasing light purple
PATH_CLASS_ID = 1
DISTANCE_CALIBRATION_FACTOR = 500 

def get_smoothed_path_mask(frame, logits_history):
    """Gets U-Net prediction and applies a weighted moving average."""
    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    image_tensor = seg_transform(image=image_rgb)['image'].unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        current_logits = seg_model(image_tensor)
    
    logits_history.append(current_logits)
    weights = torch.linspace(0.1, 1.0, len(logits_history)).to(DEVICE)
    weights = weights / weights.sum()
    
    weighted_sum = torch.zeros_like(current_logits)
    for i, logits in enumerate(logits_history):
        weighted_sum += logits * weights[i]
        
    preds = torch.argmax(weighted_sum, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
    return cv2.resize(preds, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)

def get_depth_map(frame):
    """Gets MiDaS depth prediction."""
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    input_batch = midas_transform(img_rgb).to(DEVICE)
    with torch.no_grad():
        prediction = midas(input_batch)
    return torch.nn.functional.interpolate(
        prediction.unsqueeze(1), size=img_rgb.shape[:2], mode="bicubic", align_corners=False
    ).squeeze().cpu().numpy()

def mask_to_prompt_points(mask, max_points=3):
    """Finds the largest contours in the U-Net mask and returns their center points."""
    contours, _ = cv2.findContours((mask == PATH_CLASS_ID).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours: return None, 0
    
    sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True)
    largest_contour_area = cv2.contourArea(sorted_contours[0])
    
    points = []
    for contour in sorted_contours[:max_points]:
        M = cv2.moments(contour)
        if M["m00"] > 0:
            points.append([int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])])
            
    return np.array(points) if points else None, largest_contour_area

def estimate_width_in_meters(pixel_width, distance_m, frame_width, fov_degrees=70):
    """Estimates real-world width using the pinhole camera model."""
    # Calculate the camera's focal length in pixels
    fov_radians = np.deg2rad(fov_degrees)
    focal_length_pixels = (frame_width / 2) / np.tan(fov_radians / 2)
    
    # Calculate the size of one pixel in meters at the given distance
    meters_per_pixel = distance_m / focal_length_pixels
    
    # Estimate the real-world width
    width_meters = pixel_width * meters_per_pixel * 1.5  # Adjust factor based on visual calibration
    return width_meters

In [20]:
cap = cv2.VideoCapture(str(VIDEO_PATH))
if not cap.isOpened():
    print(f"Error opening video file {VIDEO_PATH}")
else:
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(str(OUTPUT_VIDEO_PATH), fourcc, fps, (frame_width, frame_height))

    last_refined_mask, last_boxes, last_depth_map, last_known_points = None, [], None, None
    frame_count = 0
    logits_history = deque(maxlen=10)

    for _ in tqdm(range(total_frames), desc="Creating final video"):
        ret, frame = cap.read()
        if not ret: break

        if frame_count % FPS_REDUCTION_FACTOR == 0:
            unet_mask = get_smoothed_path_mask(frame, logits_history)
            current_points, path_area = mask_to_prompt_points(unet_mask)
            
            # --- UPDATED: Confidence check for path prediction ---
            if path_area > MIN_PATH_AREA_THRESHOLD:
                last_known_points = current_points
            else: # If path is too small or non-existent, don't use a prompt
                last_known_points = None

            if last_known_points is not None:
                sam_predictor.set_image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                masks_sam, _, _ = sam_predictor.predict(
                    point_coords=last_known_points,
                    point_labels=np.ones(len(last_known_points)),
                    multimask_output=True,
                )
                
                best_iou, best_mask = -1, None
                unet_path_mask_bool = (unet_mask == PATH_CLASS_ID)
                for mask in masks_sam:
                    intersection = np.logical_and(unet_path_mask_bool, mask).sum()
                    union = np.logical_or(unet_path_mask_bool, mask).sum()
                    iou = intersection / union if union > 0 else 0
                    if iou > best_iou:
                        best_iou, best_mask = iou, mask
                last_refined_mask = best_mask
            else:
                last_refined_mask = None
            
            last_boxes = detection_model(frame, verbose=False)[0].boxes
            last_depth_map = get_depth_map(frame)

        annotated_frame = frame.copy()
        
        if last_refined_mask is not None:
            contours, _ = cv2.findContours(last_refined_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(annotated_frame, contours, -1, PATH_COLOR_BGR, 3)

            width_measure_y = int(frame_height * 0.8)
            line = last_refined_mask[width_measure_y, :]
            path_pixels = np.where(line)[0]
            if len(path_pixels) > 1 and last_depth_map is not None:
                x_start, x_end = path_pixels[0], path_pixels[-1]
                pixel_width = x_end - x_start
                
                depth_at_line = np.median(last_depth_map[width_measure_y, x_start:x_end])
                distance_m = DISTANCE_CALIBRATION_FACTOR / depth_at_line
                
                path_width_meters = estimate_width_in_meters(pixel_width, distance_m, frame_width)

                cv2.line(annotated_frame, (x_start, width_measure_y), (x_end, width_measure_y), (255, 255, 0), 3)
                label = f"Path Width: {path_width_meters:.2f}m"
                cv2.putText(annotated_frame, label, (x_start, width_measure_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 0), 2)

        if last_boxes is not None and last_depth_map is not None:
            for box in last_boxes:
                if box.conf >= CONFIDENCE_THRESHOLD:
                    x1, y1, x2, y2 = map(int, box.xyxy[0])
                    class_id = int(box.cls[0])
                    class_name = detection_model.names[class_id]
                    
                    label = f"{class_name} ({box.conf[0]:.2f})"
                    color = (0, 255, 0)

                    box_depth = last_depth_map[y1:y2, x1:x2]
                    if box_depth.size > 0:
                        distance = DISTANCE_CALIBRATION_FACTOR / np.median(box_depth)
                        label += f" | Dist: {distance:.1f}m"

                    cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2)
                    cv2.putText(annotated_frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)

        out.write(annotated_frame)
        frame_count += 1
        
    cap.release()
    out.release()
    print(f"\nFinal showcase video saved to: {OUTPUT_VIDEO_PATH}")


Creating final video: 100%|██████████| 1032/1032 [02:10<00:00,  7.94it/s]


Final showcase video saved to: final_video_output/showcase_Clip_2_34s.mp4



