# BananaTracker - Multi-Object Tracking with SAM2.1 Segmentation

This notebook demonstrates how to use BananaTracker for multi-object tracking with:
- **YOLOv8** detection
- **ByteTrack-based** tracking
- **SAM2.1** mask generation and temporal propagation

## Cell 1: Install Dependencies

In [None]:
# Install required packages
!pip install ultralytics opencv-python-headless tqdm
!pip install lap cython_bbox  # For ByteTrack tracker core

# SAM2.1 dependencies
!pip install hydra-core omegaconf

## Cell 2: Clone Repositories and Install

In [None]:
# Clone the BananaTracker repository
!git clone https://github.com/USER/bananatracker.git
%cd bananatracker

# Install in development mode
!pip install -e .

# Clone SAM2.1 real-time repository
%cd ..
!git clone https://github.com/USER/segment-anything-2-real-time.git
%cd segment-anything-2-real-time

# Download SAM2.1 checkpoints
%cd checkpoints
!bash download_ckpts.sh
%cd ../..

## Cell 3: Configuration

Configure the tracker with your model weights and SAM2.1 settings.

### SAM2.1 Model Options:
| Model | Checkpoint | Config | Quality | Speed |
|-------|------------|--------|---------|-------|
| Large (Recommended) | `sam2.1_hiera_large.pt` | `sam2.1_hiera_l.yaml` | Best | Slower |
| Base+ | `sam2.1_hiera_base_plus.pt` | `sam2.1_hiera_b+.yaml` | Good | Medium |
| Small | `sam2.1_hiera_small.pt` | `sam2.1_hiera_s.yaml` | Fair | Fast |
| Tiny | `sam2.1_hiera_tiny.pt` | `sam2.1_hiera_t.yaml` | Basic | Fastest |

In [None]:
from bananatracker import BananaTrackerConfig

# =============================================================================
# SAM2.1 MODEL CONFIGURATION - Modify these to change the segmentation model
# =============================================================================
SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_large.pt"  # Path to SAM2.1 checkpoint
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"     # Path to SAM2.1 config
SAM2_REPO_PATH = "/content/segment-anything-2-real-time"  # Path to SAM2.1 repo

# =============================================================================
# TRACKER CONFIGURATION
# =============================================================================
config = BananaTrackerConfig(
    # Detection
    yolo_weights="/content/HockeyAI_model_weight.pt",  # Update with your model path
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],  # Goaltender, Player, Puck, Referee
    special_classes=[5],          # Puck - max-conf only
    detection_conf_thresh=0.5,    # General confidence threshold
    detection_iou_thresh=0.7,     # IoU threshold for YOLO NMS

    # Post-processing: Centroid-based deduplication
    centroid_dedup_enabled=True,     # Remove duplicate boxes for same player
    centroid_dedup_max_distance=36,  # Max pixel distance to consider duplicates

    # Tracker
    track_thresh=0.6,
    track_buffer=30,
    cmc_method="orb",  # Options: "orb", "ecc", "sift", "sparseOptFlow", "none"

    # SAM2.1 Mask Propagation
    sam2_enabled=True,                    # Enable mask generation with SAM2.1
    sam2_checkpoint=SAM2_CHECKPOINT,      # SAM2.1 model checkpoint
    sam2_config=SAM2_CONFIG,              # SAM2.1 config yaml
    sam2_repo_path=SAM2_REPO_PATH,        # Path to segment-anything-2-real-time repo
    mask_start_frame=1,                   # Frame to start mask creation (1-indexed)
    mask_overlap_threshold=0.6,           # Skip mask creation for heavily overlapping bboxes
    mask_alpha=0.4,                       # Mask overlay transparency for visualization

    # Visualization
    class_colors={
        "Goaltender": (255, 165, 0),   # Orange
        "Player": (255, 0, 0),          # Blue (BGR)
        "Puck": (0, 255, 0),            # Green
        "Referee": (0, 0, 255),         # Red
    },
    show_track_id=True,
    line_thickness=2,

    # Output
    output_video_path="/content/output_tracked.mp4",
    output_txt_path="/content/results.txt",
    device="cuda:0",
)

