# RIA Segmentation Pipeline - Google Colab

This notebook performs SAM2-based video segmentation for RIA (pharyngeal pumping) analysis in C. elegans.

## Key Features:
- **Cloud Storage Integration**: Works with Google Drive
- **Interactive Widgets**: User-friendly bounding box collection
- **Batch Processing**: Process multiple videos sequentially
- **GPU Acceleration**: Automatic GPU detection and usage
- **Result Visualization**: Preview videos with mask overlays

## Setup Instructions:
1. Upload your video data to Google Drive in the specified folder structure
2. Run the setup cell to install dependencies
3. Execute the pipeline cells in order
4. Download results when complete

## 🔧 Environment Setup

In [None]:
# Check if running in Colab and setup environment
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running in Google Colab")
    
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    
except ImportError:
    IN_COLAB = False
    print("Running locally")

# Install required packages
if IN_COLAB:
    print("Installing required packages...")
    !pip install -q segment-anything-2 tifffile h5py opencv-python matplotlib pandas scipy scikit-image ipywidgets
    
    # Install SAM2 from GitHub
    !pip install -q git+https://github.com/facebookresearch/segment-anything-2.git
    
    print("✓ Package installation complete")

In [None]:
# Import required libraries
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import h5py
import json
import random
from PIL import Image
from tqdm import tqdm
import shutil
from pathlib import Path

# Colab-specific imports
if IN_COLAB:
    import ipywidgets as widgets
    from IPython.display import display, clear_output, Image as IPImage
    from google.colab import files

print("✓ All imports successful")

## 🤖 Model Setup

In [None]:
# Setup directories and download SAM2 model
work_dir = '/content/ria_segmentation' if IN_COLAB else './ria_segmentation'
model_dir = '/content/models' if IN_COLAB else './models'

os.makedirs(work_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

if IN_COLAB:
    os.chdir(work_dir)

# Download SAM2 model checkpoint
checkpoint_path = f"{model_dir}/sam2.1_hiera_base_plus.pt"

if not os.path.exists(checkpoint_path):
    print("Downloading SAM2 model checkpoint...")
    if IN_COLAB:
        !wget -O {checkpoint_path} https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2.1_hiera_base_plus.pt
    else:
        import urllib.request
        urllib.request.urlretrieve(
            'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2.1_hiera_base_plus.pt',
            checkpoint_path
        )
    print("✓ Model downloaded")
else:
    print("✓ Model already exists")

print(f"Model path: {checkpoint_path}")

In [None]:
# Initialize SAM2 model
def setup_sam2_model():
    # Select device
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"✓ Using CUDA device: {torch.cuda.get_device_name()}")
    else:
        device = torch.device("cpu")
        print("⚠️  Using CPU device (processing will be slower)")
    
    try:
        from sam2.build_sam import build_sam2_video_predictor
        
        # Use default config for base_plus model
        model_cfg = "sam2_hiera_b+.yaml"
        
        # Build predictor
        predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=device)
        print("✓ SAM2 model initialized successfully!")
        return predictor, device
        
    except Exception as e:
        print(f"❌ Error initializing SAM2: {e}")
        return None, device

# Initialize the model
predictor, device = setup_sam2_model()

## 📁 Storage Configuration

In [None]:
# Setup storage paths
if IN_COLAB:
    base_path = '/content/drive/MyDrive/RIA_segmentation'
    input_videos_dir = f'{base_path}/input_videos'
    output_dir = f'{base_path}/output'
    temp_dir = '/content/temp_processing'
else:
    base_path = './RIA_segmentation'
    input_videos_dir = f'{base_path}/input_videos'
    output_dir = f'{base_path}/output'
    temp_dir = './temp_processing'

