# BananaTracker - Multi-Object Tracking with SAM2.1 + Cutie

This notebook demonstrates the complete MOT tracking pipeline using:
- **YOLOv8** for object detection
- **ByteTrack-based tracker** for multi-object tracking
- **SAM2.1** for high-quality mask generation from bounding boxes
- **Cutie** for temporal mask propagation

The mask module enhances tracking by providing pixel-level precision that improves association when objects are close together.

## Cell 1: Install Dependencies

In [None]:
# Install required packages
!pip install ultralytics opencv-python-headless tqdm
!pip install lap cython_bbox  # For ByteTrack tracker core

# Install SAM2.1 dependencies (HuggingFace transformers)
!pip install transformers>=4.35.0 huggingface_hub

# Install Cutie dependencies
!pip install omegaconf hydra-core

## Cell 2: Clone Repositories

In [None]:
import os

# Clone BananaTracker repository
if not os.path.exists('bananatracker'):
    !git clone https://github.com/USER/bananatracker.git

# Clone Cutie for temporal mask propagation
if not os.path.exists('Cutie'):
    !git clone https://github.com/hkchengrex/Cutie.git

# Create symlink for Cutie in mask_propagation folder
os.makedirs('bananatracker/bananatracker/mask_propagation', exist_ok=True)
if not os.path.exists('bananatracker/bananatracker/mask_propagation/Cutie'):
    os.symlink('/content/Cutie', 'bananatracker/bananatracker/mask_propagation/Cutie')

# Download Cutie weights
os.makedirs('Cutie/weights', exist_ok=True)
if not os.path.exists('Cutie/weights/cutie-base-mega.pth'):
    !wget -P Cutie/weights https://github.com/hkchengrex/Cutie/releases/download/v1.0/cutie-base-mega.pth

# Install BananaTracker in development mode
%cd bananatracker
!pip install -e .
%cd ..

## Cell 3: SAM2.1 Model Configuration

Configure SAM2.1 model settings. You can choose different model sizes:
- `facebook/sam2.1-hiera-tiny` - Fastest, lower quality
- `facebook/sam2.1-hiera-small` - Good balance
- `facebook/sam2.1-hiera-base-plus` - Better quality
- `facebook/sam2.1-hiera-large` - Best quality (recommended)

In [None]:
#@title SAM2.1 Configuration { display-mode: "form" }

# SAM2.1 Model Settings
SAM2_MODEL_ID = "facebook/sam2.1-hiera-large"  #@param ["facebook/sam2.1-hiera-tiny", "facebook/sam2.1-hiera-small", "facebook/sam2.1-hiera-base-plus", "facebook/sam2.1-hiera-large"]
SAM2_CHECKPOINT = ""  #@param {type:"string"} # Optional: local checkpoint path (leave empty to use HuggingFace)

# HuggingFace Token (required for gated models, optional for SAM2)
HF_TOKEN = ""  #@param {type:"string"}

# Cutie Weights Path
CUTIE_WEIGHTS = "/content/Cutie/weights/cutie-base-mega.pth"  #@param {type:"string"}

print(f"SAM2.1 Model: {SAM2_MODEL_ID}")
print(f"SAM2.1 Checkpoint: {SAM2_CHECKPOINT if SAM2_CHECKPOINT else 'Using HuggingFace download'}")
print(f"HF Token: {'Set' if HF_TOKEN else 'Not set (optional)'}")
print(f"Cutie Weights: {CUTIE_WEIGHTS}")

## Cell 4: Tracker Configuration

Configure the full tracking pipeline with detection, tracking, and mask settings.

In [None]:
import sys
sys.path.insert(0, '/content/bananatracker')
sys.path.insert(0, '/content/Cutie')

from bananatracker import BananaTrackerConfig

