# BananaTracker - Multi-Object Tracking with SAM2.1 + Cutie

This notebook demonstrates the complete MOT tracking pipeline using:
- **YOLOv8** for object detection
- **ByteTrack-based tracker** for multi-object tracking
- **SAM2.1** for high-quality mask generation from bounding boxes
- **Cutie** for temporal mask propagation

The mask module enhances tracking by providing pixel-level precision that improves association when objects are close together.

## Cell 1: Install Dependencies

In [None]:
# Install required packages
!pip install ultralytics opencv-python-headless tqdm
!pip install lap cython_bbox  # For ByteTrack tracker core

# Install SAM2.1 dependencies (HuggingFace transformers)
!pip install transformers>=4.35.0 huggingface_hub

# Install Cutie dependencies
!pip install omegaconf hydra-core

# GTA-Link post-processing dependencies
!pip install scikit-learn loguru seaborn

## Cell 2: Clone Repositories

In [None]:
import os

# Clone BananaTracker repository
if not os.path.exists('bananatracker'):
    !git clone https://github.com/USER/bananatracker.git

# Clone Cutie for temporal mask propagation
if not os.path.exists('Cutie'):
    !git clone https://github.com/hkchengrex/Cutie.git

# Create symlink for Cutie in mask_propagation folder
os.makedirs('bananatracker/bananatracker/mask_propagation', exist_ok=True)
if not os.path.exists('bananatracker/bananatracker/mask_propagation/Cutie'):
    os.symlink('/content/Cutie', 'bananatracker/bananatracker/mask_propagation/Cutie')

# Download Cutie weights
os.makedirs('Cutie/weights', exist_ok=True)
if not os.path.exists('Cutie/weights/cutie-base-mega.pth'):
    !wget -P Cutie/weights https://github.com/hkchengrex/Cutie/releases/download/v1.0/cutie-base-mega.pth

# Install BananaTracker in development mode
%cd bananatracker
!pip install -e .
%cd ..

# Clone GTA-Link for post-processing track refinement
if not os.path.exists('/content/gta-link'):
    !git clone https://github.com/sjc042/gta-link.git /content/gta-link

# Make torchreid importable via PYTHONPATH (no setup.py needed — rank_cy Cython
# extension is eval-only; FeatureExtractor used by GTA-Link is pure Python).
os.environ['PYTHONPATH'] = '/content/gta-link/reid:' + os.environ.get('PYTHONPATH', '')

## Cell 3: SAM2.1 Model Configuration

Configure SAM2.1 model settings. You can choose different model sizes:
- `facebook/sam2.1-hiera-tiny` - Fastest, lower quality
- `facebook/sam2.1-hiera-small` - Good balance
- `facebook/sam2.1-hiera-base-plus` - Better quality
- `facebook/sam2.1-hiera-large` - Best quality (recommended)

In [None]:
#@title SAM2.1 Configuration { display-mode: "form" }

# SAM2.1 Model Settings
SAM2_MODEL_ID = "facebook/sam2.1-hiera-large"  #@param ["facebook/sam2.1-hiera-tiny", "facebook/sam2.1-hiera-small", "facebook/sam2.1-hiera-base-plus", "facebook/sam2.1-hiera-large"]
SAM2_CHECKPOINT = ""  #@param {type:"string"} # Optional: local checkpoint path (leave empty to use HuggingFace)

# HuggingFace Token (required for gated models, optional for SAM2)
HF_TOKEN = ""  #@param {type:"string"}

# Cutie Weights Path
CUTIE_WEIGHTS = "/content/Cutie/weights/cutie-base-mega.pth"  #@param {type:"string"}

print(f"SAM2.1 Model: {SAM2_MODEL_ID}")
print(f"SAM2.1 Checkpoint: {SAM2_CHECKPOINT if SAM2_CHECKPOINT else 'Using HuggingFace download'}")
print(f"HF Token: {'Set' if HF_TOKEN else 'Not set (optional)'}")
print(f"Cutie Weights: {CUTIE_WEIGHTS}")

## Cell 4: Tracker Configuration

Configure the full tracking pipeline with detection, tracking, and mask settings.

In [None]:
import sys
sys.path.insert(0, '/content/bananatracker')
sys.path.insert(0, '/content/Cutie')

from bananatracker import BananaTrackerConfig

