# BananaTracker - Optimized for Ice Hockey

This notebook demonstrates the OPTIMIZED MOT tracking pipeline with:
- **Batch SAM2.1 Inference** - Process all boxes in single forward pass
- **Pre-computed Mask Statistics** - Avoid np.unique() inside nested loops
- **Vectorized Kalman Operations** - Batch velocity zeroing for lost tracks
- **Memory Management** - Periodic GPU cache clearing to prevent OOM
- **ECC Motion Compensation** - Better for fast motion and camera jitter
- **Auto FPS Detection** - Buffer scales automatically with video fps
- **Extended Track Buffer** - 1.5 seconds for occlusion handling (memory-balanced)
- **Grace Period for Unconfirmed Tracks** - 4 frames before removal

## Performance Targets
- FPS: 15-30+ (2-3x improvement)
- Better detection (catch more objects)
- Stable tracking (fewer ID switches)
- Track through occlusions (ice hockey player collisions)
- No GPU OOM crashes on long videos

## Cell 1: Install Dependencies

In [None]:
# Install required packages
!pip install ultralytics opencv-python-headless tqdm
!pip install lap cython_bbox  # For ByteTrack tracker core

# Install SAM2.1 dependencies (HuggingFace transformers)
!pip install transformers>=4.35.0 huggingface_hub

# Install Cutie dependencies
!pip install omegaconf hydra-core

## Cell 2: Clone Repositories

In [None]:
import os

# Clone BananaTracker repository
if not os.path.exists('bananatracker'):
    !git clone https://github.com/USER/bananatracker.git

# Clone Cutie for temporal mask propagation
if not os.path.exists('Cutie'):
    !git clone https://github.com/hkchengrex/Cutie.git

# Create symlink for Cutie in mask_propagation folder
os.makedirs('bananatracker/bananatracker/mask_propagation', exist_ok=True)
if not os.path.exists('bananatracker/bananatracker/mask_propagation/Cutie'):
    os.symlink('/content/Cutie', 'bananatracker/bananatracker/mask_propagation/Cutie')

# Download Cutie weights
os.makedirs('Cutie/weights', exist_ok=True)
if not os.path.exists('Cutie/weights/cutie-base-mega.pth'):
    !wget -P Cutie/weights https://github.com/hkchengrex/Cutie/releases/download/v1.0/cutie-base-mega.pth

# Install BananaTracker in development mode
%cd bananatracker
!pip install -e .
%cd ..

## Cell 3: Model Configuration

In [None]:
#@title Model Configuration { display-mode: "form" }

# SAM2.1 Model Settings
SAM2_MODEL_ID = "facebook/sam2.1-hiera-large"  #@param ["facebook/sam2.1-hiera-tiny", "facebook/sam2.1-hiera-small", "facebook/sam2.1-hiera-base-plus", "facebook/sam2.1-hiera-large"]
SAM2_CHECKPOINT = ""  #@param {type:"string"}

# HuggingFace Token (optional)
HF_TOKEN = ""  #@param {type:"string"}

# Cutie Weights Path
CUTIE_WEIGHTS = "/content/Cutie/weights/cutie-base-mega.pth"  #@param {type:"string"}

# Test Video Path
TEST_VIDEO = "/Users/harry/final/sample.mp4"  #@param {type:"string"}

# YOLO Weights Path (use glob pattern to find .pt files)
import glob
YOLO_WEIGHTS = glob.glob("/Users/harry/final/*.pt")[0] if glob.glob("/Users/harry/final/*.pt") else ""

# Test Duration (seconds)
TEST_DURATION_SECONDS = 5  #@param {type:"integer"}

print(f"SAM2.1 Model: {SAM2_MODEL_ID}")
print(f"Cutie Weights: {CUTIE_WEIGHTS}")
print(f"Test Video: {TEST_VIDEO}")
print(f"YOLO Weights: {YOLO_WEIGHTS}")
print(f"Test Duration: {TEST_DURATION_SECONDS} seconds")

## Cell 4: OPTIMIZED Tracker Configuration

Configuration with all performance optimizations for ice hockey:
- **detection_conf_thresh: 0.4** (was 0.5) - Catch more objects
- **track_thresh: 0.5** (was 0.6) - Match more detections in first pass
- **track_buffer: 45** - 1.5 seconds at 30fps for occlusion recovery (memory-balanced)
- **cmc_method: "ecc"** - Better for fast motion and camera jitter (ice hockey)
- **fps: Auto-detected** - Buffer scales automatically with video fps
- **debug_tracking: True** - Log track lifecycle events

In [None]:
import sys
sys.path.insert(0, '/content/bananatracker')
sys.path.insert(0, '/content/Cutie')

# For local testing
sys.path.insert(0, '/Users/harry/final/bananatracker')

from bananatracker import BananaTrackerConfig

