<a href="https://colab.research.google.com/github/Arbaaz-Tanveer/Cricket-Highlight-Generator/blob/main/generator_with_yolo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cricket Highlights Generator - Combined Audio and OCR Approach with Scene Change Adjustment and YOLO Scoreboard Cropping

# Install dependencies
!pip install pydub opencv-python-headless matplotlib numpy tqdm paddleocr paddlepaddle ultralytics -q
!apt-get install ffmpeg -qq

import cv2
import os
import glob
import re
import logging
import numpy as np
import matplotlib.pyplot as plt
from pydub import AudioSegment
from tqdm import tqdm
import time
from google.colab import files
from paddleocr import PaddleOCR
from ultralytics import YOLO  # Import YOLO from ultralytics

# -------------------- PATH SETUP --------------------
content_dir = '/content'
output_dir = os.path.join(content_dir, 'highlights')
temp_dir = os.path.join(content_dir, 'temp')
os.makedirs(output_dir, exist_ok=True)
os.makedirs(temp_dir, exist_ok=True)

# Clear previous output
for f in glob.glob(os.path.join(output_dir, "*")):
    os.remove(f)
for f in glob.glob(os.path.join(temp_dir, "*")):
    os.remove(f)

# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger('CricketHighlights')

# Path to uploaded video
video_path = '/content/clip.mp4'
if not os.path.exists(video_path):
    print(f"❌ Error: {video_path} not found!")
else:
    print(f"✅ Found video: {video_path}")
    cap = cv2.VideoCapture(video_path)
    if cap.isOpened():
        w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        fc = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        dur = fc / fps
        print(f"Dimensions: {w}x{h}, Duration: {dur:.1f}s, FPS: {fps}")
        cap.release()

# -------------------- OCR FUNCTIONS --------------------
def clean_ocr_text(text):
    """Strip off parentheses and non‑score characters."""
    text = re.sub(r'\(.*$', '', text)        # drop from "(" onward
    text = re.sub(r'[^0-9A-Z\-\/\s]', '', text) # keep digits, uppercase letters, dash, slash and spaces
    return text

def preprocess_for_ocr(cropped):
    # Check current dimensions
    h, w = cropped.shape[:2]

    # Determine the scaling factor so that the smallest side becomes at least 500 pixels.
    min_side = min(h, w)
    if min_side < 500:
        scale = 500 / min_side
        new_w = int(w * scale)
        new_h = int(h * scale)
        cropped = cv2.resize(cropped, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    print(scale)
    # Convert to grayscale
    gray = cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY)

    # Apply Otsu's thresholding to binarize the image
    _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Alternatively, you can try adaptive thresholding if needed:
    # thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
    #                                cv2.THRESH_BINARY, 11, 2)

    return thresh


def extract_score(text):
    """
    Find runs-wickets (or runs/wickets) in cleaned text.
    Returns a string like "10-1" or None if not found.
    """
    txt = clean_ocr_text(text)
    m = re.search(r'(\d{1,3})\s*[-\/]\s*(\d{1,2})', txt)
    if not m:
        return None
    runs, wkts = int(m.group(1)), int(m.group(2))
    if runs >= 0 and 0 <= wkts <= 10:
        return f"{runs}-{wkts}"
    return None

def extract_team_names(text):
    """
    Extract team names by checking for a "v" or "vs" pattern.
    For example, from "PAK v IND" it returns "PAK v IND".
    """
    m = re.search(r'([A-Z]{2,3})\s*(?:v|vs)\s*([A-Z]{2,3})', text)
    if m:
        return f"{m.group(1)} v {m.group(2)}"
    teams = ['IND','PAK','AUS','ENG','NZ','SA','WI','SL','BAN','AFG','ZIM','IRE']
    for t in teams:
        if t in text:
            return t
    return None

def extract_score_and_teams(text):
    """
    Attempt to extract both the team names and the score from text formatted like:
    "PAK v IND 10 - 1"
    Returns a tuple (score, team_str) or (None, None) if extraction fails.
    """
    text = " ".join(text.splitlines())
    print(text)

    txt = clean_ocr_text(text)
    # This pattern now accepts either "v" or "vs" between team abbreviations
    pattern = r'([A-Z]{2,3}\s*(?:v|vs)\s*[A-Z]{2,3}).?(\d{1,3})\s[-\/]\s*(\d{1,2})'
    m = re.search(pattern, txt)
    if m:
        teams = m.group(1).strip()
        runs = int(m.group(2))
        wkts = int(m.group(3))
        if 0 <= wkts <= 10:
            score = f"{runs}-{wkts}"
            return score, teams
    # Fallback to separate extraction if pattern matching fails
    return extract_score(text), extract_team_names(text)