# Full configuration for hockey/sports tracking with mask enhancement
# NOTE: These values are optimized for better tracking performance:
#   - Lower detection thresholds to catch more objects
#   - Larger track_buffer for longer occlusion handling (3 seconds)
#   - Lower track_thresh for more lenient first-pass matching
config = BananaTrackerConfig(
    # Detection Settings
    yolo_weights="/content/HockeyAI_model_weight.pt",  # Update with your model path
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],  # Goaltender, Player, Puck, Referee
    special_classes=[5],          # Puck - max-conf only
    detection_conf_thresh=0.4,    # Lowered to catch more objects
    detection_iou_thresh=0.7,     # IoU threshold for YOLO NMS

    # Post-processing: Centroid-based deduplication
    centroid_dedup_enabled=True,
    centroid_dedup_max_distance=36,

    # Tracker Settings (ByteTrack) - Optimized for stability
    track_thresh=0.5,    # Lowered for more first-pass matches
    track_buffer=90,     # 3 seconds at 30fps for better occlusion handling
    match_thresh=0.8,
    fps=30,
    cmc_method="orb",  # Camera motion compensation

    # Mask Module Settings (SAM2.1 + Cutie)
    enable_masks=True,
    sam2_model_id=SAM2_MODEL_ID,
    sam2_checkpoint=SAM2_CHECKPOINT if SAM2_CHECKPOINT else None,
    cutie_weights_path=CUTIE_WEIGHTS,
    hf_token=HF_TOKEN if HF_TOKEN else None,
    mask_start_frame=1,
    mask_bbox_overlap_threshold=0.6,

    # Visualization Settings
    class_colors={
        "Goaltender": (255, 165, 0),   # Orange
        "Player": (255, 0, 0),          # Blue (BGR)
        "Puck": (0, 255, 0),            # Green
        "Referee": (0, 0, 255),         # Red
    },
    show_track_id=True,
    show_masks=True,     # Enable mask overlay in visualization
    mask_alpha=0.5,      # Mask transparency
    line_thickness=2,

    # Output Settings
    output_video_path="/content/output_tracked_with_masks.mp4",
    output_txt_path="/content/results.txt",
    device="cuda:0",
)

print("Configuration created!")
print(f"\nDetection:")
print(f"  - Confidence threshold: {config.detection_conf_thresh}")
print(f"  - Tracking classes: {config.track_classes}")
print(f"\nTracker:")
print(f"  - Track threshold: {config.track_thresh}")
print(f"  - Track buffer: {config.track_buffer} frames ({config.track_buffer/config.fps:.1f}s)")
print(f"\nMask Module:")
print(f"  - Enabled: {config.enable_masks}")
print(f"  - SAM2.1 Model: {config.sam2_model_id}")
print(f"  - Cutie Weights: {config.cutie_weights_path}")
print(f"\nVisualization:")
print(f"  - Show masks: {config.show_masks}")
print(f"  - Mask alpha: {config.mask_alpha}")

## Cell 4b: Enable Debug Logging (Optional)

Enable verbose debug logging to track LOST, REMOVED, CREATED, and REJECTED events during tracking. Useful for diagnosing ID switches and track loss issues.

In [None]:
#@title Debug Logging Configuration { display-mode: "form" }
import logging

# Enable/disable debug logging
ENABLE_DEBUG_LOGGING = True  #@param {type:"boolean"}

# Configure the BananaTracker logger
logger = logging.getLogger("BananaTracker")

if ENABLE_DEBUG_LOGGING:
    # Set up console handler with detailed formatting
    handler = logging.StreamHandler()
    handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%H:%M:%S'
    )
    handler.setFormatter(formatter)
    
    # Clear existing handlers and add new one
    logger.handlers.clear()
    logger.addHandler(handler)
    logger.setLevel(logging.DEBUG)
    
    print("Debug logging ENABLED")
    print("  - Track LOST: When a track loses its detection")
    print("  - Track REMOVED: When an unconfirmed track exceeds grace period")
    print("  - Track CREATED: When a new track is initialized")
    print("  - Detection REJECTED: When detection confidence is below threshold")
    print("\nLogs will appear during pipeline.process_video() execution.")
else:
    logger.setLevel(logging.WARNING)
    print("Debug logging DISABLED (only warnings and errors will be shown)")

## Cell 5: Initialize Pipeline with SAM2.1 + Cutie

In [None]:
from bananatracker import BananaTrackerPipeline

# Initialize the pipeline (this will load SAM2.1 and Cutie models)
print("Initializing pipeline...")
print("This may take a moment to download SAM2.1 model from HuggingFace...")

pipeline = BananaTrackerPipeline(config)

print("\nPipeline initialized!")
print(f"  - Detector: YOLOv8")
print(f"  - Tracker: BananaTracker (ByteTrack-based with mask enhancement)")
print(f"  - Mask Module: {'SAM2.1 + Cutie' if pipeline.mask_manager else 'Disabled'}")
print(f"  - CMC Method: {config.cmc_method}")

## Cell 6: Validate SAM2.1 Mask Generation (Test Cell)

This cell validates that SAM2.1 is working correctly by generating masks on a random frame from the video.

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random

# Input video for testing
TEST_VIDEO = "/content/sample_video.mp4"  # Update with your video path

# Open video and get a random frame
cap = cv2.VideoCapture(TEST_VIDEO)
if not cap.isOpened():
    print(f"Error: Could not open video {TEST_VIDEO}")
