# Elite Single-Player Re-Identification (v9.0 - Cerebrus Tracker with Continuity-Guided Re-ID)

This version addresses the critical failure of the Re-ID system under motion. It introduces a new core logic: **Continuity-Guided Re-Identification**. This system now heavily prioritizes physical plausibility, ensuring the tracker can maintain a lock on a target even when their appearance changes due to motion blur or different angles.

### Core Architecture Enhancements:
- **Detector**: YOLOv8-Large
- **Re-ID Model**: OpenCLIP ViT-H-14
- **Cerebrus Tracking Algorithm v9**:
  - **Continuity-Guided Re-ID (NEW)**: The tracker now calculates a 'continuity score' based on a candidate's proximity to the target's last known position. This score is fused with the visual similarity scores.
  - **Intelligent Fused Score (NEW)**: The final matching decision is now based on a combination of visual similarity AND physical plausibility. The tracker will strongly prefer a candidate in a logical position, making it highly resilient to motion blur and pose changes.
  - **Stable Visual Lock**: The system remains free of unstable predictive models, relying on direct visual evidence guided by the new continuity logic.

In [None]:
# STEP 1: Install All Dependencies
print("Installing system and Python dependencies...")
!sudo apt-get update -qq
!sudo apt-get install -y tesseract-ocr -qq

# Force-reinstall key libraries with compatible versions to avoid conflicts.
!pip -q install "numpy==1.26.4" "numba==0.58.1" --force-reinstall
!pip -q install "torch>=2.3.0" "torchvision" --index-url https://download.pytorch.org/whl/cu121
!pip -q install "ultralytics>=8.0.0" "scipy>=1.10.0" "pytesseract>=0.3.10" "scikit-image>=0.21.0"
!pip -q install "open-clip-torch>=2.24.0"
!pip -q install "pillow>=10.0.0" "tqdm>=4.66.0" "opencv-python-headless>=4.9.0" "matplotlib"