# Full configuration for hockey/sports tracking with mask enhancement
config = BananaTrackerConfig(
    # Detection Settings
    yolo_weights="/content/HockeyAI_model_weight.pt",  # Update with your model path
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],  # Goaltender, Player, Puck, Referee
    special_classes=[5],          # Puck - max-conf only
    detection_conf_thresh=0.5,    # General confidence threshold
    detection_iou_thresh=0.7,     # IoU threshold for YOLO NMS

    # Post-processing: Centroid-based deduplication
    centroid_dedup_enabled=True,
    centroid_dedup_max_distance=36,

    # Tracker Settings (ByteTrack)
    track_thresh=0.6,
    track_buffer=30,
    match_thresh=0.8,
    fps=30,
    cmc_method="orb",  # Camera motion compensation

    # Mask Module Settings (SAM2.1 + Cutie)
    enable_masks=True,
    sam2_model_id=SAM2_MODEL_ID,
    sam2_checkpoint=SAM2_CHECKPOINT if SAM2_CHECKPOINT else None,
    cutie_weights_path=CUTIE_WEIGHTS,
    hf_token=HF_TOKEN if HF_TOKEN else None,
    mask_start_frame=1,
    mask_bbox_overlap_threshold=0.6,

    # Visualization Settings
    class_colors={
        "Goaltender": (255, 165, 0),   # Orange
        "Player": (255, 0, 0),          # Blue (BGR)
        "Puck": (0, 255, 0),            # Green
        "Referee": (0, 0, 255),         # Red
    },
    show_track_id=True,
    show_masks=True,     # Enable mask overlay in visualization
    mask_alpha=0.5,      # Mask transparency
    line_thickness=2,

    # Output Settings
    output_video_path="/content/output_tracked_with_masks.mp4",
    output_txt_path="/content/results.txt",
    device="cuda:0",
)

print("Configuration created!")
print(f"\nDetection:")
print(f"  - Confidence threshold: {config.detection_conf_thresh}")
print(f"  - Tracking classes: {config.track_classes}")
print(f"\nMask Module:")
print(f"  - Enabled: {config.enable_masks}")
print(f"  - SAM2.1 Model: {config.sam2_model_id}")
print(f"  - Cutie Weights: {config.cutie_weights_path}")
print(f"\nVisualization:")
print(f"  - Show masks: {config.show_masks}")
print(f"  - Mask alpha: {config.mask_alpha}")

## Cell 5: Initialize Pipeline with SAM2.1 + Cutie

In [None]:
from bananatracker import BananaTrackerPipeline

# Initialize the pipeline (this will load SAM2.1 and Cutie models)
print("Initializing pipeline...")
print("This may take a moment to download SAM2.1 model from HuggingFace...")

pipeline = BananaTrackerPipeline(config)

print("\nPipeline initialized!")
print(f"  - Detector: YOLOv8")
print(f"  - Tracker: BananaTracker (ByteTrack-based with mask enhancement)")
print(f"  - Mask Module: {'SAM2.1 + Cutie' if pipeline.mask_manager else 'Disabled'}")
print(f"  - CMC Method: {config.cmc_method}")

## Cell 6: Validate SAM2.1 Mask Generation (Test Cell)

This cell validates that SAM2.1 is working correctly by generating masks on a random frame from the video.

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random

# Input video for testing
TEST_VIDEO = "/content/sample_video.mp4"  # Update with your video path

# Open video and get a random frame
cap = cv2.VideoCapture(TEST_VIDEO)
if not cap.isOpened():
    print(f"Error: Could not open video {TEST_VIDEO}")