# ------------- NEW: YOLO SCOREBOARD CROPPING FUNCTION -------------
def crop_scoreboard(frame, yolo_model):
    """
    Run YOLO detection on the frame to detect the scoreboard.
    Returns the cropped region corresponding to the union of all detected bounding boxes.
    If no detection is found, returns the original frame.
    """
    results = yolo_model(frame)
    # If no detections or no boxes, fallback to full frame
    if not results or not results[0].boxes:
        return frame

    boxes = results[0].boxes
    # Initialize union coordinates with extreme values
    x_min, y_min = float('inf'), float('inf')
    x_max, y_max = -float('inf'), -float('inf')

    # Compute union of all detected boxes.
    # Each box.xyxy[0] returns coordinates in [x1, y1, x2, y2] format.
    for box in boxes:
        x1, y1, x2, y2 = box.xyxy[0].tolist()
        x_min = min(x_min, x1)
        y_min = min(y_min, y1)
        x_max = max(x_max, x2)
        y_max = max(y_max, y2)

    # Convert coordinates to integer values
    x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
    # Crop the union region from the frame
    cropped = frame[y_min:y_max, x_min:x_max]
    return cropped


def detect_scoreboard(frame, ocr_model, yolo_model):
    # Crop the scoreboard using YOLO
    cropped = crop_scoreboard(frame, yolo_model)

    # Optional: Display for debugging
    # plt.figure(figsize=(6,4))
    # plt.imshow(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
    # plt.title("Cropped Scoreboard Region")
    # plt.axis("off")
    # plt.show()

    # Preprocess the cropped image to enhance OCR results
    preprocessed = preprocess_for_ocr(cropped)

    # Optional: display the preprocessed image
    # plt.figure(figsize=(6,4))
    # plt.imshow(preprocessed, cmap='gray')
    # plt.title("Preprocessed Scoreboard Region")
    # plt.axis("off")
    # plt.show()

    results = ocr_model.ocr(preprocessed, cls=True)
    if not results or results[0] is None:
        return None, None, 0

    best_score, best_team = None, None
    best_conf = 0
    for line in results:
        if not line:
            continue
        for word_info in line:
            if not word_info or len(word_info) < 2:
                continue
            txt, conf = word_info[1]
            print(txt)
            if 'v' in txt.lower():
                score, teams = extract_score_and_teams(txt)
                if score and teams and conf > best_conf:
                    best_score, best_team, best_conf = score, teams, conf
            else:
                score = extract_score(txt)
                teams = extract_team_names(txt)
                if score and teams and conf > best_conf:
                    best_score, best_team, best_conf = score, teams, conf

    if best_score:
        print(f"OCR detected score {best_score}, team(s) {best_team}")
    return best_score, best_team, best_conf


def analyze_video_for_scores(video_path, ocr_model, yolo_model, interval=5):
    print("Analyzing video for score changes…")
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("❌ Cannot open video for OCR")
        return []
    fps = cap.get(cv2.CAP_PROP_FPS)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total / fps
    step = int(fps * interval)

    last_score, last_runs, last_wkts = None, 0, 0
    events = []
    pbar = tqdm(total=int(duration / interval), desc="Score OCR")

    frame = 0
    while frame < total:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
        ret, img = cap.read()
        if not ret:
            break
        t = frame / fps

        # Pass the YOLO model along with the OCR model
        sc, tm, conf = detect_scoreboard(img, ocr_model, yolo_model)
        if sc and conf > 0.5:
            runs, wkts = map(int, sc.split('-'))
            if last_score is None:
                last_score, last_runs, last_wkts = sc, runs, wkts
            else:
                # Calculate differences from last captured score
                dr = runs - last_runs
                dw = wkts - last_wkts

                # Discard impossible changes
                if dr < 0 or dr > 7:
                    last_score, last_runs, last_wkts = sc, runs, wkts
                    frame += step
                    pbar.update(1)
                    continue

                # Only consider events where run change is exactly 4 or exactly 6, or where a wicket falls (dw == 1)
                if (dr in [4, 6]) or (dw == 1):
                    print(f"🏏 Valid change at {t:.1f}s: {last_score} → {sc} (+{dr} runs, +{dw} wkts)")
                    events.append(t)
                    last_score, last_runs, last_wkts = sc, runs, wkts
                elif dr > 0:
                    # Update baseline for minor valid run changes that don't trigger an event
                    last_score, last_runs, last_wkts = sc, runs, wkts

        frame += step
        pbar.update(1)

    pbar.close()
    cap.release()
    print(f"Found {len(events)} valid score events")
    return events