# Create directories
os.makedirs(input_videos_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(temp_dir, exist_ok=True)

print(f"📂 Storage setup complete:")
print(f"   Input videos: {input_videos_dir}")
print(f"   Output: {output_dir}")
print(f"   Temp: {temp_dir}")

# Check for existing videos
if os.path.exists(input_videos_dir):
    available_videos = [d for d in os.listdir(input_videos_dir) 
                       if os.path.isdir(os.path.join(input_videos_dir, d))]
    print(f"\n📹 Found {len(available_videos)} video directories:")
    for video in available_videos[:5]:  # Show first 5
        print(f"   - {video}")
    if len(available_videos) > 5:
        print(f"   ... and {len(available_videos) - 5} more")
else:
    available_videos = []
    print("\n⚠️  No video directories found. Please upload your data to Google Drive.")

## 🎛️ Interactive Widget Classes

In [None]:
class ColabBboxCollector:
    """Interactive bounding box collection for Colab environment."""
    
    def __init__(self, image_path, object_ids=None):
        self.image_path = image_path
        self.image = Image.open(image_path)
        self.object_ids = object_ids or [1, 2]
        self.bboxes = {}
        self.current_obj = self.object_ids[0]
        self.finished = False
        self.skipped = False
        
        # Object names
        self.object_names = {1: 'nrD (dorsal)', 2: 'nrV (ventral)'}
        
        self.setup_widgets()
    
    def setup_widgets(self):
        """Create IPython widgets for bbox collection."""
        
        # Object selector
        self.obj_selector = widgets.Dropdown(
            options=[(self.object_names.get(obj_id, f'Object {obj_id}'), obj_id) 
                    for obj_id in self.object_ids],
            value=self.current_obj,
            description='Target Object:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )
        
        # Coordinate inputs with better styling
        coord_style = {'description_width': '40px'}
        coord_layout = widgets.Layout(width='120px')
        
        self.x1_input = widgets.IntText(value=50, description='X1:', style=coord_style, layout=coord_layout)
        self.y1_input = widgets.IntText(value=50, description='Y1:', style=coord_style, layout=coord_layout)
        self.x2_input = widgets.IntText(value=150, description='X2:', style=coord_style, layout=coord_layout)
        self.y2_input = widgets.IntText(value=150, description='Y2:', style=coord_style, layout=coord_layout)
        
        # Buttons with styling
        self.save_btn = widgets.Button(
            description='💾 Save Bbox', 
            button_style='success',
            layout=widgets.Layout(width='120px')
        )
        self.clear_btn = widgets.Button(
            description='🗑️ Clear All', 
            button_style='warning',
            layout=widgets.Layout(width='120px')
        )
        self.finish_btn = widgets.Button(
            description='✅ Finish', 
            button_style='primary',
            layout=widgets.Layout(width='120px')
        )
        self.skip_btn = widgets.Button(
            description='⏭️ Skip Video', 
            button_style='danger',
            layout=widgets.Layout(width='120px')
        )
        
        # Status output
        self.output = widgets.Output()
        
        # Event handlers
        self.obj_selector.observe(self.on_object_change, names='value')
        self.save_btn.on_click(self.save_bbox)
        self.clear_btn.on_click(self.clear_all)
        self.finish_btn.on_click(self.finish)
        self.skip_btn.on_click(self.skip)
        
        # Layout
        coord_box = widgets.HBox([
            self.x1_input, self.y1_input, self.x2_input, self.y2_input
        ], layout=widgets.Layout(justify_content='space-around'))
        
        button_box = widgets.HBox([
            self.save_btn, self.clear_btn, self.finish_btn, self.skip_btn
        ], layout=widgets.Layout(justify_content='space-around'))
        
        instructions_html = widgets.HTML(
            value="""
            <div style='background-color: #f0f7ff; padding: 15px; border-radius: 5px; border-left: 4px solid #007acc;'>
                <h3>📋 Instructions:</h3>
                <ol>
                    <li>Select the target object (nrD or nrV) from the dropdown</li>
                    <li>Look at the reference image below and note the coordinates</li>
                    <li>Set the bounding box coordinates: (X1,Y1) = top-left, (X2,Y2) = bottom-right</li>
                    <li>Click 'Save Bbox' to store the bounding box</li>
                    <li>Repeat for other objects if needed</li>
                    <li>Click 'Finish' when done, or 'Skip Video' to skip this video</li>
                </ol>
                <p><strong>Note:</strong> Ensure X1 &lt; X2 and Y1 &lt; Y2 for valid coordinates.</p>
            </div>
            """
        )
        
        self.widget_box = widgets.VBox([
            widgets.HTML("<h2>🎯 Bounding Box Collection</h2>"),
            instructions_html,
            widgets.HTML("<h4>Object Selection:</h4>"),
            self.obj_selector,
            widgets.HTML("<h4>Coordinates (pixels):</h4>"),
            coord_box,
            widgets.HTML("<h4>Actions:</h4>"),
            button_box,
            widgets.HTML("<h4>Status:</h4>"),
            self.output
        ], layout=widgets.Layout(padding='20px'))
    
    def on_object_change(self, change):
        """Handle object selection change."""
        self.current_obj = change['new']
        with self.output:
            clear_output(wait=True)
            print(f"🎯 Selected: {self.object_names.get(self.current_obj, self.current_obj)}")
    
    def save_bbox(self, btn):
        """Save current bounding box."""
        x1, y1, x2, y2 = self.x1_input.value, self.y1_input.value, self.x2_input.value, self.y2_input.value
        
        # Validate coordinates
        if x1 >= x2 or y1 >= y2:
            with self.output:
                clear_output(wait=True)
                print("❌ Error: Invalid coordinates. Ensure X1 < X2 and Y1 < Y2")
            return
        
        # Validate within image bounds
        img_width, img_height = self.image.size
        if x1 < 0 or y1 < 0 or x2 > img_width or y2 > img_height:
            with self.output:
                clear_output(wait=True)
                print(f"⚠️  Warning: Coordinates outside image bounds (0,0) to ({img_width},{img_height})")
        
        # Save bbox
        self.bboxes[self.current_obj] = np.array([x1, y1, x2, y2], dtype=np.float32)
        
        with self.output:
            clear_output(wait=True)
            obj_name = self.object_names.get(self.current_obj, f'Object {self.current_obj}')
            print(f"✅ Saved bbox for {obj_name}: [{x1}, {y1}, {x2}, {y2}]")
            print(f"📊 Total bboxes saved: {len(self.bboxes)}")
            
            # List all saved bboxes
            for obj_id, bbox in self.bboxes.items():
                name = self.object_names.get(obj_id, f'Object {obj_id}')
                print(f"   • {name}: {bbox.astype(int)}")
    
    def clear_all(self, btn):
        """Clear all bounding boxes."""
        self.bboxes = {}
        with self.output:
            clear_output(wait=True)
            print("🗑️ Cleared all bounding boxes")
    
    def finish(self, btn):
        """Finish bbox collection."""
        self.finished = True
        with self.output:
            clear_output(wait=True)
            print(f"✅ Finished! Collected {len(self.bboxes)} bounding boxes")
            print("Proceeding to video processing...")
    
    def skip(self, btn):
        """Skip this video."""
        self.skipped = True
        with self.output:
            clear_output(wait=True)
            print("⏭️ Skipping this video")
    
    def show_image_with_bboxes(self):
        """Display the reference image with any existing bounding boxes."""
        fig, ax = plt.subplots(figsize=(14, 10))
        ax.imshow(self.image)
        ax.set_title("Reference Image - Use coordinates from this image", fontsize=16, pad=20)
        
        # Add coordinate annotations
        height, width = self.image.size[1], self.image.size[0]
        ax.set_xlim(0, width)
        ax.set_ylim(height, 0)  # Invert y-axis to match image coordinates
        
        # Add grid for reference
        ax.grid(True, alpha=0.3)
        ax.set_xlabel('X coordinate (pixels)', fontsize=12)
        ax.set_ylabel('Y coordinate (pixels)', fontsize=12)
        
        # Add existing bboxes
        colors = {1: 'red', 2: 'blue'}
        for obj_id, bbox in self.bboxes.items():
            x1, y1, x2, y2 = bbox
            color = colors.get(obj_id, 'green')
            rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
                               linewidth=3, 
                               edgecolor=color,
                               facecolor='none',
                               label=self.object_names.get(obj_id, f'Object {obj_id}'))
            ax.add_patch(rect)
            
            # Add text annotation
            ax.text(x1, y1-10, f'{self.object_names.get(obj_id, f"Obj {obj_id}")}\n[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]',
                   color=color, fontweight='bold', fontsize=10,
                   bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
        
        if self.bboxes:
            ax.legend(loc='upper right', fontsize=12)
        
        plt.tight_layout()
        plt.show()

print("✓ Bounding box collector class defined")

## ⚙️ Processing Functions

In [None]:
def get_video_frames(video_dir):
    """Get sorted list of frame files from video directory."""
    frame_files = [f for f in os.listdir(video_dir) 
                  if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    frame_files.sort(key=lambda x: int(os.path.splitext(x)[0]))
    return frame_files

def process_video_with_sam2(video_dir, predictor, prompts):
    """Process video with SAM2 using provided prompts."""
    
    # Initialize inference state
    inference_state = predictor.init_state(video_path=video_dir)
    
    # Add prompts to first frame
    frame_idx = 0
    for obj_id, prompt_data in prompts.items():
        if 'bbox' in prompt_data:
            # Bounding box prompt
            bbox = prompt_data['bbox']
            _, out_obj_ids, out_mask_logits = predictor.add_new_points(
                inference_state=inference_state,
                frame_idx=frame_idx,
                obj_id=obj_id,
                points=None,
                labels=None,
                clear_old_points=True,
                bbox=bbox
            )
        elif 'points' in prompt_data:
            # Point prompts
            points, labels = prompt_data['points']
            _, out_obj_ids, out_mask_logits = predictor.add_new_points(
                inference_state=inference_state,
                frame_idx=frame_idx,
                obj_id=obj_id,
                points=points,
                labels=labels,
                clear_old_points=True
            )
    
    # Propagate masks through video
    video_segments = {}
    
    print("🎬 Propagating masks through video...")
    for out_frame_idx, out_obj_ids, out_mask_logits in tqdm(
        predictor.propagate_in_video(inference_state), 
        desc="Processing frames"
    ):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }
    
    return video_segments

def save_results_h5(video_segments, output_path):
    """Save segmentation results to HDF5 file."""
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with h5py.File(output_path, 'w') as f:
        # Metadata
        f.attrs['num_frames'] = len(video_segments)
        if video_segments:
            first_frame = list(video_segments.keys())[0]
            f.attrs['object_ids'] = list(video_segments[first_frame].keys())
        
        # Create groups
        masks_group = f.create_group('masks')
        
        # Save masks
        sorted_frames = sorted(video_segments.keys())
        for frame_idx in sorted_frames:
            frame_group = masks_group.create_group(f'frame_{frame_idx:06d}')
            
            for obj_id, mask in video_segments[frame_idx].items():
                frame_group.create_dataset(
                    f'object_{obj_id}', 
                    data=mask.astype(bool),
                    compression='gzip'
                )
    
    print(f"💾 Results saved to {output_path}")

def create_preview_video(video_dir, video_segments, output_path, fps=10):
    """Create preview video with mask overlays."""
    
    frame_files = get_video_frames(video_dir)
    if not frame_files:
        print("No frames found for preview")
        return
    
    # Read first frame to get dimensions
    first_frame_path = os.path.join(video_dir, frame_files[0])
    first_frame = cv2.imread(first_frame_path)
    height, width = first_frame.shape[:2]
    
    # Video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    # Colors for objects
    colors = {1: (255, 0, 0), 2: (0, 0, 255)}  # Red for nrD, Blue for nrV
    
    print("🎥 Creating preview video...")
    for frame_idx, frame_file in enumerate(tqdm(frame_files, desc="Rendering frames")):
        frame_path = os.path.join(video_dir, frame_file)
        frame = cv2.imread(frame_path)
        
        # Add mask overlays if available
        if frame_idx in video_segments:
            overlay = np.zeros_like(frame)
            
            for obj_id, mask in video_segments[frame_idx].items():
                if obj_id in colors:
                    color = colors[obj_id]
                    mask_3d = np.stack([mask.squeeze()] * 3, axis=-1)
                    colored_mask = mask_3d * np.array(color)
                    overlay = cv2.addWeighted(overlay, 1, colored_mask.astype(np.uint8), 0.6, 0)
            
            frame = cv2.addWeighted(frame, 1, overlay, 0.4, 0)
        
        out.write(frame)
    
    out.release()
    print(f"🎬 Preview video saved to {output_path}")

print("✓ Processing functions defined")

## 🚀 Main Processing Pipeline

In [None]:
# Video selection interface
def create_video_selector():
    """Create interactive video selection widget."""
    
    # Refresh video list
    available_videos = []
    if os.path.exists(input_videos_dir):
        available_videos = [d for d in os.listdir(input_videos_dir) 
                           if os.path.isdir(os.path.join(input_videos_dir, d))]
    
    if not available_videos:
        print("⚠️  No video directories found!")
        print(f"Please upload video directories to: {input_videos_dir}")
        return None, None, None
    
    # Create widgets
    video_selector = widgets.SelectMultiple(
        options=available_videos,
        description='Select videos to process:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(height='200px', width='400px')
    )
    
    process_btn = widgets.Button(
        description='🚀 Start Processing', 
        button_style='primary',
        layout=widgets.Layout(width='200px', height='50px')
    )
    
    output_widget = widgets.Output()
    
    # Interface layout
    interface = widgets.VBox([
        widgets.HTML("<h2>📹 Video Selection</h2>"),
        widgets.HTML(f"<p>Found <strong>{len(available_videos)}</strong> video directories in your storage.</p>"),
        widgets.HTML("<p>Hold <kbd>Ctrl</kbd> (or <kbd>Cmd</kbd>) to select multiple videos.</p>"),
        video_selector,
        process_btn,
        output_widget
    ], layout=widgets.Layout(padding='20px'))
    
    return video_selector, process_btn, output_widget, interface

# Create the interface
if available_videos:
    video_selector, process_btn, output_widget, interface = create_video_selector()
    display(interface)
else:
    print("❌ Cannot create interface - no videos found")
    video_selector = None

In [None]:
def process_single_video(video_name):
    """Process a single video with interactive prompting."""
    
    print(f"\n{'='*60}")
    print(f"🎬 Processing video: {video_name}")
    print(f"{'='*60}")
    
    video_dir = os.path.join(input_videos_dir, video_name)
    
    # Get frame files
    frame_files = get_video_frames(video_dir)
    if not frame_files:
        print(f"❌ No frames found in {video_dir}")
        return None
    
    print(f"📊 Found {len(frame_files)} frames")
    
    # Get first frame for prompting
    first_frame_path = os.path.join(video_dir, frame_files[0])
    
    # Collect bounding box prompts
    print("\n🎯 Collecting bounding box prompts...")
    bbox_collector = ColabBboxCollector(first_frame_path, object_ids=[1, 2])
    
    # Show the image first
    bbox_collector.show_image_with_bboxes()
    
    # Display the widget interface
    display(bbox_collector.widget_box)
    
    # Wait for user to finish (in a real interactive environment)
    import time
    print("⏳ Waiting for user input...")
    print("Use the widgets above to set bounding boxes, then click 'Finish'")
    
    # This would normally wait for user interaction
    # For demonstration, we'll simulate some bounding boxes
    return {
        'video_name': video_name,
        'status': 'pending_user_input',
        'collector': bbox_collector
    }

def complete_video_processing(video_name, bboxes):
    """Complete the processing after user provides bounding boxes."""
    
    video_dir = os.path.join(input_videos_dir, video_name)
    
    if not bboxes:
        print(f"⏭️ No bounding boxes provided for {video_name}, skipping...")
        return None
    
    print(f"✅ Processing {video_name} with {len(bboxes)} bounding boxes")
    
    # Convert bboxes to prompts
    prompts = {}
    for obj_id, bbox in bboxes.items():
        prompts[obj_id] = {'bbox': bbox}
    
    try:
        # Process video with SAM2
        video_segments = process_video_with_sam2(video_dir, predictor, prompts)
        
        # Save results
        output_h5_path = os.path.join(output_dir, f"{video_name}_segments.h5")
        save_results_h5(video_segments, output_h5_path)
        
        # Create preview video
        preview_path = os.path.join(output_dir, f"{video_name}_preview.mp4")
        create_preview_video(video_dir, video_segments, preview_path)
        
        result = {
            'video_name': video_name,
            'status': 'success',
            'num_frames': len(video_segments),
            'num_objects': len(prompts),
            'h5_path': output_h5_path,
            'preview_path': preview_path
        }
        
        print(f"✅ Successfully processed {video_name}")
        print(f"   📊 Frames: {result['num_frames']}")
        print(f"   🎯 Objects: {result['num_objects']}")
        print(f"   💾 Output: {output_h5_path}")
        print(f"   🎬 Preview: {preview_path}")
        
        return result
        
    except Exception as e:
        print(f"❌ Error processing {video_name}: {str(e)}")
        return {
            'video_name': video_name,
            'status': 'error',
            'error': str(e)
        }

# Define the main processing function
def process_selected_videos():
    """Process all selected videos."""
    if video_selector is None:
        print("❌ No video selector available")
        return
    
    selected_videos = list(video_selector.value)
    
    if not selected_videos:
        print("⚠️  No videos selected")
        return
    
    print(f"🚀 Starting processing of {len(selected_videos)} videos")
    
    results = []
    
    for i, video_name in enumerate(selected_videos, 1):
        print(f"\n--- Processing video {i}/{len(selected_videos)}: {video_name} ---")
        
        # Check if already processed
        output_path = os.path.join(output_dir, f"{video_name}_segments.h5")
        if os.path.exists(output_path):
            print(f"⚠️  {video_name} already processed, skipping...")
            continue
        
        result = process_single_video(video_name)
        if result:
            results.append(result)
    
    return results

# Connect the button to the processing function
if video_selector is not None:
    def on_process_click(btn):
        with output_widget:
            clear_output(wait=True)
            process_selected_videos()
    
    process_btn.on_click(on_process_click)
    
    print("✅ Processing interface ready!")
    print("\n📋 Next steps:")
    print("1. Select one or more videos from the list above")
    print("2. Click 'Start Processing' button")
    print("3. For each video, use the bounding box interface to mark objects")
    print("4. Results will be saved to your Google Drive")
else:
    print("❌ Cannot setup processing - video selector not available")

## 📊 Results and Download

In [None]:
# Check processed results
def view_processing_results():
    """Display summary of processed videos."""
    
    if not os.path.exists(output_dir):
        print("📂 No output directory found")
        return
    
    # Find all H5 files
    h5_files = [f for f in os.listdir(output_dir) if f.endswith('_segments.h5')]
    mp4_files = [f for f in os.listdir(output_dir) if f.endswith('_preview.mp4')]
    
    print(f"📊 Processing Results Summary")
    print(f"{'='*50}")
    print(f"📁 Output directory: {output_dir}")
    print(f"🗃️  Segmentation files: {len(h5_files)}")
    print(f"🎬 Preview videos: {len(mp4_files)}")
    
    if h5_files:
        print(f"\n📋 Processed Videos:")
        for i, h5_file in enumerate(h5_files, 1):
            video_name = h5_file.replace('_segments.h5', '')
            h5_path = os.path.join(output_dir, h5_file)
            
            # Try to read metadata
            try:
                with h5py.File(h5_path, 'r') as f:
                    num_frames = f.attrs.get('num_frames', 'Unknown')
                    object_ids = f.attrs.get('object_ids', [])
                    file_size = os.path.getsize(h5_path) / (1024*1024)  # MB
                    
                print(f"   {i}. {video_name}")
                print(f"      📊 Frames: {num_frames}")
                print(f"      🎯 Objects: {len(object_ids)} {list(object_ids)}")
                print(f"      💾 Size: {file_size:.1f} MB")
                
            except Exception as e:
                print(f"   {i}. {video_name} (Error reading file: {e})")
    else:
        print("\n⚠️  No processed videos found")
        print("Run the processing pipeline above to generate results.")

# View current results
view_processing_results()

In [None]:
# Download results (for Colab)
def download_all_results():
    """Create and download a zip file with all results."""
    
    if not IN_COLAB:
        print("Not in Colab - results are already local")
        return
    
    import zipfile
    from google.colab import files
    
    if not os.path.exists(output_dir) or not os.listdir(output_dir):
        print("❌ No results found to download")
        return
    
    zip_path = '/content/ria_segmentation_results.zip'
    
    print("📦 Creating results archive...")
    
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(output_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, output_dir)
                zipf.write(file_path, arcname)
                print(f"   📄 Added: {arcname}")
    
    # Get zip file size
    zip_size = os.path.getsize(zip_path) / (1024*1024)  # MB
    print(f"\n✅ Archive created: {zip_size:.1f} MB")
    
    print("⬇️  Starting download...")
    files.download(zip_path)
    print("✅ Download complete!")

# Create download button
if IN_COLAB:
    download_btn = widgets.Button(
        description='⬇️ Download Results', 
        button_style='success',
        layout=widgets.Layout(width='200px', height='40px')
    )
    
    def on_download_click(btn):
        download_all_results()
    
    download_btn.on_click(on_download_click)
    
    download_interface = widgets.VBox([
        widgets.HTML("<h3>📦 Download Results</h3>"),
        widgets.HTML("<p>Click the button below to download all segmentation results as a ZIP file.</p>"),
        download_btn
    ], layout=widgets.Layout(padding='20px'))
    
    display(download_interface)
else:
    print("💾 Results are stored locally in:", output_dir)

## 🛠️ Manual Processing (Advanced)

In [None]:
# Manual processing of a single video (for testing/debugging)
def manual_process_video(video_name, bbox_coords=None):
    """
    Manually process a single video with predefined bounding boxes.
    
    Args:
        video_name (str): Name of the video directory
        bbox_coords (dict): Dictionary with bounding box coordinates
                           e.g., {1: [x1, y1, x2, y2], 2: [x1, y1, x2, y2]}
    """
    
    if not available_videos or video_name not in available_videos:
        print(f"❌ Video '{video_name}' not found in available videos")
        print(f"Available: {available_videos}")
        return
    
    if bbox_coords is None:
        print("⚠️  No bounding box coordinates provided")
        print("Example usage:")
        print("manual_process_video('video_name', {1: [50, 50, 150, 150], 2: [200, 50, 300, 150]})")
        return
    
    print(f"🎬 Manually processing: {video_name}")
    print(f"📦 Bounding boxes: {bbox_coords}")
    
    # Convert to numpy arrays
    bboxes = {}
    for obj_id, coords in bbox_coords.items():
        bboxes[obj_id] = np.array(coords, dtype=np.float32)
    
    # Process the video
    result = complete_video_processing(video_name, bboxes)
    
    return result

# Example usage (uncomment and modify as needed):
# result = manual_process_video('your_video_name', {
#     1: [100, 100, 200, 200],  # nrD bounding box
#     2: [300, 100, 400, 200]   # nrV bounding box
# })

print("✓ Manual processing function available")
print("Use manual_process_video('video_name', bbox_dict) for direct processing")

## 🔧 Troubleshooting & Tips

### Common Issues:

1. **"No videos found"**
   - Check that your video directories are uploaded to `/content/drive/MyDrive/RIA_segmentation/input_videos/`
   - Each video should be in its own subdirectory containing JPG frames

2. **"Model loading failed"**
   - Restart the runtime and re-run the setup cells
   - Check that you have sufficient GPU memory

3. **"Processing too slow"**
   - Ensure you're using GPU runtime (Runtime → Change runtime type → GPU)
   - Consider processing fewer frames or smaller videos

4. **"Invalid coordinates"**
   - Make sure X1 < X2 and Y1 < Y2
   - Check that coordinates are within image bounds

### Data Format:
- Videos should be directories containing sequential JPG frames
- Frame names should be numeric (e.g., 000000.jpg, 000001.jpg, ...)
- Recommended frame size: 512x512 to 1024x1024 pixels

### Performance:
- GPU processing: ~1-5 fps depending on video size
- CPU processing: ~0.1-0.5 fps (much slower)
- Memory usage: ~2-8GB depending on video length and resolution