else:
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    random_frame_idx = random.randint(0, total_frames - 1)
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, random_frame_idx)
    ret, frame = cap.read()
    cap.release()
    
    if ret:
        print(f"Testing SAM2.1 on frame {random_frame_idx} of {total_frames}")
        
        # Convert to RGB for SAM2.1
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Run detection to get bounding boxes
        detections = pipeline.detector.detect(frame)
        
        if len(detections) > 0:
            # Get bounding boxes (first 4 columns: x1, y1, x2, y2)
            boxes_xyxy = detections[:, :4].tolist()
            
            print(f"Detected {len(boxes_xyxy)} objects")
            
            # Generate masks using SAM2.1
            if pipeline.mask_manager is not None:
                masks = pipeline.mask_manager._sam2_predict_boxes(frame_rgb, boxes_xyxy)
                
                print(f"Generated {len(masks)} masks")
                print(f"Mask shape: {masks.shape}")
                
                # Visualize results
                fig, axes = plt.subplots(1, 3, figsize=(18, 6))
                
                # Original frame
                axes[0].imshow(frame_rgb)
                axes[0].set_title(f'Original Frame #{random_frame_idx}')
                axes[0].axis('off')
                
                # Frame with bounding boxes
                frame_with_boxes = frame_rgb.copy()
                for i, box in enumerate(boxes_xyxy):
                    x1, y1, x2, y2 = map(int, box)
                    cv2.rectangle(frame_with_boxes, (x1, y1), (x2, y2), (255, 0, 0), 2)
                    cv2.putText(frame_with_boxes, f'{i}', (x1, y1-10), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
                axes[1].imshow(frame_with_boxes)
                axes[1].set_title(f'Detected Objects ({len(boxes_xyxy)})')
                axes[1].axis('off')
                
                # Combined masks overlay
                colors = plt.cm.tab10(np.linspace(0, 1, 10))[:, :3] * 255
                combined_mask = np.zeros_like(frame_rgb, dtype=np.float32)
                for i, mask in enumerate(masks):
                    color = colors[i % len(colors)]
                    mask_3d = np.stack([mask, mask, mask], axis=-1)
                    combined_mask += mask_3d * color / 255
                
                alpha = 0.5
                frame_with_masks = frame_rgb.astype(np.float32)
                mask_overlay = np.clip(combined_mask, 0, 255)
                binary_mask = np.any(combined_mask > 0, axis=-1, keepdims=True)
                frame_with_masks = np.where(
                    binary_mask,
                    frame_with_masks * alpha + mask_overlay * (1 - alpha),
                    frame_with_masks
                )
                axes[2].imshow(frame_with_masks.astype(np.uint8))
                axes[2].set_title('SAM2.1 Mask Overlay')
                axes[2].axis('off')
                
                plt.tight_layout()
                plt.savefig('/content/sam2_mask_validation.png', dpi=150, bbox_inches='tight')
                plt.show()
                
                print("\n✓ SAM2.1 mask generation validated successfully!")
                print(f"  Validation image saved to: /content/sam2_mask_validation.png")
            else:
                print("Warning: Mask module not initialized")
        else:
            print("No objects detected in this frame. Try a different frame.")
    else:
        print(f"Error: Could not read frame {random_frame_idx}")

## Cell 7: Run Full Tracking Pipeline

In [None]:
# Path to input video
INPUT_VIDEO = "/content/sample_video.mp4"  # Update with your video path

# Run tracking with mask enhancement
print(f"Processing video: {INPUT_VIDEO}")
print(f"Mask module: {'Enabled' if config.enable_masks else 'Disabled'}")
print("")

all_tracks = pipeline.process_video(INPUT_VIDEO)

print(f"\nProcessed {len(all_tracks)} frames")
print(f"Output video: {config.output_video_path}")
print(f"MOT results: {config.output_txt_path}")

## Cell 8: Compress and Display Output Video

In [None]:
%%capture
# Compress video for notebook display
OUTPUT_COMPRESSED = "/content/output_compressed.mp4"
!ffmpeg -y -i {config.output_video_path} -vcodec libx264 -crf 28 {OUTPUT_COMPRESSED}

print(f"Compressed video saved to: {OUTPUT_COMPRESSED}")

In [None]:
from IPython.display import HTML
from base64 import b64encode

OUTPUT_COMPRESSED = "/content/output_compressed.mp4"

# Read and encode video
mp4 = open(OUTPUT_COMPRESSED, 'rb').read()
data_url = f"data:video/mp4;base64,{b64encode(mp4).decode()}"

# Display video
HTML(f'''
<video width="800" controls>
  <source src="{data_url}" type="video/mp4">
</video>
''')

## Cell 9: Frame-by-Frame Processing (Optional)

For more control, process frames individually using the generator API.

In [None]:
# Example: Process frame-by-frame with generator and access mask info
# Uncomment to run

# from bananatracker import BananaTrackerPipeline
# 
# pipeline = BananaTrackerPipeline(config)
# 
# for frame_id, frame, tracks, vis_frame in pipeline.process_video_generator(INPUT_VIDEO):
#     # Get track info as dictionaries
#     track_info = pipeline.get_track_info(tracks)
#     
#     # Access mask data
#     prediction_mask = pipeline.prediction_mask
#     tracklet_mask_dict = pipeline.tracklet_mask_dict
#     
#     # Process each track
#     for info in track_info:
#         track_id = info['track_id']
#         mask_id = tracklet_mask_dict.get(track_id, None)
#         print(f"Frame {frame_id}: Track {track_id} - {info['class_name']} - Mask ID: {mask_id}")
#     
#     # Stop after 10 frames for demo
#     if frame_id >= 10:
#         break

## Cell 10: Disable Masks (Performance Mode)

If you need faster processing without mask enhancement, you can disable the mask module.

In [None]:
# Create config without masks for faster processing
config_fast = BananaTrackerConfig(
    yolo_weights="/content/HockeyAI_model_weight.pt",
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],
    special_classes=[5],
    detection_conf_thresh=0.5,
    
    # Tracker settings
    track_thresh=0.6,
    track_buffer=30,
    cmc_method="orb",
    
    # DISABLE mask module
    enable_masks=False,
    
    # Output
    output_video_path="/content/output_fast.mp4",
    device="cuda:0",
)