print("Configuration created!")
print(f"\nDetection:")
print(f"  - Confidence threshold: {config.detection_conf_thresh}")
print(f"  - IoU threshold: {config.detection_iou_thresh}")
print(f"  - Tracking classes: {config.track_classes}")
print(f"  - Special classes (max-conf only): {config.special_classes}")
print(f"\nSAM2.1 Mask Propagation:")
print(f"  - Enabled: {config.sam2_enabled}")
print(f"  - Checkpoint: {config.sam2_checkpoint}")
print(f"  - Config: {config.sam2_config}")
print(f"  - Mask alpha: {config.mask_alpha}")

## Cell 4: Load Model and Setup Pipeline

In [None]:
from bananatracker import BananaTrackerPipeline

# Initialize the pipeline
pipeline = BananaTrackerPipeline(config)

print("Pipeline initialized!")
print(f"Detector: YOLOv8")
print(f"Tracker: BananaTracker (ByteTrack-based)")
print(f"Mask Manager: SAM2.1 Camera Predictor")
print(f"CMC Method: {config.cmc_method}")

## Cell 5: Run Tracking with Mask Propagation

In [None]:
# Path to input video
INPUT_VIDEO = "/content/sample_video.mp4"  # Update with your video path

# Run tracking with SAM2.1 mask propagation
print(f"Processing video: {INPUT_VIDEO}")
print(f"SAM2.1 mask propagation: {'Enabled' if config.sam2_enabled else 'Disabled'}")

all_tracks = pipeline.process_video(INPUT_VIDEO)

print(f"\nProcessed {len(all_tracks)} frames")
print(f"Output video: {config.output_video_path}")
print(f"MOT results: {config.output_txt_path}")

## Cell 6: Validation - Show Mask on Random Frame

This cell validates the SAM2.1 integration by displaying masks on a random frame from the video.

In [None]:
import cv2
import numpy as np
import random
from IPython.display import display, Image
import matplotlib.pyplot as plt

# Reload video and process a random frame with masks
cap = cv2.VideoCapture(INPUT_VIDEO)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

# Select a random frame (after mask initialization)
random_frame_idx = random.randint(max(10, config.mask_start_frame + 5), min(total_frames - 1, 100))
print(f"Selected random frame: {random_frame_idx}")

# Reset pipeline for fresh processing
pipeline.tracker.reset()
if pipeline.mask_manager:
    pipeline.mask_manager.reset()

# Process frames up to random frame
prev_frame = None
frame_id = 0
result_frame = None
result_mask = None

while frame_id <= random_frame_idx:
    ret, frame = cap.read()
    if not ret:
        break
    
    frame_id += 1
    
    # Detect
    detections = pipeline.detector.detect(frame)
    
    # Track
    height, width = frame.shape[:2]
    tracks, removed_ids, new_tracks = pipeline.tracker.update(
        detections_array=detections,
        img_info=(height, width),
        frame_img=frame
    )
    
    # Update masks
    mask = None
    tracklet_mask_dict = {}
    
    if pipeline.mask_manager:
        online_tlwhs = [t.tlwh.tolist() for t in tracks]
        online_ids = [t.track_id for t in tracks]
        
        mask, tracklet_mask_dict, _, mask_colors = (
            pipeline.mask_manager.get_updated_masks(
                frame=frame,
                frame_prev=prev_frame,
                frame_id=frame_id,
                online_tlwhs=online_tlwhs,
                online_ids=online_ids,
                new_tracks=new_tracks,
                removed_tracks_ids=removed_ids,
            )
        )
        
        if mask_colors is not None:
            mask = mask_colors
    
    # Store for visualization
    if frame_id == random_frame_idx:
        result_frame = frame.copy()
        result_mask = mask
        result_tracks = tracks
        result_tracklet_mask_dict = tracklet_mask_dict
    
    prev_frame = frame.copy()

cap.release()