# -------------------- AUDIO ANALYSIS FUNCTIONS --------------------
def analyze_audio(video_path, threshold_db=-30, interval=0.5):
    """Extract excitement points from audio based on volume."""
    print("Analyzing audio to find exciting moments...")
    audio_file = os.path.join(temp_dir, "audio.wav")
    os.system(f'ffmpeg -i "{video_path}" -vn -acodec pcm_s16le -ar 44100 -ac 2 "{audio_file}" -y -hide_banner -loglevel error')

    if not os.path.exists(audio_file) or os.path.getsize(audio_file) == 0:
        print("⚠ Audio extraction failed, retrying...")
        os.system(f'ffmpeg -i "{video_path}" -vn -acodec pcm_s16le "{audio_file}" -y -hide_banner -loglevel error')

    if not os.path.exists(audio_file) or os.path.getsize(audio_file) == 0:
        print("❌ Could not extract audio")
        return []

    try:
        audio = AudioSegment.from_file(audio_file)
        segment_ms = int(interval * 1000)
        volumes, timestamps = [], []
        for i in range(0, len(audio), segment_ms):
            seg = audio[i:i+segment_ms]
            if len(seg) > 0:
                volumes.append(seg.dBFS)
                timestamps.append(i / 1000.0)

        # Debug plot for audio analysis
        plt.figure(figsize=(12, 4))
        plt.plot(timestamps, volumes)
        plt.axhline(y=threshold_db, color='r', linestyle='--', label=f'Threshold ({threshold_db} dB)')
        plt.title('Audio Volume Analysis')
        plt.xlabel('Time (s)')
        plt.ylabel('Volume (dB)')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        excitement_points = []
        current = None
        min_dur = 1.0

        for t, v in zip(timestamps, volumes):
            loud = (v > threshold_db)
            if loud and current is None:
                current = {'start': t, 'peak_time': t, 'peak_volume': v}
            elif loud and current:
                if v > current['peak_volume']:
                    current['peak_time'], current['peak_volume'] = t, v
            elif not loud and current:
                current['end'] = t
                if (current['end'] - current['start']) >= min_dur:
                    excitement_points.append(current['peak_time'])
                current = None

        if current:
            current['end'] = timestamps[-1]
            if (current['end'] - current['start']) >= min_dur:
                excitement_points.append(current['peak_time'])

        print(f"Found {len(excitement_points)} exciting audio moments")
        if len(excitement_points) > 30:
            top = sorted(
                [(pt, volumes[timestamps.index(pt)]) for pt in excitement_points],
                key=lambda x: x[1], reverse=True
            )[:30]
            excitement_points = sorted([pt for pt, _ in top])

        return excitement_points

    except Exception as e:
        print(f"❌ Error analyzing audio: {e}")
        import traceback; traceback.print_exc()
        return []

# -------------------- SCENE CHANGE FUNCTIONS --------------------
def get_frame_feature(frame):
    """
    Compute a normalized color histogram as the feature vector for the given frame.
    The histogram is computed in the HSV color space for robustness.
    """
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    hist = cv2.calcHist([hsv], [0, 1], None, [50, 60], [0, 180, 0, 256])
    hist = cv2.normalize(hist, hist).flatten()
    return hist

def chi_square_distance(histA, histB, eps=1e-10):
    """Compute the Chi-Square distance between two histograms."""
    return 0.5 * np.sum(((histA - histB) ** 2) / (histA + histB + eps))