print("\n--- Verifying Installations ---")
try:
    import numpy as np, cv2, torch, open_clip, ultralytics, scipy, pytesseract, skimage, matplotlib
    from ultralytics import YOLO
    print(f"NumPy version: {np.__version__}")
    print(f"OpenCV version: {cv2.__version__}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"Tesseract OCR version: {pytesseract.get_tesseract_version()}")
    print(f"Ultralytics (YOLOv8) version: {ultralytics.__version__}")
    print("All libraries loaded successfully!")
except Exception as e:
    print(f"Error during verification: {e}")

In [None]:
# STEP 2: Initialize All Models & Advanced Utility Functions
import torch, open_clip, cv2, pytesseract
from ultralytics import YOLO
from PIL import Image
import numpy as np
from typing import Optional, List, Dict, Any
from skimage.metrics import structural_similarity as ssim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

print("Loading YOLOv8-Large model for player detection...")
player_detector = YOLO('yolov8l.pt')
player_detector.to(device)

print("Loading YOLOv8n model for ball detection...")
ball_detector = YOLO('yolov8n.pt')
ball_detector.to(device)

print("Loading OpenCLIP ViT-H/14 model for Re-ID...")
reid_model, _, reid_preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k', device=device)
reid_model.eval()

# --- Advanced Feature Extraction Functions (no changes) ---
def get_embedding(crop_bgr: np.ndarray) -> Optional[np.ndarray]:
    if crop_bgr.size == 0: return None
    try:
        rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(rgb)
        img_tensor = reid_preprocess(pil_img).unsqueeze(0).to(device)
        with torch.no_grad(), torch.amp.autocast(device_type=device, dtype=torch.float16):
            features = reid_model.encode_image(img_tensor)
            features /= features.norm(dim=-1, keepdim=True)
        return features.cpu().numpy().squeeze()
    except Exception: return None

def get_color_hist(crop_bgr: np.ndarray) -> Optional[np.ndarray]:
    if crop_bgr.size == 0: return None
    try:
        lab_crop = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2LAB)
        hist = cv2.calcHist([lab_crop], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
        cv2.normalize(hist, hist)
        return hist.flatten()
    except Exception: return None

def get_ssim(crop1_bgr: np.ndarray, crop2_bgr: np.ndarray) -> float:
    if crop1_bgr.size == 0 or crop2_bgr.size == 0: return 0.0
    try:
        h, w, _ = crop1_bgr.shape
        crop2_resized = cv2.resize(crop2_bgr, (w, h))
        gray1 = cv2.cvtColor(crop1_bgr, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(crop2_resized, cv2.COLOR_BGR2GRAY)
        return ssim(gray1, gray2)
    except Exception: return 0.0

def ocr_jersey_number(crop_bgr: np.ndarray) -> Optional[str]:
    if crop_bgr.size == 0 or crop_bgr.shape[0] < 30 or crop_bgr.shape[1] < 30: return None
    try:
        h, w, _ = crop_bgr.shape
        torso = crop_bgr[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)]
        gray = cv2.cvtColor(torso, cv2.COLOR_BGR2GRAY)
        gray_resized = cv2.resize(gray, (150, 75), interpolation=cv2.INTER_CUBIC)
        blurred = cv2.GaussianBlur(gray_resized, (3,3), 0)
        sharpened = cv2.addWeighted(gray_resized, 1.5, blurred, -0.5, 0)
        _, thresh = cv2.threshold(sharpened, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        config = "--psm 8 -c tessedit_char_whitelist=0123456789"
        text = pytesseract.image_to_string(thresh, config=config, timeout=1).strip()
        return text if text.isdigit() else None
    except Exception: return None

print("All models and utility functions initialized successfully!")

In [None]:
# STEP 3: Define the Cerebrus Tracking System v9 with Continuity-Guided Re-ID
from scipy.spatial.distance import cdist
import collections

def iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    iou_val = interArea / float(boxAArea + boxBArea - interArea)
    return iou_val if not np.isnan(iou_val) else 0.0

class CerebrusTrack:
    def __init__(self, bbox, track_id, data, is_manual_init=False):
        self.id = track_id
        self.bbox = np.array(bbox)
        self.time_since_update = 0
        self.hits = 1
        self.state = 'confirmed' if is_manual_init else 'tentative'
        self.feature_ema = data['emb']
        self.color_hist_ema = data['hist']
        self.last_known_crop = data['crop']
        self.jersey_number = data.get('jersey')
        self.confirmed_number = self.jersey_number if is_manual_init else None
        self.number_hits = collections.defaultdict(int)
        if self.jersey_number: self.number_hits[self.jersey_number] += 5 if is_manual_init else 1

    def update(self, bbox, data):
        self.time_since_update = 0
        self.hits += 1
        self.bbox = np.array(bbox)
        if self.state == 'lost': self.state = 'confirmed'
        if self.hits > 5 and self.state == 'tentative': self.state = 'confirmed'
        self.feature_ema = 0.9 * self.feature_ema + 0.1 * data['emb']
        self.color_hist_ema = 0.9 * self.color_hist_ema + 0.1 * data['hist']
        self.last_known_crop = data['crop']
        if data.get('jersey'):
            self.number_hits[data['jersey']] += 1
            if self.number_hits[data['jersey']] > 5: self.confirmed_number = data['jersey']

class CerebrusTracker:
    def __init__(self, max_age=90):
        self.max_age = max_age
        # The base thresholds can be slightly more forgiving now
        self.reid_thresh = 0.7
        self.color_thresh = 0.6
        self.ssim_thresh = 0.2
        self.track_id_counter = 1
        self.target_track = None
        self.ball_pos = None

    def initialize_target(self, initial_data):
        self.target_track = CerebrusTrack(initial_data['bbox'], self.track_id_counter, initial_data, is_manual_init=True)
        self.track_id_counter += 1

    def update(self, frame, p_data, b_detections, target_jersey_number=None):
        if b_detections.any():
            best_ball = b_detections[np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in b_detections])]
            self.ball_pos = (int((best_ball[0]+best_ball[2])/2), int((best_ball[1]+best_ball[3])/2))

        if not self.target_track and target_jersey_number is not None:
            if p_data:
                found_player = None
                if target_jersey_number:
                    for d in p_data:
                        if d.get('jersey') == target_jersey_number: found_player = d; break
                else:
                    pitch_center_x = frame.shape[1] / 2
                    centrality = [abs((d['bbox'][0]+d['bbox'][2])/2 - pitch_center_x) for d in p_data]
                    found_player = p_data[np.argmin(centrality)]
                if found_player: self.target_track = CerebrusTrack(found_player['bbox'], self.track_id_counter, found_player)

        if not self.target_track: return [], self.ball_pos

        self.target_track.time_since_update += 1
        if self.target_track.time_since_update > self.max_age:
            self.target_track = None
            return [], self.ball_pos
        
        match_found = False
        if p_data:
            last_box = self.target_track.bbox
            w, h = last_box[2] - last_box[0], last_box[3] - last_box[1]
            # Define a generous search area based on player size
            search_area = [last_box[0] - w*1.5, last_box[1] - h*1.5, last_box[2] + w*1.5, last_box[3] + h*1.5]
            
            candidate_indices = [i for i, d in enumerate(p_data) if iou(d['bbox'], search_area) > 0]
            
            if candidate_indices:
                candidate_detections = [p_data[i] for i in candidate_indices]
                
                reid_dists = cdist(np.array([self.target_track.feature_ema]), np.array([d['emb'] for d in candidate_detections]), 'cosine').flatten()
                color_corrs = np.array([cv2.compareHist(self.target_track.color_hist_ema, d['hist'], cv2.HISTCMP_CORREL) for d in candidate_detections])
                ssim_scores = np.array([get_ssim(self.target_track.last_known_crop, d['crop']) for d in candidate_detections])
                
                # --- Continuity Score Calculation ---
                last_center = np.array([(last_box[0] + last_box[2]) / 2, (last_box[1] + last_box[3]) / 2])
                candidate_centers = np.array([((d['bbox'][0] + d['bbox'][2]) / 2, (d['bbox'][1] + d['bbox'][3]) / 2) for d in candidate_detections])
                # Normalize distance by player height to make it scale-invariant
                distances = np.linalg.norm(candidate_centers - last_center, axis=1) / h
                continuity_scores = np.exp(-distances) # Closer gets a score near 1.0, farther near 0.0

                # --- Intelligent Fused Score ---
                # Lower is better. We heavily weight the continuity score.
                fused_scores = (reid_dists) + (1 - color_corrs) + (1 - ssim_scores) - (continuity_scores * 0.5)
                best_candidate_idx = np.argmin(fused_scores)
                
                # Use a single, fused score threshold for a more holistic decision
                if fused_scores[best_candidate_idx] < 1.5:
                    original_idx = candidate_indices[best_candidate_idx]
                    d = p_data[original_idx]
                    self.target_track.update(d['bbox'], d)
                    match_found = True

        if not match_found:
            self.target_track.state = 'lost'

        output_tracks = [self.target_track] if self.target_track and self.target_track.state != 'tentative' else []
        return output_tracks, self.ball_pos