# Visualize results
if result_frame is not None:
    # Create visualization
    vis_frame = pipeline.visualizer.draw_tracks_with_masks(
        result_frame, result_tracks, result_mask, result_tracklet_mask_dict
    )
    
    # Display
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Original frame
    axes[0].imshow(cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB))
    axes[0].set_title(f"Original Frame {random_frame_idx}")
    axes[0].axis('off')
    
    # Mask only
    if result_mask is not None:
        axes[1].imshow(result_mask, cmap='nipy_spectral')
        axes[1].set_title(f"SAM2.1 Mask ({len(np.unique(result_mask)) - 1} objects)")
    else:
        axes[1].text(0.5, 0.5, 'No mask available', ha='center', va='center')
        axes[1].set_title("SAM2.1 Mask")
    axes[1].axis('off')
    
    # Combined visualization
    axes[2].imshow(cv2.cvtColor(vis_frame, cv2.COLOR_BGR2RGB))
    axes[2].set_title(f"Tracking + Masks ({len(result_tracks)} tracks)")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\nFrame {random_frame_idx} Statistics:")
    print(f"  - Active tracks: {len(result_tracks)}")
    print(f"  - Track IDs: {[t.track_id for t in result_tracks]}")
    if result_mask is not None:
        unique_masks = np.unique(result_mask)
        print(f"  - Unique mask IDs: {unique_masks.tolist()}")
        print(f"  - Tracklet-to-mask mapping: {result_tracklet_mask_dict}")
else:
    print("Could not process frame")

## Cell 7: Compress Output Video

In [None]:
%%capture
# Compress video for notebook display
OUTPUT_COMPRESSED = "/content/output_compressed.mp4"
!ffmpeg -y -i {config.output_video_path} -vcodec libx264 -crf 28 {OUTPUT_COMPRESSED}

print(f"Compressed video saved to: {OUTPUT_COMPRESSED}")

## Cell 8: Display Video in Notebook

In [None]:
from IPython.display import HTML
from base64 import b64encode

OUTPUT_COMPRESSED = "/content/output_compressed.mp4"

# Read and encode video
mp4 = open(OUTPUT_COMPRESSED, 'rb').read()
data_url = f"data:video/mp4;base64,{b64encode(mp4).decode()}"

# Display video
HTML(f'''
<video width="800" controls>
  <source src="{data_url}" type="video/mp4">
</video>
''')

## Optional: Frame-by-Frame Processing with Masks

For more control, you can process frames individually using the generator API which now includes mask data.

In [None]:
# Example: Process frame-by-frame with generator
# Uncomment to run

# from bananatracker import BananaTrackerPipeline
# 
# pipeline = BananaTrackerPipeline(config)
# 
# for frame_id, frame, tracks, vis_frame, mask, tracklet_mask_dict in pipeline.process_video_generator(INPUT_VIDEO):
#     # Get track info as dictionaries
#     track_info = pipeline.get_track_info(tracks)
#     
#     # Process each track
#     for info in track_info:
#         print(f"Frame {frame_id}: Track {info['track_id']} - {info['class_name']} at {info['bbox']}")
#     
#     # Access mask data
#     if mask is not None:
#         print(f"  Mask has {len(np.unique(mask)) - 1} segmented objects")
#     
#     # Stop after 10 frames for demo
#     if frame_id >= 10:
#         break

## Configuration Reference

### SAM2.1 Model Options

```python
# Large model (best quality, recommended for accuracy)
SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_large.pt"
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"

# Base+ model (balanced quality/speed)
SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_base_plus.pt"
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_b+.yaml"

# Small model (faster, lower quality)
SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_small.pt"
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_s.yaml"

# Tiny model (fastest, basic quality)
SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_tiny.pt"
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_t.yaml"
```

### Key Configuration Options

| Parameter | Description | Default |
|-----------|-------------|--------|
| `sam2_enabled` | Enable/disable SAM2.1 mask propagation | `True` |
| `mask_start_frame` | Frame to start mask creation | `1` |
| `mask_overlap_threshold` | Skip masks for heavily overlapping bboxes | `0.6` |
| `mask_alpha` | Mask overlay transparency | `0.4` |