def merge_close_timestamps(timestamps, merge_threshold=0.2):
    """
    Merge timestamps that are within merge_threshold seconds of each other.
    Returns a list of merged timestamps (average of clustered values).
    """
    if not timestamps:
        return []

    merged = []
    cluster = [timestamps[0]]
    for t in timestamps[1:]:
        if t - cluster[-1] <= merge_threshold:
            cluster.append(t)
        else:
            merged.append(sum(cluster) / len(cluster))
            cluster = [t]
    if cluster:
        merged.append(sum(cluster) / len(cluster))
    return merged

def detect_scene_changes(video_path, threshold=0.5):
    """Detect scene changes using a simple histogram difference approach."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error opening video file for scene detection:", video_path)
        return []
    fps = cap.get(cv2.CAP_PROP_FPS)
    prev_feature = None
    frame_count = 0
    detected_times = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Resize frame to reduce computation time (adjust size as needed)
        frame = cv2.resize(frame, (320, 240))
        current_feature = get_frame_feature(frame)
        if prev_feature is not None:
            diff = np.linalg.norm(current_feature - prev_feature)
            if diff > threshold:
                timestamp = frame_count / fps
                detected_times.append(timestamp)
        prev_feature = current_feature
        frame_count += 1

    cap.release()
    merged_times = merge_close_timestamps(detected_times)
    print(f"Detected {len(merged_times)} scene changes at: {merged_times}")
    return merged_times

def adjust_lower_bound(lower_time, scene_times, min_gap=3):
    """
    For the desired lower_time, return the nearest scene change timestamp that has its next scene change
    at least min_gap seconds in the future.
    """
    scene_times = sorted(scene_times)
    for i, s in enumerate(scene_times):
        if s >= lower_time:
            if i < len(scene_times)-1 and (scene_times[i+1] - s) >= min_gap:
                return s
            break
    return lower_time

def adjust_upper_bound(upper_time, scene_times, min_gap=3):
    """
    For the desired upper_time, return the nearest scene change timestamp that has its previous scene change
    at least min_gap seconds in the past.
    """
    scene_times = sorted(scene_times)
    for i in range(len(scene_times)-1, -1, -1):
        s = scene_times[i]
        if s <= upper_time:
            if i > 0 and (s - scene_times[i-1]) >= min_gap:
                return s
            break
    return upper_time

# -------------------- CLIP EXTRACTION FUNCTIONS --------------------
def extract_robust_clip(video_path, start_time, duration, output_file):
    """Extract a single clip with robust error handling."""
    s, d = int(start_time), int(duration)
    cmd = f'ffmpeg -i "{video_path}" -ss {s} -t {d} -c:v libx264 -preset ultrafast -c:a aac "{output_file}" -y -hide_banner -loglevel error'
    res = os.system(cmd)
    if res != 0 or not os.path.exists(output_file) or os.path.getsize(output_file) < 10000:
        cmd = f'ffmpeg -ss {s} -i "{video_path}" -t {d} -c:v libx264 -preset ultrafast -c:a aac "{output_file}" -y -hide_banner -loglevel error'
        res = os.system(cmd)
    if res != 0 or not os.path.exists(output_file) or os.path.getsize(output_file) < 10000:
        cmd = f'ffmpeg -ss {s} -i "{video_path}" -t {d} -c copy "{output_file}" -y -hide_banner -loglevel error'
        res = os.system(cmd)
    return (res == 0 and os.path.exists(output_file) and os.path.getsize(output_file) > 10000)

def extract_clips(video_path, event_timestamps, scene_times, pre_sec=15, post_sec=25):
    """Extract clips for each event timestamp with overlap handling and scene boundary adjustments."""
    if not event_timestamps:
        print("❌ No event timestamps to extract clips for")
        return 0

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total / fps
    cap.release()

    segments = []
    for ts in event_timestamps:
        # Compute initial clip boundaries based on audio/OCR events
        lower = max(0, ts - pre_sec)
        upper = min(duration, ts + post_sec)
        # Adjust using scene change timestamps:
        adjusted_lower = adjust_lower_bound(lower, scene_times, min_gap=3)
        adjusted_upper = adjust_upper_bound(upper, scene_times, min_gap=3)
        # Use the adjusted boundaries (ensure lower is less than upper)
        start, end = adjusted_lower, adjusted_upper
        if start >= end:
            start = lower
            end = upper
        # Print out the original and scene adjusted boundaries for debugging:
        print(f"Event at {ts:.1f}s: Original clip: {lower:.1f}s–{upper:.1f}s, Scene adjusted clip: {adjusted_lower:.1f}s–{adjusted_upper:.1f}s")

        # Merge overlapping segments
        merged = False
        for i, (s, e) in enumerate(segments):
            if start <= e and end >= s:
                segments[i] = (min(s, start), max(e, end))
                merged = True
                break
        if not merged:
            segments.append((start, end))

    print(f"Extracting {len(segments)} clips (after merging)")
    success = 0
    for idx, (s, e) in enumerate(segments):
        dur_seg = e - s
        if dur_seg < 5:
            continue
        out = os.path.join(output_dir, f"clip_{idx+1}_{int(s)}s.mp4")
        print(f" ▶ Clip {idx+1}: {s:.1f}s–{e:.1f}s (duration: {dur_seg:.1f}s)")
        if extract_robust_clip(video_path, s, dur_seg, out):
            print("   ✅")
            success += 1
        else:
            print("   ❌")
    print(f"✅ Extracted {success}/{len(segments)} clips")
    return success

def merge_clips():
    """Merge all extracted clips into a final highlights video."""
    clips = sorted(glob.glob(os.path.join(output_dir, "clip_*.mp4")),
                   key=lambda x: int(re.search(r'clip_(\d+)_', x).group(1)))
    if not clips:
        print("❌ No clips to merge")
        return False

    list_file = os.path.join(temp_dir, "clips_list.txt")
    with open(list_file, 'w') as f:
        for c in clips:
            f.write(f"file '{c}'\n")

    final = os.path.join(content_dir, "cricket_highlights.mp4")
    cmd = f'ffmpeg -f concat -safe 0 -i "{list_file}" -c copy "{final}" -y -hide_banner -loglevel error'
    res = os.system(cmd)
    if res == 0 and os.path.exists(final):
        print(f"✅ Highlights video ready: {final}")
        files.download(final)
        return True

    print("❌ Merge failed")
    return False

# -------------------- MAIN PROCESSING FUNCTION --------------------
def process_video():
    if not os.path.exists(video_path):
        print("❌ Please upload 'clip.mp4' to your Colab environment first.")
        return

    try:
        # Initialize the YOLO model with your trained weight file.
        # Update the weight file path if necessary.
        yolo_model = YOLO('/content/best (3).pt')

        # Step 0: Scene Change Detection
        print("\n🔍 Running scene change detection...")
        scene_times = detect_scene_changes(video_path, threshold=0.5)
        if not scene_times:
            print("❌ No scene changes detected, clips will use original boundaries.")
        else:
            print(f"Scene changes detected at: {scene_times}")

        # Step 1: Audio Analysis Parameters
        print("\n⚙ Enter parameters for highlight extraction:")
        audio_threshold = float(input("Audio threshold in dB (-40 to -20) [default -30]: ") or "-30")
        pre_time = float(input("Seconds BEFORE exciting moment [default 15]: ") or "15")
        post_time = float(input("Seconds AFTER exciting moment [default 25]: ") or "25")

        # Step 2: Audio Analysis
        audio_events = analyze_audio(video_path, threshold_db=audio_threshold, interval=0.5)

        # Step 3: OCR Analysis
        print("\n🏏 Initializing OCR model (this may take a moment)...")
        ocr_model = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
        score_events = analyze_video_for_scores(video_path, ocr_model, yolo_model, interval=5)

        # Combine events from audio and OCR
        all_events = sorted(set(audio_events + score_events))
        print(f"\n📊 Combined Analysis:")
        print(f"  • Audio events: {len(audio_events)}")
        print(f"  • Score events: {len(score_events)}")
        print(f"  • Total unique events: {len(all_events)}")

        if not all_events:
            print("❌ No events detected. Try adjusting thresholds.")
            return

        # Step 4: Extract Clips (with scene change adjustment)
        print("\n✂ Extracting highlight clips with scene change adjustments...")
        clip_count = extract_clips(video_path, all_events, scene_times, pre_sec=pre_time, post_sec=post_time)

        # Step 5: Merge Clips
        if clip_count > 0:
            print("\n🔄 Merging clips into final highlights video...")
            merge_clips()
        else:
            print("❌ No clips extracted. Try adjusting the parameters.")

    except Exception as e:
        print(f"❌ An error occurred: {e}")
        import traceback
        traceback.print_exc()

if _name_ == "_main_":
    process_video()