# OPTIMIZED configuration for ice hockey (fast motion, jitter)
config = BananaTrackerConfig(
    # Detection Settings - OPTIMIZED
    yolo_weights=YOLO_WEIGHTS,
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],  # Goaltender, Player, Puck, Referee
    special_classes=[5],          # Puck - max-conf only
    detection_conf_thresh=0.4,    # LOWERED from 0.5 - catch more objects
    detection_iou_thresh=0.7,

    # Post-processing
    centroid_dedup_enabled=True,
    centroid_dedup_max_distance=36,

    # Tracker Settings - OPTIMIZED
    # NOTE: fps is auto-detected from video, no need to set it
    track_thresh=0.5,    # LOWERED from 0.6 - match more detections
    track_buffer=45,     # 1.5 seconds at 30fps (balanced: memory vs occlusion handling)
    match_thresh=0.8,
    cmc_method="ecc",    # ECC for fast motion/jitter (ice hockey)

    # Lost track recovery
    lost_track_buffer_scale=0.3,  # Expand bbox by 30% for lost track matching

    # Mask Module Settings - ENABLED with batch processing
    enable_masks=True,
    sam2_model_id=SAM2_MODEL_ID,
    sam2_checkpoint=SAM2_CHECKPOINT if SAM2_CHECKPOINT else None,
    cutie_weights_path=CUTIE_WEIGHTS,
    hf_token=HF_TOKEN if HF_TOKEN else None,
    mask_start_frame=1,
    mask_bbox_overlap_threshold=0.6,

    # Visualization Settings
    class_colors={
        "Goaltender": (255, 165, 0),
        "Player": (255, 0, 0),
        "Puck": (0, 255, 0),
        "Referee": (0, 0, 255),
    },
    show_track_id=True,
    show_masks=True,
    mask_alpha=0.5,
    line_thickness=2,

    # Output Settings
    output_video_path="/content/output_optimized.mp4",
    output_txt_path="/content/results_optimized.txt",
    device="cuda:0",

    # Debug - ENABLED for testing
    debug_tracking=True,
)

print("=" * 60)
print("OPTIMIZED Configuration for Ice Hockey!")
print("=" * 60)
print(f"\nDetection:")
print(f"  - Confidence threshold: {config.detection_conf_thresh} (was 0.5)")
print(f"  - Tracking classes: {config.track_classes}")
print(f"\nTracker:")
print(f"  - track_thresh: {config.track_thresh} (was 0.6)")
print(f"  - track_buffer: {config.track_buffer} frames")
print(f"  - cmc_method: {config.cmc_method} (best for fast motion)")
print(f"  - fps: Auto-detected from video")
print(f"  - lost_track_buffer_scale: {config.lost_track_buffer_scale}")
print(f"\nMask Module:")
print(f"  - Enabled: {config.enable_masks}")
print(f"  - SAM2.1 Model: {config.sam2_model_id}")
print(f"  - Batch inference: ENABLED")
print(f"\nDebug:")
print(f"  - debug_tracking: {config.debug_tracking}")

## Cell 5: Initialize Optimized Pipeline

In [None]:
import logging

# Enable debug logging for BananaTracker
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("BananaTracker")
logger.setLevel(logging.DEBUG)

from bananatracker import BananaTrackerPipeline

print("Initializing OPTIMIZED pipeline...")
print("This may take a moment to download SAM2.1 model from HuggingFace...")

pipeline = BananaTrackerPipeline(config)

# Enable debug tracking on tracker
pipeline.tracker.debug_tracking = config.debug_tracking

print("\nPipeline initialized!")
print(f"  - Detector: YOLOv8")
print(f"  - Tracker: BananaTracker (OPTIMIZED)")
print(f"  - Mask Module: {'SAM2.1 + Cutie (BATCH)' if pipeline.mask_manager else 'Disabled'}")
print(f"  - CMC Method: {config.cmc_method}")

## Cell 6: Performance Benchmark - First N Seconds

Run tracking on the first N seconds to measure performance metrics:
- FPS (target: 15-30+)
- Unique track IDs (fewer = better stability)
- Track continuity

In [None]:
import cv2
import time
import numpy as np
from collections import defaultdict

# Open video
cap = cv2.VideoCapture(TEST_VIDEO)
if not cap.isOpened():
    print(f"Error: Could not open video {TEST_VIDEO}")