else:
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    random_frame_idx = random.randint(0, total_frames - 1)
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, random_frame_idx)
    ret, frame = cap.read()
    cap.release()
    
    if ret:
        print(f"Testing SAM2.1 on frame {random_frame_idx} of {total_frames}")
        
        # Convert to RGB for SAM2.1
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Run detection to get bounding boxes
        detections = pipeline.detector.detect(frame)
        
        if len(detections) > 0:
            # Get bounding boxes (first 4 columns: x1, y1, x2, y2)
            boxes_xyxy = detections[:, :4].tolist()
            
            print(f"Detected {len(boxes_xyxy)} objects")
            
            # Generate masks using SAM2.1
            if pipeline.mask_manager is not None:
                print("Generating masks with SAM2.1...")
                try:
                    masks = pipeline.mask_manager._sam2_predict_boxes(frame_rgb, boxes_xyxy)
                    
                    print(f"Generated {len(masks)} masks")
                    print(f"Mask shape: {masks.shape}")
                    
                    # Visualize results
                    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
                    
                    # Original frame
                    axes[0].imshow(frame_rgb)
                    axes[0].set_title(f'Original Frame #{random_frame_idx}')
                    axes[0].axis('off')
                    
                    # Frame with bounding boxes
                    frame_with_boxes = frame_rgb.copy()
                    for i, box in enumerate(boxes_xyxy):
                        x1, y1, x2, y2 = map(int, box)
                        cv2.rectangle(frame_with_boxes, (x1, y1), (x2, y2), (255, 0, 0), 2)
                        cv2.putText(frame_with_boxes, f'{i}', (x1, y1-10), 
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
                    axes[1].imshow(frame_with_boxes)
                    axes[1].set_title(f'Detected Objects ({len(boxes_xyxy)})')
                    axes[1].axis('off')
                    
                    # Combined masks overlay
                    colors = plt.cm.tab10(np.linspace(0, 1, 10))[:, :3] * 255
                    combined_mask = np.zeros_like(frame_rgb, dtype=np.float32)
                    for i, mask in enumerate(masks):
                        color = colors[i % len(colors)]
                        mask_3d = np.stack([mask, mask, mask], axis=-1).astype(np.float32)
                        combined_mask += mask_3d * color / 255
                    
                    alpha = 0.5
                    frame_with_masks = frame_rgb.astype(np.float32)
                    mask_overlay = np.clip(combined_mask, 0, 255)
                    binary_mask = np.any(combined_mask > 0, axis=-1, keepdims=True)
                    frame_with_masks = np.where(
                        binary_mask,
                        frame_with_masks * alpha + mask_overlay * (1 - alpha),
                        frame_with_masks
                    )
                    axes[2].imshow(frame_with_masks.astype(np.uint8))
                    axes[2].set_title('SAM2.1 Mask Overlay')
                    axes[2].axis('off')
                    
                    plt.tight_layout()
                    plt.savefig('/content/sam2_mask_validation.png', dpi=150, bbox_inches='tight')
                    plt.show()
                    
                    print("\n✓ SAM2.1 mask generation validated successfully!")
                    print(f"  Validation image saved to: /content/sam2_mask_validation.png")
                    
                except Exception as e:
                    print(f"Error during mask generation: {e}")
                    import traceback
                    traceback.print_exc()
            else:
                print("Warning: Mask module not initialized")
        else:
            print("No objects detected in this frame. Try a different frame.")
    else:
        print(f"Error: Could not read frame {random_frame_idx}")

## Cell 7: Run Full Tracking Pipeline

In [None]:
# Path to input video
INPUT_VIDEO = "/content/sample_video.mp4"  # Update with your video path

# Run tracking with mask enhancement
print(f"Processing video: {INPUT_VIDEO}")
print(f"Mask module: {'Enabled' if config.enable_masks else 'Disabled'}")
print("")

all_tracks = pipeline.process_video(INPUT_VIDEO)

print(f"\nProcessed {len(all_tracks)} frames")
print(f"Output video: {config.output_video_path}")
print(f"MOT results: {config.output_txt_path}")

---
## GTA-Link Post-Processing

Refine MOT results using GTA-Link's two-phase pipeline:
1. **Phase 1 — Generate Tracklets**: Extract per-track appearance features using OSNet re-ID model
2. **Phase 2 — Refine Tracklets**: Split identity switches (DBSCAN clustering) then merge fragmented tracks (hierarchical cosine-distance merging)

Output: refined MOT `.txt`, re-rendered video with refined tracks, before/after summary statistics.

In [None]:
#@title GTA-Link Configuration { display-mode: "form" }

# ─── Master toggle ───────────────────────────────────────────
GTA_ENABLED = True  #@param {type:"boolean"}

# ─── Classes to refine ────────────────────────────────────────
# Class IDs (from YOLO model):
#   0=Center Ice  1=Faceoff  2=Goalpost  3=Goaltender
#   4=Player      5=Puck     6=Referee
# Default "3,4,6" refines Goaltender, Player, Referee.
# Puck (5) is excluded by default — short tracklets confuse the splitter.
GTA_CLASSES_TO_REFINE = "3,4,6"  #@param {type:"string"}

# ─── Phase 2 — Split parameters ───────────────────────────────
GTA_USE_SPLIT = True   #@param {type:"boolean"}
GTA_EPS = 0.7          #@param {type:"number"}
GTA_MIN_SAMPLES = 10   #@param {type:"number"}
GTA_MAX_K = 3          #@param {type:"number"}
GTA_MIN_LEN = 100      #@param {type:"number"}

# ─── Phase 2 — Connect (merge) parameters ────────────────────
GTA_USE_CONNECT = True       #@param {type:"boolean"}
GTA_MERGE_DIST_THRES = 0.4   #@param {type:"number"}
GTA_SPATIAL_FACTOR = 1.0     #@param {type:"number"}

# ─── Output-path labels ──────────────────────────────────────
GTA_TRACKER_NAME = "BananaTracker"  #@param {type:"string"}
GTA_DATASET_NAME = "hockey"         #@param {type:"string"}
GTA_SEQ_NAME = "sequence"           #@param {type:"string"}

# ─── Derived ──────────────────────────────────────────────────
GTA_CLASSES_TO_REFINE_LIST = [int(x.strip()) for x in GTA_CLASSES_TO_REFINE.split(",") if x.strip()]

print("=" * 52)
print("GTA-Link Post-Processing Configuration")
print("=" * 52)
print(f"  Enabled            : {GTA_ENABLED}")
print(f"  Classes to refine  : {GTA_CLASSES_TO_REFINE_LIST}")
print(f"  Split              : {GTA_USE_SPLIT} (eps={GTA_EPS}, min_samples={GTA_MIN_SAMPLES}, max_k={GTA_MAX_K}, min_len={GTA_MIN_LEN})")
print(f"  Connect            : {GTA_USE_CONNECT} (merge_dist={GTA_MERGE_DIST_THRES}, spatial={GTA_SPATIAL_FACTOR})")
print(f"  Tracker / Dataset  : {GTA_TRACKER_NAME} / {GTA_DATASET_NAME}")
print("=" * 52)


### Step 1 — Prepare inputs
Frees GPU memory held by SAM2.1/Cutie/YOLO, extracts video frames to disk, and writes a class-filtered MOT `.txt` for GTA-Link to consume.

In [None]:
import gc, os, sys
import cv2
import torch
from tqdm import tqdm

# ── guard ─────────────────────────────────────────────────────
if not GTA_ENABLED:
    print("GTA_ENABLED is False — skipping Step 1.")
else:
    # ── 1. Free GPU ───────────────────────────────────────────
    try:
        del pipeline
    except NameError:
        pass
    torch.cuda.empty_cache()
    gc.collect()
    print("GPU memory freed.")

    # ── 2. Directory layout ───────────────────────────────────
    GTA_DATA_PATH  = "/content/gta_input/frames"
    GTA_PRED_DIR   = "/content/gta_input/predictions"
    GTA_IMG_DIR    = os.path.join(GTA_DATA_PATH, GTA_SEQ_NAME, "img1")
    os.makedirs(GTA_IMG_DIR,  exist_ok=True)
    os.makedirs(GTA_PRED_DIR, exist_ok=True)
    print(f"Frame dir : {GTA_IMG_DIR}")
    print(f"Pred dir  : {GTA_PRED_DIR}")

    # ── 3. Extract frames ─────────────────────────────────────
    cap = cv2.VideoCapture(INPUT_VIDEO)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"Extracting {total_frames} frames from {INPUT_VIDEO} ...")
    for i in tqdm(range(total_frames), desc="Extracting frames"):
        ret, frame = cap.read()
        if not ret:
            break
        cv2.imwrite(os.path.join(GTA_IMG_DIR, f"{i+1:06d}.jpg"), frame)
    cap.release()
    actual_frames = len(os.listdir(GTA_IMG_DIR))
    print(f"Extracted {actual_frames} frames.")

    # ── 4. Build filtered MOT .txt ────────────────────────────
    # all_tracks: list of (frame_id, list[STrack])
    mot_path = os.path.join(GTA_PRED_DIR, f"{GTA_SEQ_NAME}.txt")
    REFINED_TRACK_IDS = set()          # original track_ids sent to GTA
    TRACK_CLASS_MAP   = {}             # (frame_id, track_id) -> class_id
    line_count = 0
    with open(mot_path, "w") as f:
        for frame_id, tracks in all_tracks:
            for t in tracks:
                if t.class_id not in GTA_CLASSES_TO_REFINE_LIST:
                    continue
                tlwh = t.tlwh
                f.write(f"{frame_id},{t.track_id},{tlwh[0]:.2f},{tlwh[1]:.2f},"
                        f"{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n")
                REFINED_TRACK_IDS.add(t.track_id)
                TRACK_CLASS_MAP[(frame_id, t.track_id)] = t.class_id
                line_count += 1

    print(f"\nFiltered MOT .txt written to {mot_path}")
    print(f"  Lines written        : {line_count}")
    print(f"  Unique track IDs     : {len(REFINED_TRACK_IDS)}")
    print(f"  Classes in file      : {sorted(set(TRACK_CLASS_MAP.values()))}")