print("Cerebrus tracking system v9 defined successfully!")

In [None]:
# STEP 4: Define the Main Video Processing Function
from tqdm.notebook import tqdm

def process_video_final(
    input_path: str, output_path: str, highlight_path: str, mode: str,
    initial_target_data: Optional[Dict[str, Any]] = None, target_jersey: Optional[str] = None
):
    cap = cv2.VideoCapture(input_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    out_full = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    out_high = cv2.VideoWriter(highlight_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    tracker = CerebrusTracker(max_age=int(fps*3))
    start_frame = 0
    if mode == 'Manual' and initial_target_data:
        tracker.initialize_target(initial_target_data)
        cap.read()
        start_frame = 1

    for frame_idx in tqdm(range(start_frame, total_frames), desc='Processing Video'):
        ret, frame = cap.read()
        if not ret: break

        p_results = player_detector(frame, classes=[0], verbose=False, conf=0.4)
        p_detections = p_results[0].boxes.xyxy.cpu().numpy()
        b_results = ball_detector(frame, classes=[32], verbose=False, conf=0.5)
        b_detections = b_results[0].boxes.xyxy.cpu().numpy()
        
        p_data = []
        for i, (x1, y1, x2, y2) in enumerate(p_detections):
            crop = frame[int(y1):int(y2), int(x1):int(x2)]
            emb, hist = get_embedding(crop), get_color_hist(crop)
            if emb is not None and hist is not None:
                p_data.append({'bbox': p_detections[i], 'emb': emb, 'hist': hist, 'jersey': ocr_jersey_number(crop), 'crop': crop})
        
        auto_target = target_jersey if mode == 'Automatic' else None
        if mode == 'Automatic' and target_jersey == '': auto_target = ''
        
        tracked_players, ball_pos = tracker.update(frame, p_data, b_detections, target_jersey_number=auto_target)

        is_highlight_frame = False
        if tracked_players:
            track = tracked_players[0]
            x1, y1, x2, y2 = map(int, track.bbox)
            label = f"PLAYER {track.confirmed_number}" if track.confirmed_number else "TARGET PLAYER"
            state_label = f"STATE: {track.state.upper()}"
            
            color = (0, 255, 0)
            if track.state == 'lost': 
                color = (0, 165, 255)
            elif track.state == 'tentative': 
                color = (255, 255, 0)
            
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
            cv2.putText(frame, label, (x1, y1 - 35), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
            cv2.putText(frame, state_label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
            is_highlight_frame = True

        if ball_pos:
             cv2.circle(frame, ball_pos, 15, (0, 255, 255), -1)
             cv2.circle(frame, ball_pos, 15, (0,0,0), 2)
             if tracked_players:
                 player_center = ((x1+x2)//2, (y1+y2)//2)
                 if np.linalg.norm(np.array(player_center) - np.array(ball_pos)) < 250: is_highlight_frame = True

        if is_highlight_frame:
            out_high.write(frame)
        
        out_full.write(frame)

    cap.release(); out_full.release(); out_high.release()
    print("\n--- Processing Complete ---")
    print(f"Full video saved to: {output_path}")
    print(f"Highlights video saved to: {highlight_path}")

In [None]:
# STEP 5: Upload Video and Select Tracking Mode & Target
from google.colab import files
import matplotlib.pyplot as plt
import io

print('Upload your video file (.mp4, .mov, etc.)...')
uploaded_video = files.upload()
input_video_path = None
first_frame_data = []
selected_player_data = None
TARGET_JERSEY_NUMBER = ''

#@markdown ### 1. Choose Tracking Mode
TRACKING_MODE = 'Manual' #@param ["Manual", "Automatic"]

if uploaded_video:
    input_video_path = f'/content/{next(iter(uploaded_video))}'
    print(f"\nVideo ready: {input_video_path}")
    
    if TRACKING_MODE == 'Manual':
        cap = cv2.VideoCapture(input_video_path)
        ret, frame = cap.read()
        if ret:
            p_results = player_detector(frame, classes=[0], verbose=False, conf=0.3)
            p_detections = p_results[0].boxes.xyxy.cpu().numpy()
            for i, (x1, y1, x2, y2) in enumerate(p_detections):
                crop = frame[int(y1):int(y2), int(x1):int(x2)]
                emb, hist = get_embedding(crop), get_color_hist(crop)
                if emb is not None and hist is not None:
                    first_frame_data.append({'id': i + 1, 'bbox': p_detections[i], 'emb': emb, 'hist': hist, 'jersey': ocr_jersey_number(crop), 'crop': crop})
            
            annotated_frame = frame.copy()
            for player in first_frame_data:
                x1, y1, x2, y2 = map(int, player['bbox'])
                cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(annotated_frame, f"ID: {player['id']}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0,0,255), 3)
            
            print("\n--- SELECT YOUR TARGET PLAYER ---")
            annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
            plt.figure(figsize=(16, 9)); plt.imshow(annotated_frame_rgb); plt.axis('off'); plt.title('Enter the ID of the player you want to track below'); plt.show()
        cap.release()
    else:
        print("\n--- AUTOMATIC MODE SELECTED ---")
        print("Configure jersey number below. Leave blank to track the most central player.")
else:
    print("\nNo video file was uploaded. Please run this cell again.")

#@markdown --- 
#@markdown ### 2. Configure Target (based on mode)
#@markdown - If **Manual**, enter the ID of the Target Player from the image above.
#@markdown - If **Automatic**, (optional) enter the Jersey Number to track.
TARGET_PLAYER_ID = 1 #@param {type:"integer"}
TARGET_JERSEY_NUMBER = "" #@param {type:"string"}

In [None]:
# STEP 6: Run Processing and Download Results

if input_video_path:
    output_video_path = '/content/tracked_output_cerebrus_v9.mp4'
    highlight_video_path = '/content/highlights_cerebrus_v9.mp4'
    
    if TRACKING_MODE == 'Manual':
        if not first_frame_data:
            print("Error: Manual mode selected but no players were detected in the first frame. Try Automatic mode.")
        else:
            selected_player_data = next((p for p in first_frame_data if p['id'] == TARGET_PLAYER_ID), None)
            if selected_player_data:
                print(f"Target locked: Player ID {TARGET_PLAYER_ID}. Starting tracking...")
                process_video_final(
                    input_path=input_video_path, output_path=output_video_path, highlight_path=highlight_video_path,
                    mode='Manual', initial_target_data=selected_player_data
                )
            else:
                print(f"Error: Player with ID {TARGET_PLAYER_ID} not found. Please check the ID and run the cell again.")
    
    elif TRACKING_MODE == 'Automatic':
        if TARGET_JERSEY_NUMBER:
            print(f"Attempting to automatically track player with jersey number: {TARGET_JERSEY_NUMBER}")
        else:
            print("No jersey number specified. Will automatically track the most central player.")
        
        process_video_final(
            input_path=input_video_path, output_path=output_video_path, highlight_path=highlight_video_path,
            mode='Automatic', target_jersey=TARGET_JERSEY_NUMBER
        )

    import os
    if os.path.exists(output_video_path):
        print("\n--- Download Your Files ---")
        try:
            print("Downloading highlights video...")
            files.download(highlight_video_path)
            print("Downloading full tracked video...")
            files.download(output_video_path)
        except Exception as e:
            print(f"Could not trigger automatic download. Error: {e}")
else:
    print("Cannot run processing. Please upload a video and select a mode in the cell above.")