else:
    fps_video = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    # Calculate frames to process
    frames_to_process = min(int(fps_video * TEST_DURATION_SECONDS), total_frames)
    
    print(f"Video: {TEST_VIDEO}")
    print(f"  - Resolution: {width}x{height}")
    print(f"  - FPS: {fps_video}")
    print(f"  - Total frames: {total_frames}")
    print(f"  - Testing first {frames_to_process} frames ({TEST_DURATION_SECONDS} seconds)")
    print()
    
    # Reset tracker
    pipeline.tracker.reset()
    pipeline._reset_mask_state()
    
    # Metrics tracking
    all_track_ids = set()
    track_lifetimes = defaultdict(int)
    frame_times = []
    detection_counts = []
    track_counts = []
    
    frame_id = 0
    
    print("Processing...")
    start_time = time.time()
    
    while frame_id < frames_to_process:
        ret, frame = cap.read()
        if not ret:
            break
        
        frame_id += 1
        frame_start = time.time()
        
        # Detect
        detections = pipeline.detector.detect(frame)
        detection_counts.append(len(detections))
        
        # Track with mask parameters
        img_info = (height, width)
        tracks, removed_ids, new_tracks = pipeline.tracker.update(
            detections_array=detections,
            img_info=img_info,
            prediction_mask=pipeline.prediction_mask,
            tracklet_mask_dict=pipeline.tracklet_mask_dict,
            mask_avg_prob_dict=pipeline.mask_avg_prob_dict,
            frame_img=frame
        )
        
        # Update masks
        if pipeline.mask_manager is not None:
            pipeline._update_masks(frame, frame_id, tracks, new_tracks, removed_ids)
        
        frame_time = time.time() - frame_start
        frame_times.append(frame_time)
        
        # Track metrics
        track_counts.append(len(tracks))
        for track in tracks:
            all_track_ids.add(track.track_id)
            track_lifetimes[track.track_id] += 1
        
        # Progress
        if frame_id % 30 == 0:
            avg_fps = frame_id / (time.time() - start_time)
            print(f"  Frame {frame_id}/{frames_to_process} - FPS: {avg_fps:.1f} - Tracks: {len(tracks)}")
    
    total_time = time.time() - start_time
    cap.release()
    
    # Calculate metrics
    avg_fps = frame_id / total_time
    avg_frame_time = np.mean(frame_times) * 1000
    avg_detections = np.mean(detection_counts)
    avg_tracks = np.mean(track_counts)
    
    # Track stability metrics
    lifetimes = list(track_lifetimes.values())
    avg_lifetime = np.mean(lifetimes) if lifetimes else 0
    long_tracks = sum(1 for l in lifetimes if l >= 30)  # Tracks lasting >= 1 second
    
    print()
    print("=" * 60)
    print("PERFORMANCE RESULTS")
    print("=" * 60)
    print(f"\nSpeed:")
    print(f"  - Total time: {total_time:.2f}s for {frame_id} frames")
    print(f"  - Average FPS: {avg_fps:.1f}")
    print(f"  - Average frame time: {avg_frame_time:.1f}ms")
    print(f"\nDetection:")
    print(f"  - Average detections/frame: {avg_detections:.1f}")
    print(f"  - Max detections in frame: {max(detection_counts)}")
    print(f"\nTracking Stability:")
    print(f"  - Total unique track IDs: {len(all_track_ids)}")
    print(f"  - Average tracks/frame: {avg_tracks:.1f}")
    print(f"  - Average track lifetime: {avg_lifetime:.1f} frames")
    print(f"  - Long-lived tracks (>=30 frames): {long_tracks}")
    print()
    print("=" * 60)

## Cell 7: Compare with Non-Optimized Config (Optional)

Run the same test with original thresholds to compare performance.

In [None]:
# Create BASELINE config (original non-optimized settings)
config_baseline = BananaTrackerConfig(
    yolo_weights=YOLO_WEIGHTS,
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],
    special_classes=[5],
    
    # ORIGINAL VALUES (non-optimized)
    detection_conf_thresh=0.5,  # Original
    track_thresh=0.6,           # Original
    track_buffer=30,            # Original
    cmc_method="orb",           # Original (ECC is better for ice hockey)
    
    # Masks DISABLED for speed comparison
    enable_masks=False,
    
    device="cuda:0",
)

print("Baseline config (non-optimized) created.")
print("Uncomment the code below to run comparison.")

# Uncomment to run baseline comparison:
# pipeline_baseline = BananaTrackerPipeline(config_baseline)
# # ... run same benchmark code ...

## Cell 8: Full Video Processing

In [None]:
# Run full video tracking with optimized settings
# Uncomment to process full video:

# print(f"Processing full video: {TEST_VIDEO}")
# print(f"Mask module: {'Enabled' if config.enable_masks else 'Disabled'}")
# 
# all_tracks = pipeline.process_video(TEST_VIDEO)
# 
# print(f"\nProcessed {len(all_tracks)} frames")
# print(f"Output video: {config.output_video_path}")
# print(f"MOT results: {config.output_txt_path}")

## Optimization Summary

| Category | Change | Expected Impact |
|----------|--------|----------------|
| **Speed** | Batch SAM2.1 | 2-4x faster mask generation |
| **Speed** | Pre-compute mask stats | 5-10x faster conditioned_assignment |
| **Speed** | Vectorize Kalman ops | Reduced loop overhead |
| **Speed** | Eliminate frame copies | 18-27MB/frame memory saved |
| **Memory** | Periodic GPU cache clear | Prevents OOM on long videos |
| **Memory** | track_buffer 90->45 | Reduced memory, still handles occlusion |
| **Reliability** | det_thresh 0.7->0.5 | More tracks created |
| **Reliability** | Grace period 1->4 frames | Stabilize new tracks |
| **Reliability** | Relaxed cost thresholds | Better matching |
| **Motion** | cmc_method: ECC | Better for fast motion/jitter (ice hockey) |
| **Flexibility** | fps auto-detected | Works with any video fps |
| **Detection** | conf_thresh 0.5->0.4 | Catch more objects |
| **Detection** | track_thresh 0.6->0.5 | Match more in first pass |