# pipeline_fast = BananaTrackerPipeline(config_fast)
# all_tracks_fast = pipeline_fast.process_video(INPUT_VIDEO)

print("Fast mode config created (masks disabled)")
print("Uncomment the lines above to run without mask enhancement")

## Architecture Overview

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                           BananaTracker Pipeline                             │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  Input Frame                                                                 │
│      │                                                                       │
│      ▼                                                                       │
│  ┌─────────┐                                                                 │
│  │ YOLOv8  │ ─────────────────────────────────────────────────────┐         │
│  │Detector │                                                       │         │
│  └────┬────┘                                                       │         │
│       │ Detections [x1,y1,x2,y2,conf,class]                       │         │
│       ▼                                                            │         │
│  ┌──────────────┐                                                  │         │
│  │ BananaTracker│◄─────────── Mask-Enhanced ───────────────────┐   │         │
│  │  (ByteTrack) │             Cost Matrix                      │   │         │
│  └──────┬───────┘                                              │   │         │
│         │ Tracks, New Tracks, Removed IDs                      │   │         │
│         ▼                                                      │   │         │
│  ┌──────────────────────────────────────────────────────────┐  │   │         │
│  │                    MaskManager                            │  │   │         │
│  │  ┌───────────┐         ┌─────────┐                       │  │   │         │
│  │  │  SAM2.1   │─────────│  Cutie  │                       │  │   │         │
│  │  │ (Initial  │  Seed   │(Temporal│                       │  │   │         │
│  │  │  Masks)   │  Masks  │ Propag.)│                       │  │   │         │
│  │  └───────────┘         └────┬────┘                       │  │   │         │
│  │                             │                             │  │   │         │
│  │                     prediction_mask                       │──┘   │         │
│  │                    tracklet_mask_dict                     │      │         │
│  │                    mask_avg_prob_dict                     │      │         │
│  └──────────────────────────────────────────────────────────┘      │         │
│                                                                     │         │
│  ┌─────────────┐                                                   │         │
│  │ Visualizer  │◄──────────────────────────────────────────────────┘         │
│  │ (with mask  │                                                             │
│  │  overlay)   │                                                             │
│  └──────┬──────┘                                                             │
│         │                                                                    │
│         ▼                                                                    │
│    Output Frame                                                              │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘
```

### Key Components:

1. **SAM2.1**: Called ONCE per new tracklet to create pixel-precise mask from bounding box
2. **Cutie**: Called EVERY frame to propagate masks temporally
3. **Mask-Enhanced Cost Matrix**: Uses `mc` (mask coverage) and `mf` (mask fill) metrics to improve association when IoU is ambiguous