### Step 2 — Phase 1: Generate Tracklets
Runs `generate_tracklets.py` via subprocess. This is the expensive step: it performs OSNet re-ID inference on every detection crop across every frame. Subsequent tuning only requires re-running Phase 2 (Step 3).

In [None]:
import subprocess, glob as _glob

if not GTA_ENABLED:
    print("GTA_ENABLED is False — skipping Phase 1.")
else:
    OSNET_CKPT = "/content/gta-link/reid_checkpoints/sports_model.pth.tar-60"

    # ── pre-flight checks ─────────────────────────────────────
    if not os.path.isfile(OSNET_CKPT):
        raise FileNotFoundError(
            f"OSNet checkpoint not found: {OSNET_CKPT}\n"
            "Download it and place it at the path above, or update OSNET_CKPT.")
    if not os.path.isfile(os.path.join("/content/gta-link", "generate_tracklets.py")):
        raise FileNotFoundError("generate_tracklets.py not found in /content/gta-link")

    # ── patch torchreid for PyTorch ≥ 2.6 (weights_only default) ──
    # torch.load changed its default from weights_only=False to True in 2.6;
    # the bundled torchreid predates that change.  Patch is idempotent.
    _torchtools = "/content/gta-link/reid/torchreid/utils/torchtools.py"
    with open(_torchtools) as _f:
        _src = _f.read()
    if "weights_only=False" not in _src:
        _src = _src.replace(
            "torch.load(fpath, map_location=map_location)",
            "torch.load(fpath, map_location=map_location, weights_only=False)")
        with open(_torchtools, "w") as _f:
            _f.write(_src)
        print("Patched torchtools.py: added weights_only=False to torch.load")

    # ── explicit env so torchreid is importable in the child ──
    env = {**os.environ,
           "PYTHONPATH": "/content/gta-link/reid:" + os.environ.get("PYTHONPATH", "")}

    cmd = [
        sys.executable,
        "generate_tracklets.py",
        "--model_path", OSNET_CKPT,
        "--data_path",  GTA_DATA_PATH,
        "--pred_dir",   GTA_PRED_DIR,
        "--tracker",    GTA_TRACKER_NAME,
    ]
    print("Running Phase 1:", " ".join(cmd))
    result = subprocess.run(cmd, cwd="/content/gta-link", env=env,
                            stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    print(result.stdout)          # always surface full output
    if result.returncode != 0:
        raise RuntimeError(
            f"generate_tracklets.py exited with code {result.returncode}\n"
            f"--- captured output ---\n{result.stdout}")

    # ── locate output ─────────────────────────────────────────
    GTA_TRACKLETS_DIR = os.path.join(
        os.path.dirname(GTA_PRED_DIR),
        f"{GTA_TRACKER_NAME}_Tracklets_{os.path.basename(GTA_DATA_PATH)}"
    )
    pkl_files = _glob.glob(os.path.join(GTA_TRACKLETS_DIR, "*.pkl"))
    print(f"\nTracklets dir : {GTA_TRACKLETS_DIR}")
    print(f"  .pkl files  : {pkl_files}")
    if not pkl_files:
        raise FileNotFoundError(f"No .pkl files found in {GTA_TRACKLETS_DIR}")


### Step 3 — Phase 2: Refine Tracklets
Splits identity switches and merges fragmented tracks. This cell can be re-run independently after tuning the split / connect parameters in the config cell above.

In [None]:
import subprocess, glob as _glob, os
from pathlib import Path

if not GTA_ENABLED:
    print("GTA_ENABLED is False — skipping Phase 2.")
else:
    env = {**os.environ,
           "PYTHONPATH": "/content/gta-link/reid:" + os.environ.get("PYTHONPATH", "")}

    cmd = [
        sys.executable,
        "refine_tracklets.py",
        "--dataset",          GTA_DATASET_NAME,
        "--tracker",          GTA_TRACKER_NAME,
        "--track_src",        GTA_TRACKLETS_DIR,
        "--eps",              str(GTA_EPS),
        "--min_samples",      str(GTA_MIN_SAMPLES),
        "--max_k",            str(GTA_MAX_K),
        "--min_len",          str(GTA_MIN_LEN),
        "--merge_dist_thres", str(GTA_MERGE_DIST_THRES),
        "--spatial_factor",   str(GTA_SPATIAL_FACTOR),
    ]
    if GTA_USE_SPLIT:
        cmd.append("--use_split")
    if GTA_USE_CONNECT:
        cmd.append("--use_connect")

    print("Running Phase 2:", " ".join(cmd))
    result = subprocess.run(cmd, cwd="/content/gta-link", env=env,
                            stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    print(result.stdout)          # always surface full output
    if result.returncode != 0:
        raise RuntimeError(
            f"refine_tracklets.py exited with code {result.returncode}\n"
            f"--- captured output ---\n{result.stdout}")

    # ── locate refined .txt via glob + mtime ─────────────────
    # Output dir pattern: <parent>/<TRACKER>_<DATASET>_*/<SEQ_NAME>.txt
    parent = os.path.dirname(GTA_TRACKLETS_DIR)  # /content/gta_input
    pattern = os.path.join(parent, f"{GTA_TRACKER_NAME}_{GTA_DATASET_NAME}_*", f"{GTA_SEQ_NAME}.txt")
    candidates = _glob.glob(pattern)
    if not candidates:
        raise FileNotFoundError(f"No refined .txt matched pattern: {pattern}")
    # pick most recently modified
    GTA_REFINED_TXT = max(candidates, key=os.path.getmtime)
    print(f"\nRefined MOT .txt : {GTA_REFINED_TXT}")


### Step 4 — Merge, Re-render & Summarize
Merges refined tracks with passthrough tracks (classes not sent to GTA-Link), re-renders the video with the merged result, and prints a before/after summary.

In [None]:
import cv2
import numpy as np
from collections import defaultdict

if not GTA_ENABLED:
    print("GTA_ENABLED is False — skipping Step 4.")
else:
    # ── class metadata (mirrors config cell) ─────────────────
    CLASS_NAMES = ["Center Ice", "Faceoff", "Goalpost", "Goaltender",
                   "Player", "Puck", "Referee"]
    CLASS_COLORS_BGR = {                        # BGR, same as config
        3: (255, 165,   0),   # Goaltender – orange
        4: (255,   0,   0),   # Player     – blue
        5: (  0, 255,   0),   # Puck       – green
        6: (  0,   0, 255),   # Referee    – red
    }

    # ── helper: IoU ───────────────────────────────────────────
    def _iou(box_a, box_b):
        """IoU between two [l, t, w, h] boxes."""
        a_x1, a_y1 = box_a[0], box_a[1]
        a_x2, a_y2 = a_x1 + box_a[2], a_y1 + box_a[3]
        b_x1, b_y1 = box_b[0], box_b[1]
        b_x2, b_y2 = b_x1 + box_b[2], b_y1 + box_b[3]
        inter_x1 = max(a_x1, b_x1);  inter_y1 = max(a_y1, b_y1)
        inter_x2 = min(a_x2, b_x2);  inter_y2 = min(a_y2, b_y2)
        inter   = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
        union   = (box_a[2]*box_a[3]) + (box_b[2]*box_b[3]) - inter
        return inter / union if union > 0 else 0.0

    # ── 1. Parse refined .txt ─────────────────────────────────
    refined_by_frame = defaultdict(list)   # frame_id -> [(tid, [l,t,w,h])]
    with open(GTA_REFINED_TXT) as f:
        for line in f:
            parts = line.strip().split(",")
            fid  = int(float(parts[0]))
            tid  = int(float(parts[1]))
            bbox = [float(parts[2]), float(parts[3]),
                    float(parts[4]), float(parts[5])]
            refined_by_frame[fid].append((tid, bbox))

    # ── 2. Build passthrough (classes NOT refined) ───────────
    passthrough_by_frame = defaultdict(list)  # frame_id -> [(tid, bbox, class_id)]
    orig_by_frame        = defaultdict(list)   # frame_id -> [(tid, bbox, class_id)] — ALL original
    all_orig_tids        = set()
    for frame_id, tracks in all_tracks:
        for t in tracks:
            tlwh = t.tlwh.tolist()
            orig_by_frame[frame_id].append((t.track_id, tlwh, t.class_id))
            all_orig_tids.add(t.track_id)
            if t.class_id not in GTA_CLASSES_TO_REFINE_LIST:
                passthrough_by_frame[frame_id].append((t.track_id, tlwh, t.class_id))

    # ── 3. Recover class labels for refined tracks via IoU ───
    # Build original refined-class detections per frame for matching
    orig_refined_by_frame = defaultdict(list)  # frame_id -> [(tid, bbox, class_id)]
    for frame_id, tracks in all_tracks:
        for t in tracks:
            if t.class_id in GTA_CLASSES_TO_REFINE_LIST:
                orig_refined_by_frame[frame_id].append((t.track_id, t.tlwh.tolist(), t.class_id))

    refined_tid_class = {}  # refined_tid -> class_id (cached after first match)
    for fid in sorted(refined_by_frame.keys()):
        for (rtid, rbbox) in refined_by_frame[fid]:
            if rtid in refined_tid_class:
                continue
            best_iou, best_class = 0.0, None
            for (_otid, obbox, ocls) in orig_refined_by_frame.get(fid, []):
                iou = _iou(rbbox, obbox)
                if iou > best_iou:
                    best_iou   = iou
                    best_class = ocls
            if best_iou > 0.1:
                refined_tid_class[rtid] = best_class

    print(f"Class labels recovered for {len(refined_tid_class)} / {len(set(tid for dets in refined_by_frame.values() for tid, _ in dets))} refined track IDs")

    # ── 4. Remap refined track IDs to avoid collision ────────
    id_offset   = max(all_orig_tids) + 1 if all_orig_tids else 1
    refined_tids_sorted = sorted(set(tid for dets in refined_by_frame.values() for tid, _ in dets))
    remap = {old: old + id_offset for old in refined_tids_sorted}
    print(f"Refined ID range remapped: [{min(refined_tids_sorted)}, {max(refined_tids_sorted)}] "
          f"-> [{min(remap.values())}, {max(remap.values())}]")

    # ── 5. Write merged MOT .txt ──────────────────────────────
    MERGED_MOT_PATH = "/content/results_gta_merged.txt"
    merged_lines = []
    # passthrough tracks — original IDs
    for fid, dets in passthrough_by_frame.items():
        for (tid, bbox, _cid) in dets:
            # recover score from original
            merged_lines.append(f"{fid},{tid},{bbox[0]:.2f},{bbox[1]:.2f},{bbox[2]:.2f},{bbox[3]:.2f},1,-1,-1,-1")
    # refined tracks — remapped IDs
    for fid, dets in refined_by_frame.items():
        for (tid, bbox) in dets:
            new_tid = remap[tid]
            merged_lines.append(f"{fid},{new_tid},{bbox[0]:.2f},{bbox[1]:.2f},{bbox[2]:.2f},{bbox[3]:.2f},1,-1,-1,-1")
    # sort by frame
    merged_lines.sort(key=lambda l: (int(l.split(',')[0]), int(l.split(',')[1])))
    with open(MERGED_MOT_PATH, "w") as f:
        f.write("\n".join(merged_lines) + "\n")
    print(f"Merged MOT .txt written to {MERGED_MOT_PATH}  ({len(merged_lines)} lines)")

    # ── 6. Re-render video ────────────────────────────────────
    OUTPUT_GTA_VIDEO = "/content/output_gta_refined.mp4"
    cap = cv2.VideoCapture(INPUT_VIDEO)
    fps  = cap.get(cv2.CAP_PROP_FPS)
    W    = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    H    = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(OUTPUT_GTA_VIDEO, fourcc, fps, (W, H))

    # index merged tracks by frame
    merged_by_frame = defaultdict(list)  # frame_id -> [(tid, bbox, class_id, is_refined)]
    for fid, dets in passthrough_by_frame.items():
        for (tid, bbox, cid) in dets:
            merged_by_frame[fid].append((tid, bbox, cid, False))
    for fid, dets in refined_by_frame.items():
        for (tid, bbox) in dets:
            cid = refined_tid_class.get(tid, -1)
            merged_by_frame[fid].append((remap[tid], bbox, cid, True))

    from tqdm import tqdm as _tqdm
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    for frame_idx in _tqdm(range(total_frames), desc="Re-rendering"):
        ret, frame = cap.read()
        if not ret:
            break
        fid = frame_idx + 1
        for (tid, bbox, cid, is_refined) in merged_by_frame.get(fid, []):
            l, t, w, h = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
            color = CLASS_COLORS_BGR.get(cid, (0, 255, 0))
            cv2.rectangle(frame, (l, t), (l+w, t+h), color, 2)
            cls_name = CLASS_NAMES[cid] if 0 <= cid < len(CLASS_NAMES) else f"cls{cid}"
            label = f"{cls_name} {tid}" + (" *" if is_refined else "")
            # label background
            (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
            cv2.rectangle(frame, (l, t - th - 2), (l + tw + 4, t), color, -1)
            # text — black or white for contrast
            b, g, r = color
            lum = 0.299*r + 0.587*g + 0.114*b
            txt_color = (0, 0, 0) if lum > 128 else (255, 255, 255)
            cv2.putText(frame, label, (l + 2, t - 3),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.55, txt_color, 1)
        out.write(frame)
    cap.release()
    out.release()
    print(f"Re-rendered video saved to {OUTPUT_GTA_VIDEO}")

    # ── 7. Summary table ──────────────────────────────────────
    # Before: count unique track_ids per class from all_tracks
    before_counts = defaultdict(set)  # class_id -> set of track_ids
    for frame_id, tracks in all_tracks:
        for t in tracks:
            before_counts[t.class_id].add(t.track_id)

    # After (refined classes): unique remapped IDs; passthrough classes unchanged
    after_refined_tids = defaultdict(set)
    for tid in refined_tids_sorted:
        cid = refined_tid_class.get(tid, -1)
        if cid >= 0:
            after_refined_tids[cid].add(remap[tid])

    print("\n" + "=" * 60)
    print(" GTA-Link Refinement Summary")
    print("=" * 60)
    print(f"{'Class':<14} {'Before':>8} {'After':>8} {'Delta':>7}  Tag")
    print("-" * 60)
    for cid in sorted(set(list(before_counts.keys()) + list(after_refined_tids.keys()))):
        name = CLASS_NAMES[cid] if 0 <= cid < len(CLASS_NAMES) else f"class_{cid}"
        before_n = len(before_counts.get(cid, set()))
        if cid in GTA_CLASSES_TO_REFINE_LIST:
            after_n = len(after_refined_tids.get(cid, set()))
            tag = "(refined)"
        else:
            after_n = before_n
            tag = "(passthrough)"
        delta = after_n - before_n
        sign  = "+" if delta > 0 else ""
        print(f"{name:<14} {before_n:>8} {after_n:>8} {sign}{delta:>6}  {tag}")
    print("=" * 60)


## Cell 8: Compress and Display Output Video

In [None]:
%%capture
# Compress video for notebook display
OUTPUT_COMPRESSED = "/content/output_compressed.mp4"
!ffmpeg -y -i {config.output_video_path} -vcodec libx264 -crf 28 {OUTPUT_COMPRESSED}

print(f"Compressed video saved to: {OUTPUT_COMPRESSED}")

In [None]:
from IPython.display import HTML
from base64 import b64encode

OUTPUT_COMPRESSED = "/content/output_compressed.mp4"

# Read and encode video
mp4 = open(OUTPUT_COMPRESSED, 'rb').read()
data_url = f"data:video/mp4;base64,{b64encode(mp4).decode()}"

# Display video
HTML(f'''
<video width="800" controls>
  <source src="{data_url}" type="video/mp4">
</video>
''')

## Cell 9: Frame-by-Frame Processing (Optional)

For more control, process frames individually using the generator API.

In [None]:
# Example: Process frame-by-frame with generator and access mask info
# Uncomment to run

# from bananatracker import BananaTrackerPipeline
# 
# pipeline = BananaTrackerPipeline(config)
# 
# for frame_id, frame, tracks, vis_frame in pipeline.process_video_generator(INPUT_VIDEO):
#     # Get track info as dictionaries
#     track_info = pipeline.get_track_info(tracks)
#     
#     # Access mask data
#     prediction_mask = pipeline.prediction_mask
#     tracklet_mask_dict = pipeline.tracklet_mask_dict
#     
#     # Process each track
#     for info in track_info:
#         track_id = info['track_id']
#         mask_id = tracklet_mask_dict.get(track_id, None)
#         print(f"Frame {frame_id}: Track {track_id} - {info['class_name']} - Mask ID: {mask_id}")
#     
#     # Stop after 10 frames for demo
#     if frame_id >= 10:
#         break

## Cell 10: Disable Masks (Performance Mode)

If you need faster processing without mask enhancement, you can disable the mask module.

In [None]:
# Create config without masks for faster processing
# Uses the same optimized tracking parameters
config_fast = BananaTrackerConfig(
    yolo_weights="/content/HockeyAI_model_weight.pt",
    class_names=["Center Ice", "Faceoff", "Goalpost", "Goaltender", "Player", "Puck", "Referee"],
    track_classes=[3, 4, 5, 6],
    special_classes=[5],
    detection_conf_thresh=0.4,    # Lowered to catch more objects
    
    # Tracker settings - Optimized for stability
    track_thresh=0.5,    # Lowered for more first-pass matches
    track_buffer=90,     # 3 seconds at 30fps for better occlusion handling
    cmc_method="orb",
    
    # DISABLE mask module for speed
    enable_masks=False,
    
    # Output
    output_video_path="/content/output_fast.mp4",
    device="cuda:0",
)

# pipeline_fast = BananaTrackerPipeline(config_fast)
# all_tracks_fast = pipeline_fast.process_video(INPUT_VIDEO)

print("Fast mode config created (masks disabled)")
print(f"  - detection_conf_thresh: {config_fast.detection_conf_thresh}")
print(f"  - track_thresh: {config_fast.track_thresh}")
print(f"  - track_buffer: {config_fast.track_buffer} frames")
print("Uncomment the lines above to run without mask enhancement")

## Architecture Overview

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                           BananaTracker Pipeline                             │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  Input Frame                                                                 │
│      │                                                                       │
│      ▼                                                                       │
│  ┌─────────┐                                                                 │
│  │ YOLOv8  │ ─────────────────────────────────────────────────────┐         │
│  │Detector │                                                       │         │
│  └────┬────┘                                                       │         │
│       │ Detections [x1,y1,x2,y2,conf,class]                       │         │
│       ▼                                                            │         │
│  ┌──────────────┐                                                  │         │
│  │ BananaTracker│◄─────────── Mask-Enhanced ───────────────────┐   │         │
│  │  (ByteTrack) │             Cost Matrix                      │   │         │
│  └──────┬───────┘                                              │   │         │
│         │ Tracks, New Tracks, Removed IDs                      │   │         │
│         ▼                                                      │   │         │
│  ┌──────────────────────────────────────────────────────────┐  │   │         │
│  │                    MaskManager                            │  │   │         │
│  │  ┌───────────┐         ┌─────────┐                       │  │   │         │
│  │  │  SAM2.1   │─────────│  Cutie  │                       │  │   │         │
│  │  │ (Initial  │  Seed   │(Temporal│                       │  │   │         │
│  │  │  Masks)   │  Masks  │ Propag.)│                       │  │   │         │
│  │  └───────────┘         └────┬────┘                       │  │   │         │
│  │                             │                             │  │   │         │
│  │                     prediction_mask                       │──┘   │         │
│  │                    tracklet_mask_dict                     │      │         │
│  │                    mask_avg_prob_dict                     │      │         │
│  └──────────────────────────────────────────────────────────┘      │         │
│                                                                     │         │
│  ┌─────────────┐                                                   │         │
│  │ Visualizer  │◄──────────────────────────────────────────────────┘         │
│  │ (with mask  │                                                             │
│  │  overlay)   │                                                             │
│  └──────┬──────┘                                                             │
│         │                                                                    │
│         ▼                                                                    │
│    Output Frame                                                              │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘
```

### Key Components:

1. **SAM2.1**: Called ONCE per new tracklet to create pixel-precise mask from bounding box
2. **Cutie**: Called EVERY frame to propagate masks temporally
3. **Mask-Enhanced Cost Matrix**: Uses `mc` (mask coverage) and `mf` (mask fill) metrics to improve association when IoU is ambiguous