# BananaTracker - Multi-Object Tracking Demo

This notebook demonstrates how to use BananaTracker for multi-object tracking with YOLOv8 detection and ByteTrack-based tracking.

## Cell 1: Install Dependencies

In [None]:
# Install required packages
!pip install ultralytics opencv-python-headless tqdm
!pip install lap cython_bbox  # For ByteTrack tracker core

## Cell 2: Clone Repository and Install

In [None]:
# Clone the repository
!git clone https://github.com/USER/bananatracker.git
%cd bananatracker

# Install in development mode
!pip install -e .

## Cell 3: Configuration

Configure the tracker with your model weights and settings.

In [None]:
from bananatracker import BananaTrackerConfig

# Example configuration for hockey tracking
config = BananaTrackerConfig(
    # Detection
    yolo_weights="/content/HockeyAI_model_weight.pt",  # Update with your model path
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],  # Goaltender, Player, Puck, Referee
    special_classes=[5],          # Puck - max-conf only
    detection_conf_thresh=0.5,    # General confidence threshold
    detection_iou_thresh=0.7,     # IoU threshold for YOLO NMS

    # Post-processing: Centroid-based deduplication
    centroid_dedup_enabled=True,     # Remove duplicate boxes for same player
    centroid_dedup_max_distance=36,  # Max pixel distance to consider duplicates

    # Tracker
    track_thresh=0.6,
    track_buffer=30,
    cmc_method="orb",  # Options: "orb", "ecc", "sift", "sparseOptFlow", "none"

    # Visualization
    class_colors={
        "Goaltender": (255, 165, 0),   # Orange
        "Player": (255, 0, 0),          # Blue (BGR)
        "Puck": (0, 255, 0),            # Green
        "Referee": (0, 0, 255),         # Red
    },
    show_track_id=True,
    line_thickness=2,

    # Output
    output_video_path="/content/output_tracked.mp4",
    output_txt_path="/content/results.txt",
    device="cuda:0",
)

print("Configuration created!")
print(f"Detection thresholds: conf={config.detection_conf_thresh}, iou={config.detection_iou_thresh}")
print(f"Centroid dedup: enabled={config.centroid_dedup_enabled}, max_distance={config.centroid_dedup_max_distance}px")
print(f"Tracking classes: {config.track_classes}")
print(f"Special classes (max-conf only): {config.special_classes}")

## Cell 4: Load Model and Setup Tracker

In [None]:
from bananatracker import BananaTrackerPipeline

# Initialize the pipeline
pipeline = BananaTrackerPipeline(config)

print("Pipeline initialized!")
print(f"Detector: YOLOv8")
print(f"Tracker: BananaTracker (ByteTrack-based)")
print(f"CMC Method: {config.cmc_method}")

## Cell 5: Run Tracking

In [None]:
# Path to input video
INPUT_VIDEO = "/content/sample_video.mp4"  # Update with your video path

# Run tracking
print(f"Processing video: {INPUT_VIDEO}")
all_tracks = pipeline.process_video(INPUT_VIDEO)

print(f"\nProcessed {len(all_tracks)} frames")
print(f"Output video: {config.output_video_path}")
print(f"MOT results: {config.output_txt_path}")

## Cell 6: Compress Output Video

In [None]:
%%capture
# Compress video for notebook display
OUTPUT_COMPRESSED = "/content/output_compressed.mp4"
!ffmpeg -y -i {config.output_video_path} -vcodec libx264 -crf 28 {OUTPUT_COMPRESSED}

print(f"Compressed video saved to: {OUTPUT_COMPRESSED}")

## Cell 7: Display Video in Notebook

In [None]:
from IPython.display import HTML
from base64 import b64encode

OUTPUT_COMPRESSED = "/content/output_compressed.mp4"

# Read and encode video
mp4 = open(OUTPUT_COMPRESSED, 'rb').read()
data_url = f"data:video/mp4;base64,{b64encode(mp4).decode()}"

# Display video
HTML(f'''
<video width="800" controls>
  <source src="{data_url}" type="video/mp4">
</video>
''')

## Optional: Frame-by-Frame Processing

For more control, you can process frames individually using the generator API.

In [None]:
# Example: Process frame-by-frame with generator
# Uncomment to run

# from bananatracker import BananaTrackerPipeline
# 
# pipeline = BananaTrackerPipeline(config)
# 
# for frame_id, frame, tracks, vis_frame in pipeline.process_video_generator(INPUT_VIDEO):
#     # Get track info as dictionaries
#     track_info = pipeline.get_track_info(tracks)
#     
#     # Process each track
#     for info in track_info:
#         print(f"Frame {frame_id}: Track {info['track_id']} - {info['class_name']} at {info['bbox']}")
#     
#     # Stop after 10 frames for demo
#     if frame_id >= 10:
#         break