# Instruction Before Running
We have 2 pre-trained models that will be used in our project.
Here is the path you can get


1. YOLO Model :

https://drive.google.com/file/d/1wL-OiTJXFwb4BUTvV6JZOSXcgP5wWAOZ/view?usp=drive_link


2. U-Net Model :

https://drive.google.com/file/d/1dFIb0wo3yMqmd9kIzyLGfPU5NvetYLyN/view?usp=sharing



 ## After discovering the models, please use google drive to mount and insert the path into the variable model_path for both U-Net and YOLO Model.

In [None]:
!pip install ultralytics opencv-python-headless matplotlib numpy mediapipe --quiet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
"""
Chess Move Detection Pipeline from Video
=========================================
A complete computer vision pipeline for detecting chess moves from video footage.

This module provides functions to:
1. Extract and filter video frames (removing frames with hands)
2. Detect chessboard corners using a pre-trained U-Net model
3. Determine board orientation using OCR + VLM reasoning
4. Detect chess pieces using a pre-trained YOLO model
5. Track board state changes and generate PGN notation

Prerequisites:
- Pre-trained YOLO model for chess piece detection (loaded externally)
- Pre-trained U-Net model for corner detection (loaded externally)
- MediaPipe for hand detection
- OpenCV, NumPy, Matplotlib for image processing and visualization

Author: Chess Vision Pipeline
Version: 1.0
"""

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle, FancyArrow
from matplotlib.lines import Line2D
import mediapipe as mp
from typing import List, Tuple, Dict, Optional, Any, Union
from dataclasses import dataclass, field
from enum import Enum
import re
from collections import OrderedDict
import warnings
import torch
import torch.nn as nn
from pathlib import Path
import base64

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')


# Adding detail

In [None]:
def crop_center_square(frame: np.ndarray) -> np.ndarray:
    """
    Crop the largest possible square from the center of the frame.

    This removes black bars at the top/bottom (or left/right) by
    taking a width x width crop centered in the image.

    Parameters
    ----------
    frame : np.ndarray
        Input frame in BGR format.

    Returns
    -------
    np.ndarray
        Center-cropped square frame.
    """
    h, w = frame.shape[:2]

    # side length of the square: use the smaller dimension
    side = min(h, w)

    # top-left corner of the crop
    y1 = (h - side) // 2
    x1 = (w - side) // 2

    y2 = y1 + side
    x2 = x1 + side

    cropped = frame[y1:y2, x1:x2]
    return cropped

def pad_frame_for_unet(
    frame: np.ndarray,
    pad_ratio: float = 0.5
) -> Tuple[np.ndarray, Tuple[int, int]]:
    """
    Pad a (square) frame with a black border on all sides before
    sending it to U-Net.

    pad_ratio = 0.5  => pad width/2 (and height/2) on each side,
    so final size = 2 * side.

    Parameters
    ----------
    frame : np.ndarray
        Original BGR frame (already cropped to board region).
    pad_ratio : float
        Fraction of the original side length to pad on each side.

    Returns
    -------
    Tuple[np.ndarray, Tuple[int, int]]
        - padded_frame: new image with black border.
        - (offset_x, offset_y): top-left corner of the original
          frame inside the padded image.
    """
    h, w = frame.shape[:2]
    # assume square; if not, use the smaller dimension
    side = min(h, w)

    pad = int(side * pad_ratio)
    new_side = side + 2 * pad

    # create black canvas
    padded = np.zeros((new_side, new_side, 3), dtype=frame.dtype)

    # put original frame in the center
    y1 = pad
    x1 = pad
    y2 = y1 + side
    x2 = x1 + side

    # if frame isn't exactly square, just crop the center square part
    frame_square = frame[0:side, 0:side]
    padded[y1:y2, x1:x2] = frame_square

    return padded, (x1, y1)

# Handle the symbol "+"
def _square_to_coords(square: str) -> Tuple[int, int]:
    """Convert algebraic square like 'e4' to (file, rank) = (0..7, 0..7)."""
    file_idx = ord(square[0]) - ord('a')       # a -> 0, ..., h -> 7
    rank_idx = int(square[1]) - 1              # '1' -> 0, ..., '8' -> 7
    return file_idx, rank_idx


def _coords_to_square(file_idx: int, rank_idx: int) -> str:
    """Convert (file, rank) back to algebraic notation."""
    return FILES[file_idx] + str(rank_idx + 1)


def _can_piece_attack(
    piece: str,
    from_square: str,
    to_square: str,
    board: Dict[str, str]
) -> bool:
    """
    Check if a single piece on from_square can attack to_square,
    taking blocking pieces into account for sliding pieces.
    """
    if from_square == to_square:
        return False

    piece_type = piece.upper()
    from_file, from_rank = _square_to_coords(from_square)
    to_file, to_rank = _square_to_coords(to_square)

    dx = to_file - from_file
    dy = to_rank - from_rank

    # Knight
    if piece_type == 'N':
        return (abs(dx), abs(dy)) in [(1, 2), (2, 1)]

    # King
    if piece_type == 'K':
        return max(abs(dx), abs(dy)) == 1

    # Pawn (only capture directions matter for check)
    if piece_type == 'P':
        direction = 1 if piece.isupper() else -1  # white up, black down
        return dy == direction and abs(dx) == 1

    # Sliding pieces: Bishop / Rook / Queen
    if piece_type in ('B', 'R', 'Q'):
        # Movement pattern check
        if piece_type == 'B' and (abs(dx) != abs(dy) or dx == 0):
            return False
        if piece_type == 'R' and not ((dx == 0) ^ (dy == 0)):  # exactly one nonzero
            return False
        if piece_type == 'Q' and not (
            (abs(dx) == abs(dy) and dx != 0) or  # diagonal
            ((dx == 0) ^ (dy == 0))              # straight
        ):
            return False

        # Step direction
        step_x = 0 if dx == 0 else (1 if dx > 0 else -1)
        step_y = 0 if dy == 0 else (1 if dy > 0 else -1)

        # Walk along the ray and ensure no blocking pieces
        x = from_file + step_x
        y = from_rank + step_y
        while (x, y) != (to_file, to_rank):
            sq = _coords_to_square(x, y)
            if sq in board:          # any piece blocks the line
                return False
            x += step_x
            y += step_y

        return True

    return False


def is_square_attacked(
    board: Dict[str, str],
    target_square: str,
    by_white: bool
) -> bool:
    """
    Return True if target_square is attacked by the given side.

    Parameters
    ----------
    board : Dict[str, str]
        Board state (square -> piece code).
    target_square : str
        Square to test (e.g. 'e4').
    by_white : bool
        True if we check attacks by white pieces, False for black.
    """
    for square, piece in board.items():
        if piece.isupper() != by_white:
            continue  # piece belongs to the other side
        if _can_piece_attack(piece, square, target_square, board):
            return True
    return False

def compute_piece_center_with_rotation(
    x1: int, y1: int, x2: int, y2: int, rotation_deg: int
) -> Tuple[float, float]:
    """
    Compute the detection center, biased toward the 'bottom' of the piece
    relative to the board orientation.

    rotation_deg meaning (same as compute_board_grid):
      - 0   : White at BOTTOM of image  (bottom = +y direction)
      - 90  : White at RIGHT  of image  (bottom = +x direction)
      - 180 : White at TOP    of image  (bottom = -y direction)
      - 270 : White at LEFT   of image  (bottom = -x direction)
    """

    # default geometric center
    cx = (x1 + x2) / 2.0
    cy = (y1 + y2) / 2.0

    top_y    = max(y1, y2)
    bottom_y = min(y1, y2)
    left_x   = min(x1, x2)
    right_x  = max(x1, x2)

    # weight toward "bottom" side of the piece
    # alpha > 0.5 means closer to bottom
    alpha_bottom = 0.75
    alpha_top    = 1-alpha_bottom
    # cy = alpha_bottom * top_y + alpha_top * bottom_y
    if rotation_deg == 0:
        # bottom of board = image bottom  (+y)
        cy = alpha_top * top_y + alpha_bottom * bottom_y

    elif rotation_deg == 180:
        # bottom of board = image TOP (-y)
        # swap weights so we go toward smaller y
        cy = alpha_bottom * top_y + alpha_top * bottom_y

    elif rotation_deg == 90:
        # bottom of board = image RIGHT (+x)
        cx = alpha_top * left_x + alpha_bottom * right_x


    elif rotation_deg == 270:
        # bottom of board = image LEFT (-x)
        cx = alpha_bottom * left_x + alpha_top * right_x



    # if rotation is something else, just keep geometric center

    return cx, cy
def sharpen_laplacian(img, amount: float = 1.0):
    # Laplacian edge map
    lap = cv2.Laplacian(img, cv2.CV_64F, ksize=3)

    # Sharpen: original - amount * laplacian
    sharp = cv2.addWeighted(img.astype(np.float64), 1.0,
                            -lap, amount, 0)

    return np.clip(sharp, 0, 255).astype(np.uint8)

def sharpen_ultra(img):
    # Step 1 — Unsharp mask
    blur = cv2.GaussianBlur(img, (0, 0), sigmaX=1.0)
    unsharp = cv2.addWeighted(img, 1.8, blur, -0.8, 0)

    # Step 2 — High-pass edges
    kernel = np.array([
        [-1, -1, -1],
        [-1,  8, -1],
        [-1, -1, -1]
    ], dtype=np.float32)

    edges = cv2.filter2D(img, -1, kernel)

    # Step 3 — Blend edges into unsharp image
    final = cv2.addWeighted(unsharp, 1.0, edges, 0.6, 0)

    return final



# U-NET Architecture

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------------
# U-NET ARCHITECTURE USED FOR TRAINING
# -------------------------------

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, features=[64,128,256,512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        prev = in_ch

        for f in features:
            self.downs.append(DoubleConv(prev, f))
            prev = f

        self.pool = nn.MaxPool2d(2,2)
        self.bottleneck = DoubleConv(prev, prev*2)
        prev = prev*2

        for f in reversed(features):
            self.ups.append(nn.ConvTranspose2d(prev, f, 2, 2))
            self.ups.append(DoubleConv(prev, f))
            prev = f

        self.final_conv = nn.Conv2d(prev, out_ch, 1)

    def forward(self, x):
        skips = []
        out = x

        for d in self.downs:
            out = d(out)
            skips.append(out)
            out = self.pool(out)

        out = self.bottleneck(out)

        for i in range(0, len(self.ups), 2):
            out = self.ups[i](out)
            skip = skips[-1 - (i//2)]

            if out.size()[2:] != skip.size()[2:]:
                out = nn.functional.interpolate(out, size=skip.shape[2:])

            out = torch.cat((skip, out), dim=1)
            out = self.ups[i+1](out)

        return torch.sigmoid(self.final_conv(out))


# -------------------------------
# LOAD MODEL
# -------------------------------

def load_unet(path):
    model = UNet().to(DEVICE)
    ckpt = torch.load(path, map_location=DEVICE)
    model.load_state_dict(ckpt)
    model.eval()
    print("U-Net Loaded.")
    return model


# -------------------------------
# IMAGE ENHANCEMENT METHODS
# -------------------------------

def apply_sobel(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    sx = cv2.Sobel(gray, cv2.CV_16S, 1, 0, ksize=3)
    sy = cv2.Sobel(gray, cv2.CV_16S, 0, 1, ksize=3)
    sob = cv2.convertScaleAbs(0.5*sx + 0.5*sy)
    sharp = cv2.addWeighted(gray, 1.0, sob, 0.7, 0)
    return cv2.cvtColor(sharp, cv2.COLOR_GRAY2BGR)


def apply_log(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(gray, (5,5), 0)
    lap = cv2.Laplacian(blur, cv2.CV_64F, ksize=3)
    lap = cv2.convertScaleAbs(lap)
    sharp = cv2.addWeighted(gray, 1.2, lap, -0.5, 0)
    return cv2.cvtColor(sharp, cv2.COLOR_GRAY2BGR)


# -------------------------------
# U-NET INFERENCE
# -------------------------------

def unet_predict(model, img):
    h, w = img.shape[:2]
    resized = cv2.resize(img, (512,512))
    t = torch.tensor(resized/255.0).float().permute(2,0,1).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        pred = model(t)[0,0].cpu().numpy()

    mask512 = (pred > 0.5).astype(np.uint8)
    mask = cv2.resize(mask512, (w,h), interpolation=cv2.INTER_NEAREST)
    return mask


# -------------------------------
# CORNER EXTRACTION
# -------------------------------

def order_corners(pts):
    pts = np.array(pts)
    s  = pts.sum(axis=1)
    diff = np.diff(pts, axis=1).reshape(-1)

    TL = pts[np.argmin(s)]
    BR = pts[np.argmax(s)]
    TR = pts[np.argmin(diff)]
    BL = pts[np.argmax(diff)]

    return np.array([TL, TR, BR, BL], float)

def extract_corners(mask):
    cnts,_ = cv2.findContours(mask.astype(np.uint8)*255,
                              cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not cnts:
        return None

    cnt = max(cnts, key=cv2.contourArea)
    peri = cv2.arcLength(cnt, True)
    approx = cv2.approxPolyDP(cnt, 0.02*peri, True)

    if len(approx) != 4:
        rect = cv2.minAreaRect(cnt)
        approx = cv2.boxPoints(rect)

    return order_corners(approx.reshape(-1,2))


# Longer Part

In [None]:
# ============================================================================
# CONSTANTS AND CONFIGURATION
# ============================================================================

# Standard chess piece notation
PIECE_SYMBOLS = {
    'white-king': 'K', 'white-queen': 'Q', 'white-rook': 'R',
    'white-bishop': 'B', 'white-knight': 'N', 'white-pawn': 'P',
    'black-king': 'k', 'black-queen': 'q', 'black-rook': 'r',
    'black-bishop': 'b', 'balck-knight': 'n', 'black-pawn': 'p'
}

# Reverse mapping for display
PIECE_NAMES = {v: k for k, v in PIECE_SYMBOLS.items()}

# Files and ranks for algebraic notation
FILES = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
RANKS = ['1', '2', '3', '4', '5', '6', '7', '8']

# Standard board size for warping
DEFAULT_BOARD_SIZE = 800

# Visualization colors (BGR format for OpenCV)
COLORS = {
    'corner': (0, 255, 0),      # Green
    'piece': (255, 0, 0),       # Blue
    'grid': (128, 128, 128),    # Gray
    'move_from': (0, 0, 255),   # Red
    'move_to': (0, 255, 0),     # Green
    'text': (255, 255, 255),    # White
    'white_piece': (255, 255, 200),
    'black_piece': (50, 50, 50)
}


# ============================================================================
# DATA CLASSES
# ============================================================================

@dataclass
class PieceDetection:
    """Represents a detected chess piece."""
    bbox: Tuple[int, int, int, int]  # (x1, y1, x2, y2)
    center: Tuple[float, float]       # (cx, cy)
    class_id: int
    class_name: str
    confidence: float
    projected_center: Optional[Tuple[float, float]] = None
    assigned_square: Optional[str] = None


@dataclass
class BoardState:
    """Represents the state of the chess board at a given frame."""
    frame_index: int
    pieces: Dict[str, str]  # square -> piece code (e.g., 'e4' -> 'P')
    detections: List[PieceDetection] = field(default_factory=list)

    def copy(self):
        return BoardState(
            frame_index=self.frame_index,
            pieces=self.pieces.copy(),
            detections=self.detections.copy()
        )


@dataclass
class Move:
    """Represents a chess move."""
    from_square: str
    to_square: str
    piece: str
    captured_piece: Optional[str] = None
    is_castle: bool = False
    castle_side: Optional[str] = None  # 'kingside' or 'queenside'
    is_en_passant: bool = False
    promotion_piece: Optional[str] = None
    is_check: bool = False
    is_checkmate: bool = False


@dataclass
class GlobalBoardParameters:
    """Global parameters computed from the first clean frame."""
    rotation_degrees: int  # 0, 90, 180, or 270
    corners_ordered: np.ndarray  # Shape (4, 2): [TL, TR, BR, BL]
    homography_matrix: np.ndarray  # 3x3 homography matrix
    board_size: int
    grid_centers: np.ndarray  # Shape (64, 2)
    square_names: List[str]  # 64 algebraic names


# ============================================================================
# FRAME EXTRACTION AND FILTERING
# ============================================================================

def extract_frames_from_video(video_path: str, output_fps: Optional[float] = None) -> List[np.ndarray]:
    """
    Read a video and return a list of frames (BGR np.ndarrays).

    Parameters
    ----------
    video_path : str
        Path to the input video file.
    output_fps : float, optional
        If specified, resample the video to this frame rate.
        If None, extract all frames at original FPS.

    Returns
    -------
    List[np.ndarray]
        List of frames as BGR numpy arrays.

    Raises
    ------
    FileNotFoundError
        If the video file cannot be opened.

    Example
    -------
    >>> frames = extract_frames_from_video("chess_game.mp4", output_fps=2.0)
    >>> print(f"Extracted {len(frames)} frames")
    """
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        raise FileNotFoundError(f"Could not open video file: {video_path}")

    # Get video properties
    original_fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    print(f"Video properties:")
    print(f"  - Resolution: {width}x{height}")
    print(f"  - Original FPS: {original_fps:.2f}")
    print(f"  - Total frames: {total_frames}")
    print(f"  - Duration: {total_frames/original_fps:.2f} seconds")

    frames = []

    if output_fps is None or output_fps >= original_fps:
        # Extract all frames
        frame_interval = 1
        print(f"Extracting all frames...")
    else:
        # Calculate frame interval for resampling
        frame_interval = original_fps / output_fps
        print(f"Resampling to {output_fps} FPS (every {frame_interval:.2f} frames)")

    frame_idx = 0
    next_frame_to_extract = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # --- NEW: crop to center square to remove black bars ---
        frame = crop_center_square(frame)

        if frame_idx >= next_frame_to_extract:
            frames.append(frame)
            next_frame_to_extract += frame_interval

        frame_idx += 1


    cap.release()

    print(f"Extracted {len(frames)} frames")
    return frames


def detect_hands_mediapipe(
    frame: np.ndarray,
    mp_hands: Any,
    min_detection_confidence: float = 0.5
) -> Tuple[bool, Any]:
    """
    Use MediaPipe Hands to detect hands in a frame.

    Parameters
    ----------
    frame : np.ndarray
        Input frame in BGR format.
    mp_hands : mediapipe.solutions.hands.Hands
        Initialized MediaPipe Hands object.
    min_detection_confidence : float
        Minimum confidence threshold for detection (0.0 to 1.0).

    Returns
    -------
    Tuple[bool, Any]
        - bool: True if at least one hand is detected, False otherwise.
        - results: MediaPipe detection results (for visualization).

    Example
    -------
    >>> mp_hands_instance = mp.solutions.hands.Hands(min_detection_confidence=0.5)
    >>> has_hand, results = detect_hands_mediapipe(frame, mp_hands_instance)
    """
    # Convert BGR to RGB for MediaPipe
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Process the frame
    results = mp_hands.process(frame_rgb)

    # Check if any hands were detected
    has_hands = results.multi_hand_landmarks is not None

    return has_hands, results


def filter_frames_no_hands(
    frames: List[np.ndarray],
    min_detection_confidence: float = 0.2,
    visualize_progress: bool = True
) -> Tuple[List[np.ndarray], List[int]]:
    """
    Run hand detection on all frames and return frames without hands.

    This function filters out frames where a hand is visible (e.g., when
    a player is making a move), keeping only "clean" frames showing the
    stable board state.

    Parameters
    ----------
    frames : List[np.ndarray]
        List of video frames in BGR format.
    min_detection_confidence : float
        Minimum confidence for hand detection.
    visualize_progress : bool
        If True, print progress information.

    Returns
    -------
    Tuple[List[np.ndarray], List[int]]
        - clean_frames: List of frames without detected hands.
        - clean_indices: Corresponding indices in the original frame list.

    Example
    -------
    >>> clean_frames, indices = filter_frames_no_hands(all_frames)
    >>> print(f"Kept {len(clean_frames)} clean frames out of {len(all_frames)}")
    """
    # Initialize MediaPipe Hands
    mp_hands_module = mp.solutions.hands
    hands = mp_hands_module.Hands(
        static_image_mode=True,  # Process each frame independently
        max_num_hands=2,
        min_detection_confidence=min_detection_confidence,
        min_tracking_confidence=0.5
    )

    clean_frames = []
    clean_indices = []
    frames_with_hands = 0

    if visualize_progress:
        print(f"Filtering {len(frames)} frames for hand presence...")

    for idx, frame in enumerate(frames):
        has_hands, _ = detect_hands_mediapipe(frame, hands, min_detection_confidence)

        if not has_hands:
            clean_frames.append(frame)
            clean_indices.append(idx)
        else:
            frames_with_hands += 1

        # Progress update every 50 frames
        if visualize_progress and (idx + 1) % 50 == 0:
            print(f"  Processed {idx + 1}/{len(frames)} frames...")

    hands.close()

    if visualize_progress:
        print(f"Filtering complete:")
        print(f"  - Frames with hands (discarded): {frames_with_hands}")
        print(f"  - Clean frames (kept): {len(clean_frames)}")
        print(f"  - Retention rate: {len(clean_frames)/len(frames)*100:.1f}%")

    return clean_frames, clean_indices


def visualize_hand_detection(
    frame: np.ndarray,
    results: Any,
    title: str = "Hand Detection"
) -> np.ndarray:
    """
    Visualize MediaPipe hand detection results on a frame.

    Parameters
    ----------
    frame : np.ndarray
        Input frame in BGR format.
    results : Any
        MediaPipe hand detection results.
    title : str
        Title for the visualization.

    Returns
    -------
    np.ndarray
        Frame with hand landmarks drawn.
    """
    annotated_frame = frame.copy()

    mp_drawing = mp.solutions.drawing_utils
    mp_hands_module = mp.solutions.hands

    if results.multi_hand_landmarks:
        for hand_landmarks in results.multi_hand_landmarks:
            mp_drawing.draw_landmarks(
                annotated_frame,
                hand_landmarks,
                mp_hands_module.HAND_CONNECTIONS,
                mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=4),
                mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2)
            )

    # Add title
    cv2.putText(
        annotated_frame, title,
        (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2
    )

    return annotated_frame


# ============================================================================
# IMAGE PREPROCESSING
# ============================================================================

def preprocess_sobel_log(
    frame: np.ndarray,
    sobel_ksize: int = 3,
    log_sigma: float = 1.0,
    log_ksize: int = 5
) -> Dict[str, np.ndarray]:
    """
    Apply Sobel + Laplacian of Gaussian (LoG) preprocessing.

    This preprocessing helps enhance edges for visualization and potentially
    aids in corner detection.

    Parameters
    ----------
    frame : np.ndarray
        Input frame in BGR format.
    sobel_ksize : int
        Kernel size for Sobel operator (must be 1, 3, 5, or 7).
    log_sigma : float
        Standard deviation for Gaussian blur in LoG.
    log_ksize : int
        Kernel size for Gaussian blur in LoG.

    Returns
    -------
    Dict[str, np.ndarray]
        Dictionary containing:
        - 'gray': Grayscale image
        - 'sobel_x': Sobel gradient in X direction
        - 'sobel_y': Sobel gradient in Y direction
        - 'sobel_magnitude': Combined Sobel magnitude
        - 'gaussian_blur': Gaussian blurred image
        - 'log': Laplacian of Gaussian result
        - 'combined': Combined edge detection result

    Example
    -------
    >>> preprocessed = preprocess_sobel_log(frame)
    >>> plt.imshow(preprocessed['combined'], cmap='gray')
    """
    # Convert to grayscale
    if len(frame.shape) == 3:
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    else:
        gray = frame.copy()

    # Sobel gradients
    sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=sobel_ksize)
    sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=sobel_ksize)

    # Sobel magnitude
    sobel_magnitude = np.sqrt(sobel_x**2 + sobel_y**2)
    sobel_magnitude = np.uint8(255 * sobel_magnitude / sobel_magnitude.max())

    # Gaussian blur for LoG
    gaussian_blur = cv2.GaussianBlur(gray, (log_ksize, log_ksize), log_sigma)

    # Laplacian of Gaussian
    log_result = cv2.Laplacian(gaussian_blur, cv2.CV_64F)
    log_result = np.uint8(np.abs(log_result))

    # Normalize LoG
    if log_result.max() > 0:
        log_normalized = np.uint8(255 * log_result / log_result.max())
    else:
        log_normalized = log_result

    # Combined result: weighted sum of Sobel and LoG
    combined = cv2.addWeighted(sobel_magnitude, 0.5, log_normalized, 0.5, 0)

    return {
        'gray': gray,
        'sobel_x': np.uint8(np.abs(sobel_x)),
        'sobel_y': np.uint8(np.abs(sobel_y)),
        'sobel_magnitude': sobel_magnitude,
        'gaussian_blur': gaussian_blur,
        'log': log_normalized,
        'combined': combined
    }


def visualize_preprocessing(
    preprocessed: Dict[str, np.ndarray],
    original_frame: np.ndarray,
    figsize: Tuple[int, int] = (16, 10)
) -> None:
    """
    Visualize all preprocessing stages.

    Parameters
    ----------
    preprocessed : Dict[str, np.ndarray]
        Output from preprocess_sobel_log().
    original_frame : np.ndarray
        Original BGR frame.
    figsize : Tuple[int, int]
        Figure size for matplotlib.
    """
    fig, axes = plt.subplots(2, 4, figsize=figsize)

    # Original image
    axes[0, 0].imshow(cv2.cvtColor(original_frame, cv2.COLOR_BGR2RGB))
    axes[0, 0].set_title('Original Frame')
    axes[0, 0].axis('off')

    # Grayscale
    axes[0, 1].imshow(preprocessed['gray'], cmap='gray')
    axes[0, 1].set_title('Grayscale')
    axes[0, 1].axis('off')

    # Sobel X
    axes[0, 2].imshow(preprocessed['sobel_x'], cmap='gray')
    axes[0, 2].set_title('Sobel X')
    axes[0, 2].axis('off')

    # Sobel Y
    axes[0, 3].imshow(preprocessed['sobel_y'], cmap='gray')
    axes[0, 3].set_title('Sobel Y')
    axes[0, 3].axis('off')

    # Sobel Magnitude
    axes[1, 0].imshow(preprocessed['sobel_magnitude'], cmap='gray')
    axes[1, 0].set_title('Sobel Magnitude')
    axes[1, 0].axis('off')

    # Gaussian Blur
    axes[1, 1].imshow(preprocessed['gaussian_blur'], cmap='gray')
    axes[1, 1].set_title('Gaussian Blur')
    axes[1, 1].axis('off')

    # LoG
    axes[1, 2].imshow(preprocessed['log'], cmap='gray')
    axes[1, 2].set_title('Laplacian of Gaussian')
    axes[1, 2].axis('off')

    # Combined
    axes[1, 3].imshow(preprocessed['combined'], cmap='gray')
    axes[1, 3].set_title('Combined (Sobel + LoG)')
    axes[1, 3].axis('off')

    plt.tight_layout()
    plt.suptitle('Preprocessing Pipeline: Sobel + Laplacian of Gaussian', y=1.02)
    plt.show()


# ============================================================================
# CORNER DETECTION (U-Net)
# ============================================================================

def detect_corners_unet(
    frame: np.ndarray,
    unet_model: Any
) -> np.ndarray:
    """
    Detect 4 board corners using the given U-Net model.

    NEW:
    - Before running U-Net, we pad the frame with a black border
      (width/2 on each side), so the network sees a zoomed-out view.
    - After getting corner positions in the padded coordinates, we
      remove the padding offset so that the returned corners are
      in the original frame coordinate system (width x width).

    Steps:
      1. Pad frame with black border (pad = width/2 on each side).
      2. Apply Sobel + LoG on the padded frame.
      3. Run U-Net on both preprocessed images.
      4. Extract 4 corners from each mask.
      5. Average the two sets of corners.
      6. Subtract padding offset -> corners in original frame.
    """
    # -------------------------------------------------------------
    # 0) Pad the frame with black border for U-Net ONLY
    # -------------------------------------------------------------
    padded_frame, (offset_x, offset_y) = pad_frame_for_unet(frame, pad_ratio=0.5)

    # 1. Preprocess for Sobel and LoG (on padded frame)
    sobel_img = apply_sobel(padded_frame)
    log_img   = apply_log(padded_frame)

    # 2. Predict masks with U-Net (on padded images)
    mask_sobel = unet_predict(unet_model, sobel_img)
    mask_log   = unet_predict(unet_model, log_img)

    # 3. Extract corners from each mask (in padded coordinates)
    corners_sobel = extract_corners(mask_sobel)
    corners_log   = extract_corners(mask_log)

    if corners_sobel is None or corners_log is None:
        raise RuntimeError("U-Net corner detection failed: one method returned None")

    if corners_sobel.shape != (4, 2) or corners_log.shape != (4, 2):
        raise RuntimeError(
            f"Expected (4,2) corners, got {corners_sobel.shape} and {corners_log.shape}"
        )

    # 4. Average the two results (still in padded coordinates)
    corners_avg = (corners_sobel.astype(np.float32) + corners_log.astype(np.float32)) / 2.0

    # 5. Remove padding offset to go back to ORIGINAL frame coordinates
    corners_unpadded = corners_avg.copy()
    corners_unpadded[:, 0] -= float(offset_x)
    corners_unpadded[:, 1] -= float(offset_y)

    return corners_unpadded



def extract_corners_from_heatmap(
    heatmap: np.ndarray,
    threshold: float,
    input_size: Tuple[int, int]
) -> np.ndarray:
    """
    Extract corner positions from U-Net heatmap output.

    Parameters
    ----------
    heatmap : np.ndarray
        Heatmap output from U-Net with shape (1, H, W, 4) or (1, 4, H, W).
    threshold : float
        Minimum threshold for corner detection.
    input_size : Tuple[int, int]
        Size of the input image.

    Returns
    -------
    np.ndarray
        Array of shape (4, 2) with corner positions.
    """
    # Remove batch dimension
    heatmap = np.squeeze(heatmap)

    # Handle different channel orderings
    if heatmap.shape[0] == 4:  # (4, H, W)
        heatmap = np.transpose(heatmap, (1, 2, 0))

    corners = []

    for i in range(4):
        # Get the i-th corner's heatmap
        corner_heatmap = heatmap[:, :, i]

        # Apply threshold
        corner_heatmap_thresh = corner_heatmap.copy()
        corner_heatmap_thresh[corner_heatmap_thresh < threshold] = 0

        # Find the position of maximum value
        if corner_heatmap_thresh.max() > 0:
            # Use weighted centroid for sub-pixel accuracy
            y_indices, x_indices = np.where(corner_heatmap_thresh > 0)
            weights = corner_heatmap_thresh[y_indices, x_indices]

            cx = np.average(x_indices, weights=weights)
            cy = np.average(y_indices, weights=weights)
        else:
            # Fallback to argmax
            max_idx = np.argmax(corner_heatmap)
            cy, cx = np.unravel_index(max_idx, corner_heatmap.shape)

        corners.append([cx, cy])

    return np.array(corners)


def visualize_unet_corners(
    frame: np.ndarray,
    corners: np.ndarray,
    title: str = "U-Net Corner Detection"
) -> None:
    """
    Visualize detected corners on the original frame.

    Parameters
    ----------
    frame : np.ndarray
        Original BGR frame.
    corners : np.ndarray
        Array of shape (4, 2) with corner positions.
    title : str
        Title for the plot.
    """
    fig, ax = plt.subplots(1, 1, figsize=(12, 10))

    # Display the frame
    ax.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # Define colors for each corner
    corner_colors = ['red', 'green', 'blue', 'orange']
    corner_labels = ['Corner 1', 'Corner 2', 'Corner 3', 'Corner 4']

    # Plot corners
    for i, (corner, color, label) in enumerate(zip(corners, corner_colors, corner_labels)):
        ax.scatter(corner[0], corner[1], c=color, s=200, marker='o',
                   edgecolors='white', linewidths=2, zorder=5)
        ax.annotate(f'{label}\n({corner[0]:.0f}, {corner[1]:.0f})',
                   (corner[0], corner[1]),
                   textcoords="offset points",
                   xytext=(10, 10),
                   fontsize=10,
                   color=color,
                   fontweight='bold',
                   bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))

    # Draw lines connecting corners (as a quadrilateral)
    # First, we need to order them properly
    corners_ordered = order_points_clockwise(corners)
    for i in range(4):
        start = corners_ordered[i]
        end = corners_ordered[(i + 1) % 4]
        ax.plot([start[0], end[0]], [start[1], end[1]],
               'c-', linewidth=2, alpha=0.7)

    ax.set_title(title, fontsize=14)
    ax.axis('off')
    plt.tight_layout()
    plt.show()


def order_points_clockwise(pts: np.ndarray) -> np.ndarray:
    """
    Order points in clockwise order starting from top-left.

    Parameters
    ----------
    pts : np.ndarray
        Array of shape (4, 2) with unordered points.

    Returns
    -------
    np.ndarray
        Array of shape (4, 2) with points ordered as [TL, TR, BR, BL].
    """
    pts = np.array(pts)

    # Sort by sum of coordinates (top-left has smallest sum, bottom-right has largest)
    s = pts.sum(axis=1)

    # Sort by difference of coordinates (top-right has smallest diff, bottom-left has largest)
    d = np.diff(pts, axis=1).flatten()

    ordered = np.zeros((4, 2), dtype=np.float32)
    ordered[0] = pts[np.argmin(s)]  # Top-left
    ordered[2] = pts[np.argmax(s)]  # Bottom-right
    ordered[1] = pts[np.argmin(d)]  # Top-right
    ordered[3] = pts[np.argmax(d)]  # Bottom-left

    return ordered


# ============================================================================
# ROTATION ESTIMATION (OCR + VLM)
# ============================================================================

def estimate_rotation_with_vlm(
    frame: np.ndarray,
    corners: np.ndarray,
    ocr_model: Any,
    vlm_client: Any,
    visualize: bool = True
) -> int:
    """
    Use OCR + VLM reasoning to estimate the board rotation angle.

    This function:
    1. Crops regions around the board edges where file letters (a-h) and
       rank numbers (1-8) should appear.
    2. Runs OCR on these regions to detect any visible text.
    3. Uses a VLM (Vision Language Model) to reason about the orientation
       based on detected text and visual analysis.

    Parameters
    ----------
    frame : np.ndarray
        Input frame in BGR format.
    corners : np.ndarray
        Array of shape (4, 2) with the four board corners.
    ocr_model : Any
        Pre-initialized OCR model (e.g., EasyOCR reader, Tesseract, etc.).
    vlm_client : Any
        Pre-initialized VLM client (e.g., OpenAI API, Claude, etc.).
    visualize : bool
        If True, display visualization of cropped regions.

    Returns
    -------
    int
        Estimated rotation in degrees: 0, 90, 180, or 270.
        - 0°: White at bottom (standard orientation)
        - 90°: White at right
        - 180°: White at top
        - 270°: White at left

    Notes
    -----
    The function handles cases where OCR may not detect clear text by
    combining multiple signals:
    - Detected file letters (a-h)
    - Detected rank numbers (1-8)
    - VLM reasoning about piece positions and board appearance

    Example
    -------
    >>> rotation = estimate_rotation_with_vlm(frame, corners, ocr_model, vlm_client)
    >>> print(f"Board rotation: {rotation}°")
    """

    # -------------------------------------------------------------------------
    # Step 1: Order corners and compute edge regions
    # -------------------------------------------------------------------------

    corners_ordered = order_points_clockwise(corners)
    tl, tr, br, bl = corners_ordered

    # Calculate border regions for OCR
    # These regions extend outward from each edge of the board
    border_width = 60  # pixels

    # -------------------------------------------------------------------------
    # Step 2: Crop the four edge regions
    # -------------------------------------------------------------------------

    edge_crops = {}
    edge_regions = {}

    # Top edge (potential rank numbers or file letters)
    edge_regions['top'] = crop_edge_region(frame, tl, tr, 'top', border_width)

    # Bottom edge
    edge_regions['bottom'] = crop_edge_region(frame, bl, br, 'bottom', border_width)

    # Left edge
    edge_regions['left'] = crop_edge_region(frame, tl, bl, 'left', border_width)

    # Right edge
    edge_regions['right'] = crop_edge_region(frame, tr, br, 'right', border_width)

    # -------------------------------------------------------------------------
    # Step 3: Run OCR on each edge region
    # -------------------------------------------------------------------------

    ocr_results = {}

    for edge_name, region in edge_regions.items():
        if region is not None and region.size > 0:
            try:
                # Run OCR (adapt based on your OCR model's API)
                detected_text = run_ocr(region, ocr_model)
                ocr_results[edge_name] = detected_text
            except Exception as e:
                print(f"OCR failed for {edge_name} edge: {e}")
                ocr_results[edge_name] = []

    # -------------------------------------------------------------------------
    # Step 4: Analyze OCR results
    # -------------------------------------------------------------------------

    file_letters = set('abcdefgh')
    rank_numbers = set('12345678')

    edge_analysis = {}
    for edge_name, texts in ocr_results.items():
        edge_analysis[edge_name] = {
            'files': [],
            'ranks': [],
            'raw': texts
        }

        for text in texts:
            text_lower = text.lower().strip()
            for char in text_lower:
                if char in file_letters:
                    edge_analysis[edge_name]['files'].append(char)
                if char in rank_numbers:
                    edge_analysis[edge_name]['ranks'].append(char)

    # -------------------------------------------------------------------------
    # Step 5: Use VLM for reasoning (or rule-based fallback)
    # -------------------------------------------------------------------------

    rotation = determine_rotation_from_analysis(
        edge_analysis, frame, corners_ordered, vlm_client
    )

    # -------------------------------------------------------------------------
    # Step 6: Visualization
    # -------------------------------------------------------------------------

    if visualize:
        visualize_rotation_estimation(
            frame, corners_ordered, edge_regions, ocr_results, rotation
        )

    return rotation


def crop_edge_region(
    frame: np.ndarray,
    corner1: np.ndarray,
    corner2: np.ndarray,
    edge_type: str,
    border_width: int
) -> Optional[np.ndarray]:
    """
    Crop the region along a board edge where labels might appear.

    Parameters
    ----------
    frame : np.ndarray
        Input frame.
    corner1, corner2 : np.ndarray
        Two corners defining the edge.
    edge_type : str
        One of 'top', 'bottom', 'left', 'right'.
    border_width : int
        Width of the border region to crop.

    Returns
    -------
    Optional[np.ndarray]
        Cropped region, or None if invalid.
    """
    h, w = frame.shape[:2]

    # Calculate the direction vector along the edge
    edge_vec = corner2 - corner1
    edge_length = np.linalg.norm(edge_vec)

    if edge_length < 10:  # Too small
        return None

    # Normalize
    edge_unit = edge_vec / edge_length

    # Perpendicular vector (outward from board)
    if edge_type == 'top':
        perp = np.array([edge_unit[1], -edge_unit[0]])  # Upward
    elif edge_type == 'bottom':
        perp = np.array([-edge_unit[1], edge_unit[0]])  # Downward
    elif edge_type == 'left':
        perp = np.array([edge_unit[1], -edge_unit[0]])  # Leftward
    else:  # right
        perp = np.array([-edge_unit[1], edge_unit[0]])  # Rightward

    # Define the crop quadrilateral
    p1 = corner1
    p2 = corner2
    p3 = corner2 + perp * border_width
    p4 = corner1 + perp * border_width

    # Create a bounding box
    all_pts = np.array([p1, p2, p3, p4])
    x_min = max(0, int(np.min(all_pts[:, 0])))
    x_max = min(w, int(np.max(all_pts[:, 0])))
    y_min = max(0, int(np.min(all_pts[:, 1])))
    y_max = min(h, int(np.max(all_pts[:, 1])))

    if x_max <= x_min or y_max <= y_min:
        return None

    crop = frame[y_min:y_max, x_min:x_max]
    return crop


def run_ocr(region: np.ndarray, ocr_model: Any) -> List[str]:
    """
    Run OCR on a cropped region.

    Parameters
    ----------
    region : np.ndarray
        Cropped image region.
    ocr_model : Any
        OCR model instance.

    Returns
    -------
    List[str]
        List of detected text strings.
    """
    # This implementation assumes EasyOCR-like API
    # Adapt based on your specific OCR model

    try:
        # EasyOCR style
        results = ocr_model.readtext(region)
        texts = [result[1] for result in results]
    except AttributeError:
        try:
            # Tesseract style (pytesseract)
            import pytesseract
            text = pytesseract.image_to_string(region)
            texts = [t.strip() for t in text.split() if t.strip()]
        except:
            # Generic fallback
            texts = []

    return texts


def determine_rotation_from_analysis(
    edge_analysis: Dict,
    frame: np.ndarray,
    corners: np.ndarray,
    vlm_client: Any
) -> int:
    """
    Determine board rotation from edge analysis and VLM reasoning.

    Parameters
    ----------
    edge_analysis : Dict
        Analysis of OCR results for each edge.
    frame : np.ndarray
        Original frame.
    corners : np.ndarray
        Ordered corners [TL, TR, BR, BL].
    vlm_client : Any
        VLM client for reasoning.

    Returns
    -------
    int
        Rotation in degrees (0, 90, 180, or 270).
    """

    # -------------------------------------------------------------------------
    # Rule-based analysis first
    # -------------------------------------------------------------------------

    # Standard orientation: files (a-h) on bottom/top, ranks (1-8) on left/right
    # White starts at bottom (ranks 1-2), Black at top (ranks 7-8)

    scores = {0: 0, 90: 0, 180: 0, 270: 0}

    # Check for file letters (a-h)
    if edge_analysis['bottom']['files']:
        # Files at bottom suggest 0° or 180°
        files = edge_analysis['bottom']['files']
        if 'a' in files or 'b' in files:
            scores[0] += 2  # White's view
        if 'h' in files or 'g' in files:
            scores[180] += 2  # Black's view

    if edge_analysis['top']['files']:
        files = edge_analysis['top']['files']
        if 'h' in files or 'g' in files:
            scores[0] += 2
        if 'a' in files or 'b' in files:
            scores[180] += 2

    # Check for rank numbers (1-8)
    if edge_analysis['left']['ranks']:
        ranks = edge_analysis['left']['ranks']
        if '1' in ranks or '2' in ranks:
            scores[0] += 2
        if '8' in ranks or '7' in ranks:
            scores[180] += 2

    if edge_analysis['right']['ranks']:
        ranks = edge_analysis['right']['ranks']
        if '8' in ranks or '7' in ranks:
            scores[0] += 2
        if '1' in ranks or '2' in ranks:
            scores[180] += 2

    # Check for rotated orientation (90° or 270°)
    if edge_analysis['bottom']['ranks'] or edge_analysis['top']['ranks']:
        if edge_analysis['bottom']['ranks']:
            ranks = edge_analysis['bottom']['ranks']
            if '1' in ranks or '2' in ranks:
                scores[90] += 2
            if '8' in ranks or '7' in ranks:
                scores[270] += 2
        if edge_analysis['top']['ranks']:
            ranks = edge_analysis['top']['ranks']
            if '8' in ranks or '7' in ranks:
                scores[90] += 2
            if '1' in ranks or '2' in ranks:
                scores[270] += 2

    # -------------------------------------------------------------------------
    # VLM reasoning for ambiguous cases
    # -------------------------------------------------------------------------

    max_score = max(scores.values())

    # If scores are tied or low confidence, use VLM
    if max_score < 2 or list(scores.values()).count(max_score) > 1:
        vlm_rotation = query_vlm_for_rotation(frame, corners, vlm_client)
        if vlm_rotation is not None:
            return vlm_rotation

    # Return the rotation with highest score
    return max(scores, key=scores.get)


def query_vlm_for_rotation(
    frame: np.ndarray,
    corners: np.ndarray,
    vlm_client: Any
) -> Optional[int]:
    """
    Query VLM to determine board rotation.

    Parameters
    ----------
    frame : np.ndarray
        Input frame.
    corners : np.ndarray
        Ordered corners.
    vlm_client : Any
        VLM client.

    Returns
    -------
    Optional[int]
        Rotation in degrees, or None if VLM is unavailable.
    """
    if vlm_client is None:
        print("[VLM] vlm_client is None, skipping VLM")
        return None

    print("[VLM] Calling Gemini for rotation...")

    # Prepare the prompt
    prompt = """
    You are given an image of a real chessboard. Your ONLY task is to locate the coordinate labels
    (the printed letters a–h and numbers 1–8) around the edges of the board and determine which edge
    corresponds to White's home side (rank 1).

    CRITICAL RULES (follow strictly):
    1. Identify the edge where the FILE LETTERS appear in STRICT INCREASING ORDER:
          a b c d e f g h
      - These letters must appear left-to-right in that order ON THAT EDGE.
      - If the letters appear reversed (h g f ... a), then that is NOT the White side.

    2. Completely IGNORE:
      - Chess pieces
      - Board colors or orientation
      - Perspective distortion
      - Lighting or reflections
      - Any visual cues OTHER than the printed coordinate letters/numbers

    3. After identifying the correct White-side edge, output the rotation angle based on its
      position IN THE IMAGE (not in real world):

      - Output **0**   if the a→b→c→d→e→f→g→h sequence is on the BOTTOM edge of the image.
      - Output **90**  if the a→b→c→d→e→f→g→h sequence is on the LEFT edge of the image.
      - Output **180** if the a→b→c→d→e→f→g→h sequence is on the TOP edge of the image.
      - Output **270** if the a→b→c→d→e→f→g→h sequence is on the RIGHT edge of the image.

    SANITY CHECK (mandatory):
    - If the letters appear top-to-bottom or bottom-to-top, treat that as LEFT or RIGHT edge.
    - If the letters appear reversed (h→a) on the bottom (h→g→f→e→d→c→b→a), then the correct edge is the opposite side.

    FINAL INSTRUCTIONS:
    - Respond with ONLY ONE INTEGER: 0, 90, 180, or 270.
    - No words. No punctuation. No explanation. Only the number.
    """


    try:
        # This is a generic VLM API call - adapt for your specific VLM
        # Example for OpenAI-style API:

        # Encode image to base64
        import base64
        _, buffer = cv2.imencode('.jpg', frame)
        image_base64 = base64.b64encode(buffer).decode('utf-8')

        # # Make API call (pseudo-code - adapt to your VLM client)
        # response = vlm_client.analyze(
        #     image=image_base64,
        #     prompt=prompt
        # )


        # # Parse response
        # response_text = str(response).strip()

        # print("[VLM] Parsed response:", response_text)
        # for rotation in [0, 90, 180, 270]:
        #     if str(rotation) in response_text:
        #         return rotation

        response_text = vlm_client.analyze(image_base64, prompt)
        response_text = (response_text or "").strip()
        print("[VLM] Raw Gemini response:", repr(response_text))

        # Regex: match whole number 0 or 90 or 180 or 270
        m = re.search(r'\b(0|90|180|270)\b', response_text)
        if m:
            rotation = int(m.group(1))
            print("[VLM] Parsed rotation:", rotation)
            return rotation
        else:
            print("[VLM] Could not parse rotation from response")

    except Exception as e:
        print(f"VLM query failed: {e}")

    return None


def visualize_rotation_estimation(
    frame: np.ndarray,
    corners: np.ndarray,
    edge_regions: Dict[str, np.ndarray],
    ocr_results: Dict[str, List[str]],
    rotation: int
) -> None:
    """
    Visualize the rotation estimation process.

    Parameters
    ----------
    frame : np.ndarray
        Original frame.
    corners : np.ndarray
        Ordered corners.
    edge_regions : Dict[str, np.ndarray]
        Cropped edge regions.
    ocr_results : Dict[str, List[str]]
        OCR results for each edge.
    rotation : int
        Estimated rotation.
    """
    fig = plt.figure(figsize=(16, 12))

    # Main image with corners
    ax1 = fig.add_subplot(2, 3, 1)
    ax1.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # Draw corners and edges
    colors = ['red', 'green', 'blue', 'orange']
    labels = ['TL', 'TR', 'BR', 'BL']
    for i, (corner, color, label) in enumerate(zip(corners, colors, labels)):
        ax1.scatter(corner[0], corner[1], c=color, s=100, zorder=5)
        ax1.annotate(label, corner, fontsize=10, color=color)

    # Draw quadrilateral
    corners_closed = np.vstack([corners, corners[0]])
    ax1.plot(corners_closed[:, 0], corners_closed[:, 1], 'c-', linewidth=2)
    ax1.set_title(f'Detected Board\nEstimated Rotation: {rotation}°')
    ax1.axis('off')

    # Edge regions
    edge_names = ['top', 'bottom', 'left', 'right']
    positions = [(2, 3, 2), (2, 3, 5), (2, 3, 4), (2, 3, 3)]

    for edge_name, pos in zip(edge_names, positions):
        ax = fig.add_subplot(*pos)
        region = edge_regions.get(edge_name)

        if region is not None and region.size > 0:
            ax.imshow(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))
            ocr_text = ocr_results.get(edge_name, [])
            ax.set_title(f'{edge_name.capitalize()} Edge\nOCR: {ocr_text}')
        else:
            ax.text(0.5, 0.5, 'No Region', ha='center', va='center')
            ax.set_title(f'{edge_name.capitalize()} Edge')
        ax.axis('off')

    # Rotation indicator
    ax6 = fig.add_subplot(2, 3, 6)

    # Draw rotation diagram
    theta = np.radians(rotation)
    ax6.arrow(0, 0, np.cos(theta) * 0.8, np.sin(theta) * 0.8,
             head_width=0.1, head_length=0.1, fc='blue', ec='blue')
    ax6.set_xlim(-1, 1)
    ax6.set_ylim(-1, 1)
    ax6.set_aspect('equal')
    ax6.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax6.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    ax6.set_title(f'Rotation: {rotation}°\n(0° = White at bottom)')

    plt.tight_layout()
    plt.show()


# ============================================================================
# CORNER ORDERING AND HOMOGRAPHY
# ============================================================================

def order_corners_tl_tr_br_bl(
    corners: np.ndarray,
    rotation_deg: int
) -> np.ndarray:
    """
    Order corners as [TL, TR, BR, BL] based on the known board rotation.

    This function takes corners in arbitrary order and reorders them based
    on the board's rotation relative to the camera view.

    Parameters
    ----------
    corners : np.ndarray
        Array of shape (4, 2) with corner coordinates in arbitrary order.
    rotation_deg : int
        Board rotation in degrees (0, 90, 180, or 270).
        - 0°: a1 is at bottom-left of image
        - 90°: a1 is at bottom-right of image
        - 180°: a1 is at top-right of image
        - 270°: a1 is at top-left of image

    Returns
    -------
    np.ndarray
        Array of shape (4, 2) with corners ordered as [TL, TR, BR, BL]
        in image coordinates, where:
        - TL: Top-left of the board (h8 after rotation correction)
        - TR: Top-right of the board
        - BR: Bottom-right of the board
        - BL: Bottom-left of the board (a1 after rotation correction)

    Example
    -------
    >>> corners = np.array([[100, 100], [500, 100], [500, 500], [100, 500]])
    >>> ordered = order_corners_tl_tr_br_bl(corners, rotation_deg=0)
    """
    # First, get geometrically ordered corners (by position in image)
    ordered = order_points_clockwise(corners)

    # ordered is now [TL, TR, BR, BL] in IMAGE coordinates
    # We need to adjust based on rotation so that our grid mapping is correct

    # The rotation tells us where a1 is relative to the image:
    # - 0°: a1 at BL -> standard, no adjustment needed
    # - 90°: a1 at TL -> rotate corner labels 90° clockwise
    # - 180°: a1 at TR -> rotate corner labels 180°
    # - 270°: a1 at BR -> rotate corner labels 270° clockwise

    # However, for homography, we want corners in consistent image order [TL, TR, BR, BL]
    # The rotation will be handled when we assign algebraic names to squares

    return ordered


def compute_homography(
    corners_tl_tr_br_bl: np.ndarray,
    board_size: int = DEFAULT_BOARD_SIZE
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute homography matrix to warp the board to a square image.

    Parameters
    ----------
    corners_tl_tr_br_bl : np.ndarray
        Array of shape (4, 2) with corners ordered as [TL, TR, BR, BL].
    board_size : int
        Size of the output square image (board_size x board_size).

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        - H: 3x3 homography matrix
        - dst_points: Destination points in the warped image

    Notes
    -----
    The homography maps:
    - TL (source) -> (0, 0) (destination)
    - TR (source) -> (board_size, 0) (destination)
    - BR (source) -> (board_size, board_size) (destination)
    - BL (source) -> (0, board_size) (destination)

    Example
    -------
    >>> H, dst = compute_homography(corners, board_size=800)
    >>> warped = cv2.warpPerspective(frame, H, (800, 800))
    """
    src_points = corners_tl_tr_br_bl.astype(np.float32)

    # Destination points for a square board
    dst_points = np.array([
        [0, 0],                      # Top-left
        [board_size - 1, 0],         # Top-right
        [board_size - 1, board_size - 1],  # Bottom-right
        [0, board_size - 1]          # Bottom-left
    ], dtype=np.float32)

    # Compute homography
    H, status = cv2.findHomography(src_points, dst_points, cv2.RANSAC, 5.0)

    if H is None:
        raise ValueError("Failed to compute homography matrix")

    return H, dst_points


def warp_board(
    frame: np.ndarray,
    H: np.ndarray,
    board_size: int = DEFAULT_BOARD_SIZE
) -> np.ndarray:
    """
    Apply perspective warp to obtain a top-down view of the board.

    Parameters
    ----------
    frame : np.ndarray
        Input frame in BGR format.
    H : np.ndarray
        3x3 homography matrix.
    board_size : int
        Size of the output square image.

    Returns
    -------
    np.ndarray
        Warped board image of shape (board_size, board_size, 3).

    Example
    -------
    >>> H, _ = compute_homography(corners)
    >>> warped = warp_board(frame, H, board_size=800)
    """
    warped = cv2.warpPerspective(
        frame, H, (board_size, board_size),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=(0, 0, 0)
    )

    return warped


# ============================================================================
# BOARD GRID COMPUTATION
# ============================================================================

def compute_board_grid(
    board_size: int = DEFAULT_BOARD_SIZE,
    n: int = 8,
    rotation_deg: int = 0
) -> Tuple[np.ndarray, List[str]]:
    """
    Precompute the centers of each of the 8x8 squares in the warped board.

    Parameters
    ----------
    board_size : int
        Size of the warped board image.
    n : int
        Number of squares per side (8 for standard chess).
    rotation_deg : int
        Board rotation in degrees (0, 90, 180, or 270).
        This affects how algebraic names are assigned to squares.

    Returns
    -------
    Tuple[np.ndarray, List[str]]
        - centers: Array of shape (64, 2) with (x, y) coordinates of square centers.
        - square_names: List of 64 algebraic names (e.g., ['a1', 'b1', ..., 'h8']).

    Notes
    -----
    The warped image has (0, 0) at top-left and (board_size, board_size) at bottom-right.

    With rotation = 0° (White at bottom):
    - Top-left square (row=0, col=0) is a8
    - Bottom-left square (row=7, col=0) is a1
    - Bottom-right square (row=7, col=7) is h1

    Example
    -------
    >>> centers, names = compute_board_grid(board_size=800, rotation_deg=0)
    >>> print(f"Square a1 center: {centers[names.index('a1')]}")
    """
    square_size = board_size / n

    centers = []
    square_names = []

    for row in range(n):  # 0 = top row in image
        for col in range(n):  # 0 = left column in image
            # Center of this square in pixel coordinates
            cx = (col + 0.5) * square_size
            cy = (row + 0.5) * square_size
            centers.append([cx, cy])

            # Determine algebraic name based on rotation
            if rotation_deg == 0:
                # Standard: a8 at top-left, a1 at bottom-left
                file_idx = col
                rank_idx = 7 - row
            elif rotation_deg == 90:
                # Board rotated 90° CW: a1 at top-left
                file_idx = row
                rank_idx = col
            elif rotation_deg == 180:
                # Board rotated 180°: h1 at top-left
                file_idx = 7 - col
                rank_idx = row
            elif rotation_deg == 270:
                # Board rotated 270° CW (90° CCW): h8 at top-left
                file_idx = 7 - row
                rank_idx = 7 - col
            else:
                raise ValueError(f"Invalid rotation: {rotation_deg}")

            file_char = FILES[file_idx]
            rank_char = RANKS[rank_idx]
            square_names.append(f"{file_char}{rank_char}")

    return np.array(centers), square_names


def visualize_board_grid(
    warped_board: np.ndarray,
    centers: np.ndarray,
    square_names: List[str],
    board_size: int = DEFAULT_BOARD_SIZE,
    figsize: Tuple[int, int] = (12, 12)
) -> None:
    """
    Visualize the board grid with square centers and algebraic names.

    Parameters
    ----------
    warped_board : np.ndarray
        Warped board image.
    centers : np.ndarray
        Array of shape (64, 2) with square centers.
    square_names : List[str]
        List of 64 algebraic names.
    board_size : int
        Size of the warped board.
    figsize : Tuple[int, int]
        Figure size.
    """
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # Display warped board
    ax.imshow(cv2.cvtColor(warped_board, cv2.COLOR_BGR2RGB))

    # Draw grid lines
    square_size = board_size / 8
    for i in range(9):
        # Vertical lines
        ax.axvline(x=i * square_size, color='cyan', linewidth=1, alpha=0.7)
        # Horizontal lines
        ax.axhline(y=i * square_size, color='cyan', linewidth=1, alpha=0.7)

    # Plot centers and names
    for center, name in zip(centers, square_names):
        ax.scatter(center[0], center[1], c='red', s=30, zorder=5)
        ax.annotate(
            name,
            (center[0], center[1]),
            fontsize=8,
            color='yellow',
            ha='center',
            va='center',
            fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.1', facecolor='black', alpha=0.5)
        )

    ax.set_title('Warped Board with Grid Centers and Algebraic Names', fontsize=14)
    ax.axis('off')
    plt.tight_layout()
    plt.show()


# ============================================================================
# PIECE DETECTION (YOLO)
# ============================================================================

def detect_pieces_yolo(
    frame: np.ndarray,
    yolo_model: Any,
    conf_thres: float = 0.3,
    iou_thres: float = 0.45,
    rotation_deg: int = 0
) -> List[PieceDetection]:
    """
    Run the pre-trained YOLO model on the original frame to detect chess pieces.

    Parameters
    ----------
    frame : np.ndarray
        Input frame in BGR format.
    yolo_model : Any
        Pre-trained YOLO model (already loaded and ready for inference).
        Expected to be a YOLOv5/YOLOv8 model or similar.
    conf_thres : float
        Confidence threshold for detections.
    iou_thres : float
        IoU threshold for NMS.

    Returns
    -------
    List[PieceDetection]
        List of detected pieces with bounding boxes, centers, class info, etc.

    Notes
    -----
    The function assumes the YOLO model is already loaded. Common class names
    for chess pieces might include:
    - 'white_king', 'white_queen', 'white_rook', etc.
    - Or indexed classes that map to piece types.

    Example
    -------
    >>> detections = detect_pieces_yolo(frame, yolo_model, conf_thres=0.5)
    >>> for det in detections:
    ...     print(f"{det.class_name} at {det.center} (conf: {det.confidence:.2f})")
    """
    detections = []
    frame_for_yolo = sharpen_laplacian(frame)
    try:
        # -------------------------------------------------------------------------
        # Option A: YOLOv5/YOLOv8 style (ultralytics)
        # -------------------------------------------------------------------------
        # Run inference
        results = yolo_model(frame, conf=conf_thres, iou=iou_thres, verbose=False)

        # Process results
        # YOLOv8 returns a list of Results objects
        if hasattr(results, '__iter__'):
            result = results[0]  # Get first (and usually only) result
        else:
            result = results

        # Get boxes, classes, and confidences
        if hasattr(result, 'boxes'):
            boxes = result.boxes

            for i in range(len(boxes)):
                # Get bounding box
                box = boxes.xyxy[i].cpu().numpy()
                x1, y1, x2, y2 = map(int, box)

                # Get class and confidence
                conf = float(boxes.conf[i].cpu().numpy())
                cls_id = int(boxes.cls[i].cpu().numpy())

                # Get class name
                if hasattr(result, 'names'):
                    cls_name = result.names[cls_id]
                else:
                    cls_name = str(cls_id)

                # Calculate center
                cx, cy = compute_piece_center_with_rotation(x1, y1, x2, y2, rotation_deg)


                detection = PieceDetection(
                    bbox=(x1, y1, x2, y2),
                    center=(cx, cy),
                    class_id=cls_id,
                    class_name=cls_name,
                    confidence=conf
                )
                detections.append(detection)

        else:
            # -------------------------------------------------------------------------
            # Option B: Legacy YOLOv5 style
            # -------------------------------------------------------------------------
            # results.xyxy[0] contains [x1, y1, x2, y2, confidence, class]
            pred = results.xyxy[0].cpu().numpy()

            for det in pred:
                x1, y1, x2, y2, conf, cls_id = det
                x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
                cls_id = int(cls_id)

                # Get class name from model
                if hasattr(results, 'names'):
                    cls_name = results.names[cls_id]
                else:
                    cls_name = str(cls_id)

                cx, cy = compute_piece_center_with_rotation(x1, y1, x2, y2, rotation_deg)

                detection = PieceDetection(
                    bbox=(x1, y1, x2, y2),
                    center=(cx, cy),
                    class_id=cls_id,
                    class_name=cls_name,
                    confidence=conf
                )
                detections.append(detection)

    except Exception as e:
        print(f"YOLO detection error: {e}")
        print("Returning empty detection list.")

    return detections


def visualize_yolo_detections(
    frame: np.ndarray,
    detections: List[PieceDetection],
    title: str = "YOLO Chess Piece Detections",
    figsize: Tuple[int, int] = (14, 10)
) -> np.ndarray:
    """
    Visualize YOLO detections on the original frame.

    Parameters
    ----------
    frame : np.ndarray
        Original BGR frame.
    detections : List[PieceDetection]
        List of detected pieces.
    title : str
        Title for the visualization.
    figsize : Tuple[int, int]
        Figure size.

    Returns
    -------
    np.ndarray
        Annotated frame with detection boxes and labels.
    """
    annotated = frame.copy()

    # Define colors for different piece types
    piece_colors = {
        'white': (255, 255, 200),  # Light yellow
        'black': (100, 100, 100),  # Dark gray
        'default': (0, 255, 0)     # Green
    }

    for det in detections:
        x1, y1, x2, y2 = det.bbox

        # Determine color based on piece color
        cls_lower = det.class_name.lower()
        if 'white' in cls_lower:
            color = piece_colors['white']
        elif 'black' in cls_lower:
            color = piece_colors['black']
        else:
            color = piece_colors['default']

        # Draw bounding box
        cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)

        # Draw center point
        cx, cy = int(det.center[0]), int(det.center[1])
        cv2.circle(annotated, (cx, cy), 5, (0, 0, 255), -1)

        # Create label
        label = f"{det.class_name}: {det.confidence:.2f}"

        # Get label size for background
        (label_w, label_h), baseline = cv2.getTextSize(
            label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
        )

        # Draw label background
        cv2.rectangle(
            annotated,
            (x1, y1 - label_h - 10),
            (x1 + label_w + 5, y1),
            color, -1
        )

        # Draw label text
        text_color = (0, 0, 0) if 'white' in cls_lower else (255, 255, 255)
        cv2.putText(
            annotated, label,
            (x1 + 2, y1 - 5),
            cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 1
        )

    # Display using matplotlib
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
    ax.set_title(f"{title}\nDetected {len(detections)} pieces", fontsize=14)
    ax.axis('off')
    plt.tight_layout()
    plt.show()

    return annotated


# ============================================================================
# POINT PROJECTION AND SQUARE ASSIGNMENT
# ============================================================================

def project_points_with_homography(
    points: np.ndarray,
    H: np.ndarray
) -> np.ndarray:
    """
    Apply homography H to transform points from original to warped coordinates.

    Parameters
    ----------
    points : np.ndarray
        Array of points with shape (N, 2) where each row is (x, y).
    H : np.ndarray
        3x3 homography matrix.

    Returns
    -------
    np.ndarray
        Transformed points with shape (N, 2).

    Notes
    -----
    Uses OpenCV's perspectiveTransform for accurate projection.

    Example
    -------
    >>> original_points = np.array([[100, 200], [300, 400]])
    >>> projected = project_points_with_homography(original_points, H)
    """
    if len(points) == 0:
        return np.array([])

    points = np.array(points, dtype=np.float32)

    # OpenCV's perspectiveTransform expects shape (N, 1, 2)
    points_reshaped = points.reshape(-1, 1, 2)

    # Apply homography
    transformed = cv2.perspectiveTransform(points_reshaped, H)

    # Reshape back to (N, 2)
    return transformed.reshape(-1, 2)


def assign_pieces_to_squares(
    projected_centers: np.ndarray,
    piece_types: List[str],
    grid_centers: np.ndarray,
    square_names: List[str],
    max_distance: float = None
) -> Dict[str, str]:
    """
    Assign each detected piece to its nearest grid square.

    Parameters
    ----------
    projected_centers : np.ndarray
        Array of shape (N, 2) with projected piece centers in warped coordinates.
    piece_types : List[str]
        List of piece type codes (e.g., ['P', 'N', 'B', ...]) for each detection.
    grid_centers : np.ndarray
        Array of shape (64, 2) with precomputed square centers.
    square_names : List[str]
        List of 64 algebraic names for the squares.
    max_distance : float, optional
        Maximum allowed distance for assignment. If None, uses half the square size.

    Returns
    -------
    Dict[str, str]
        Dictionary mapping square name to piece code (e.g., {'e4': 'P', 'a1': 'R'}).

    Notes
    -----
    If two pieces are assigned to the same square, the one with smaller distance
    is kept (could indicate detection error or piece overlap).

    Example
    -------
    >>> board = assign_pieces_to_squares(proj_centers, types, grid_centers, names)
    >>> print(board)  # {'e4': 'P', 'd4': 'p', ...}
    """
    if len(projected_centers) == 0:
        return {}

    # Calculate square size for distance threshold
    if len(grid_centers) >= 2:
        # Estimate square size from grid centers
        dx = grid_centers[1][0] - grid_centers[0][0]
        square_size = abs(dx) if dx != 0 else 100
    else:
        square_size = 100

    if max_distance is None:
        max_distance = square_size * 0.6  # Allow some margin

    board = {}
    assignment_distances = {}

    for i, (center, piece_type) in enumerate(zip(projected_centers, piece_types)):
        # Calculate distances to all grid centers
        distances = np.sqrt(np.sum((grid_centers - center) ** 2, axis=1))

        # Find nearest square
        nearest_idx = np.argmin(distances)
        min_distance = distances[nearest_idx]

        # Check if within threshold
        if min_distance <= max_distance:
            square_name = square_names[nearest_idx]

            # Handle conflicts (multiple pieces assigned to same square)
            if square_name in board:
                if min_distance < assignment_distances[square_name]:
                    # This piece is closer, replace
                    board[square_name] = piece_type
                    assignment_distances[square_name] = min_distance
            else:
                board[square_name] = piece_type
                assignment_distances[square_name] = min_distance

    return board


def visualize_piece_assignments(
    warped_board: np.ndarray,
    projected_centers: np.ndarray,
    piece_types: List[str],
    grid_centers: np.ndarray,
    square_names: List[str],
    board_dict: Dict[str, str],
    figsize: Tuple[int, int] = (12, 12)
) -> None:
    """
    Visualize the piece assignments on the warped board.

    Parameters
    ----------
    warped_board : np.ndarray
        Warped board image.
    projected_centers : np.ndarray
        Projected piece centers.
    piece_types : List[str]
        Piece type codes.
    grid_centers : np.ndarray
        Grid square centers.
    square_names : List[str]
        Square names.
    board_dict : Dict[str, str]
        Assignment result.
    figsize : Tuple[int, int]
        Figure size.
    """
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # Display warped board
    ax.imshow(cv2.cvtColor(warped_board, cv2.COLOR_BGR2RGB))

    # Draw grid centers (small dots)
    ax.scatter(
        grid_centers[:, 0], grid_centers[:, 1],
        c='cyan', s=20, alpha=0.5, label='Grid Centers'
    )

    # Draw projected piece centers
    for center, piece_type in zip(projected_centers, piece_types):
        is_white = piece_type.isupper()
        color = 'yellow' if is_white else 'magenta'

        ax.scatter(center[0], center[1], c=color, s=100, marker='x', linewidth=2)
        ax.annotate(
            piece_type,
            (center[0], center[1]),
            textcoords="offset points",
            xytext=(5, 5),
            fontsize=10,
            color=color,
            fontweight='bold'
        )

    # Draw assignments
    for square_name, piece_type in board_dict.items():
        idx = square_names.index(square_name)
        grid_center = grid_centers[idx]

        ax.scatter(
            grid_center[0], grid_center[1],
            c='green', s=150, marker='s', alpha=0.3
        )
        ax.annotate(
            f"{square_name}\n{piece_type}",
            (grid_center[0], grid_center[1]),
            fontsize=8,
            color='white',
            ha='center',
            va='center',
            fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.2', facecolor='green', alpha=0.7)
        )

    # Legend
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='cyan',
               markersize=8, label='Grid Centers'),
        Line2D([0], [0], marker='x', color='yellow', markersize=10,
               linestyle='None', label='White Pieces'),
        Line2D([0], [0], marker='x', color='magenta', markersize=10,
               linestyle='None', label='Black Pieces'),
        Line2D([0], [0], marker='s', color='green', markersize=10,
               alpha=0.3, linestyle='None', label='Assigned Squares')
    ]
    ax.legend(handles=legend_elements, loc='upper right')

    ax.set_title(f'Piece Assignments\n{len(board_dict)} pieces on board', fontsize=14)
    ax.axis('off')
    plt.tight_layout()
    plt.show()


# ============================================================================
# BOARD STATE AND FEN CONVERSION
# ============================================================================

def board_to_fen(board_dict: Dict[str, str]) -> str:
    """
    Convert a board representation to FEN-like string.

    Parameters
    ----------
    board_dict : Dict[str, str]
        Dictionary mapping square name to piece code.
        - Uppercase for white pieces: K, Q, R, B, N, P
        - Lowercase for black pieces: k, q, r, b, n, p

    Returns
    -------
    str
        FEN position string (piece placement only, not full FEN).

    Notes
    -----
    FEN notation represents the board from rank 8 to rank 1 (top to bottom),
    and from file a to file h (left to right) within each rank.
    Empty squares are counted and represented as numbers.

    Example
    -------
    >>> board = {'e1': 'K', 'e8': 'k', 'd1': 'Q', 'd8': 'q'}
    >>> fen = board_to_fen(board)
    >>> print(fen)  # '3qk3/8/8/8/8/8/8/3QK3'
    """
    fen_rows = []

    for rank in range(8, 0, -1):  # 8 down to 1
        row = ''
        empty_count = 0

        for file in 'abcdefgh':
            square = f'{file}{rank}'

            if square in board_dict:
                if empty_count > 0:
                    row += str(empty_count)
                    empty_count = 0
                row += board_dict[square]
            else:
                empty_count += 1

        if empty_count > 0:
            row += str(empty_count)

        fen_rows.append(row)

    return '/'.join(fen_rows)


def fen_to_board(fen: str) -> Dict[str, str]:
    """
    Convert a FEN position string to board dictionary.

    Parameters
    ----------
    fen : str
        FEN position string (piece placement part only).

    Returns
    -------
    Dict[str, str]
        Dictionary mapping square name to piece code.
    """
    board = {}
    rows = fen.split('/')

    for rank_idx, row in enumerate(rows):
        rank = 8 - rank_idx  # FEN starts from rank 8
        file_idx = 0

        for char in row:
            if char.isdigit():
                file_idx += int(char)
            else:
                file = FILES[file_idx]
                square = f'{file}{rank}'
                board[square] = char
                file_idx += 1

    return board


def get_initial_board() -> Dict[str, str]:
    """
    Get the initial chess board position.

    Returns
    -------
    Dict[str, str]
        Starting position with all 32 pieces.
    """
    return {
        # White pieces
        'a1': 'R', 'b1': 'N', 'c1': 'B', 'd1': 'Q',
        'e1': 'K', 'f1': 'B', 'g1': 'N', 'h1': 'R',
        'a2': 'P', 'b2': 'P', 'c2': 'P', 'd2': 'P',
        'e2': 'P', 'f2': 'P', 'g2': 'P', 'h2': 'P',
        # Black pieces
        'a7': 'p', 'b7': 'p', 'c7': 'p', 'd7': 'p',
        'e7': 'p', 'f7': 'p', 'g7': 'p', 'h7': 'p',
        'a8': 'r', 'b8': 'n', 'c8': 'b', 'd8': 'q',
        'e8': 'k', 'f8': 'b', 'g8': 'n', 'h8': 'r',
    }


# ============================================================================
# MOVE DETECTION
# ============================================================================

def detect_move_from_states(
    prev_board: Dict[str, str],
    curr_board: Dict[str, str]
) -> Optional[Move]:
    """
    Compare two board states and infer the move that was made.

    Now includes a robust fallback heuristic that tries to find the
    single best candidate move even if there is extra noise from
    missed / flickering detections.
    """

    # -------------------------------------------------------------------------
    # Compute differences
    # -------------------------------------------------------------------------
    vacated = []   # squares that had a piece but now don't
    occupied = []  # squares that now have a piece but didn't before
    changed = []   # squares where piece changed to a different one

    all_squares = set(prev_board.keys()) | set(curr_board.keys())

    for square in all_squares:
        prev_piece = prev_board.get(square)
        curr_piece = curr_board.get(square)

        if prev_piece and not curr_piece:
            vacated.append((square, prev_piece))
        elif not prev_piece and curr_piece:
            occupied.append((square, curr_piece))
        elif prev_piece != curr_piece and prev_piece and curr_piece:
            changed.append((square, prev_piece, curr_piece))

    # -------------------------------------------------------------------------
    # Case 1: Castling (exact patterns)
    # -------------------------------------------------------------------------
    if len(vacated) == 2 and len(occupied) == 2:
        vacated_squares = {sq for sq, _ in vacated}
        occupied_squares = {sq for sq, _ in occupied}

        # Kingside / queenside castling for white
        if vacated_squares == {'e1', 'h1'} and occupied_squares == {'g1', 'f1'}:
            return Move('e1', 'g1', 'K', is_castle=True, castle_side='kingside')
        if vacated_squares == {'e1', 'a1'} and occupied_squares == {'c1', 'd1'}:
            return Move('e1', 'c1', 'K', is_castle=True, castle_side='queenside')

        # For black
        if vacated_squares == {'e8', 'h8'} and occupied_squares == {'g8', 'f8'}:
            return Move('e8', 'g8', 'k', is_castle=True, castle_side='kingside')
        if vacated_squares == {'e8', 'a8'} and occupied_squares == {'c8', 'd8'}:
            return Move('e8', 'c8', 'k', is_castle=True, castle_side='queenside')

    # -------------------------------------------------------------------------
    # Case 2: Simple non-capture or promotion move
    # -------------------------------------------------------------------------
    if len(vacated) == 1 and len(occupied) == 1:
        from_sq, piece = vacated[0]
        to_sq, arriving_piece = occupied[0]

        if piece.upper() == arriving_piece.upper() or (
            piece.upper() == 'P' and arriving_piece.upper() in 'QRBN'
        ):
            move = Move(from_sq, to_sq, piece)
            # Promotion
            if piece.upper() == 'P' and arriving_piece.upper() != 'P':
                move.promotion_piece = arriving_piece
            return move

    # -------------------------------------------------------------------------
    # Case 3: Capture with two vacated squares and one occupied
    # -------------------------------------------------------------------------
    if len(vacated) == 2 and len(occupied) == 1:
        occ_sq, arriving_piece = occupied[0]
        from_sq = None
        piece = None
        captured = None

        for vac_sq, vac_piece in vacated:
            if vac_sq == occ_sq:
                captured = vac_piece
            else:
                from_sq = vac_sq
                piece = vac_piece

        if from_sq and captured:
            move = Move(from_sq, occ_sq, piece, captured_piece=captured)
            if piece.upper() == 'P' and arriving_piece.upper() != 'P':
                move.promotion_piece = arriving_piece
            return move

    # -------------------------------------------------------------------------
    # Case 4: Capture where attacker replaces defender on same square
    # -------------------------------------------------------------------------
    if len(vacated) == 1 and len(occupied) == 0 and len(changed) == 1:
        from_sq, piece = vacated[0]
        to_sq, captured, arriving_piece = changed[0]

        if piece.upper() == arriving_piece.upper() or (
            piece.upper() == 'P' and arriving_piece.upper() in 'QRBN'
        ):
            move = Move(from_sq, to_sq, piece, captured_piece=captured)
            if piece.upper() == 'P' and arriving_piece.upper() != 'P':
                move.promotion_piece = arriving_piece
            return move

    # -------------------------------------------------------------------------
    # Case 5: En passant (still fairly strict)
    # -------------------------------------------------------------------------
    if len(vacated) == 2 and len(occupied) == 1:
        occ_sq, arriving_piece = occupied[0]
        if arriving_piece.upper() == 'P':
            from_sq = None
            piece = None
            captured_sq = None
            captured_piece = None
            for vac_sq, vac_piece in vacated:
                if vac_piece.upper() == 'P' and from_sq is None:
                    from_sq = vac_sq
                    piece = vac_piece
                elif vac_piece.upper() == 'P':
                    captured_sq = vac_sq
                    captured_piece = vac_piece

            if from_sq and captured_sq:
                from_file = ord(from_sq[0])
                to_file = ord(occ_sq[0])
                cap_file = ord(captured_sq[0])
                if abs(from_file - to_file) == 1 and to_file == cap_file:
                    return Move(
                        from_square=from_sq,
                        to_square=occ_sq,
                        piece=piece,
                        captured_piece=captured_piece,
                        is_en_passant=True
                    )

    # -------------------------------------------------------------------------
    # NEW Case 6: Robust heuristic – pick best single candidate move
    # -------------------------------------------------------------------------
    candidates: List[Tuple[Move, int]] = []

    # Generate all plausible (from_sq -> to_sq) pairs
    for from_sq, prev_piece in vacated:
        for to_sq, curr_piece in occupied:
            # Must be same color & same type (allow pawn -> promoted piece handled earlier)
            if prev_piece.isupper() != curr_piece.isupper():
                continue
            if prev_piece.upper() != curr_piece.upper():
                # allow pawn "changing" only if to-piece is a promotion target
                if not (prev_piece.upper() == 'P' and curr_piece.upper() in 'QRBN'):
                    continue

            # Basic geometric legality
            if not could_piece_move_to(prev_piece, from_sq, to_sq, prev_board):
                continue

            # Compute how many other squares changed that this move doesn't explain
            explained_squares = {from_sq, to_sq}
            # For captures: if there was an opposite-color piece on to_sq before, it is explained
            prev_on_to = prev_board.get(to_sq)
            if prev_on_to and prev_on_to.isupper() != prev_piece.isupper():
                explained_squares.add(to_sq)

            unexplained = 0
            for sq in all_squares:
                if sq in explained_squares:
                    continue
                if prev_board.get(sq) != curr_board.get(sq):
                    unexplained += 1

            move = Move(from_sq, to_sq, prev_piece)
            candidates.append((move, unexplained))

    if candidates:
        # Choose candidate with minimum unexplained changes
        candidates.sort(key=lambda x: x[1])
        best_move, best_noise = candidates[0]

        # Heuristic threshold: accept if at most 4 other squares changed
        # (tune if needed)
        if best_noise <= 4:
            return best_move

    # -------------------------------------------------------------------------
    # Fallback: Could not determine move
    # -------------------------------------------------------------------------
    # (Optional) debug print – you can uncomment while tuning
    # print("DEBUG move not found")
    # print("  vacated:", vacated)
    # print("  occupied:", occupied)
    # print("  changed:", changed)

    return None



def get_piece_symbol_for_pgn(piece_code: str) -> str:
    """
    Get the PGN piece symbol (uppercase, no symbol for pawns).

    Parameters
    ----------
    piece_code : str
        Single character piece code (e.g., 'K', 'q', 'P').

    Returns
    -------
    str
        PGN piece symbol (empty string for pawns).
    """
    piece = piece_code.upper()
    if piece == 'P':
        return ''
    return piece


def move_to_pgn_notation(
    move: Move,
    board_before: Dict[str, str]
) -> str:
    """
    Convert a Move object to PGN notation.

    Parameters
    ----------
    move : Move
        The move to convert.
    board_before : Dict[str, str]
        Board state before the move (for disambiguation).

    Returns
    -------
    str
        PGN notation for the move (e.g., 'Nf3', 'exd5', 'O-O').

    Example
    -------
    >>> move = Move(from_square='e2', to_square='e4', piece='P')
    >>> pgn = move_to_pgn_notation(move, board_before)
    >>> print(pgn)  # 'e4'
    """

    # Castling
    if move.is_castle:
        if move.castle_side == 'kingside':
            return 'O-O'
        else:
            return 'O-O-O'

    piece_symbol = get_piece_symbol_for_pgn(move.piece)

    # Base notation
    notation = piece_symbol

    # Check if disambiguation is needed (for pieces other than pawns)
    if piece_symbol and piece_symbol != '':
        # Find other pieces of same type that could move to the target
        need_file = False
        need_rank = False

        for square, piece in board_before.items():
            if piece.upper() == move.piece.upper() and square != move.from_square:
                # Check if this piece could also move to the target
                # This is a simplified check - full legal move validation
                # would require more chess logic
                if could_piece_move_to(piece, square, move.to_square, board_before):
                    # Need disambiguation
                    from_file = move.from_square[0]
                    from_rank = move.from_square[1]
                    other_file = square[0]
                    other_rank = square[1]

                    if from_file != other_file:
                        need_file = True
                    elif from_rank != other_rank:
                        need_rank = True
                    else:
                        need_file = True
                        need_rank = True

        if need_file:
            notation += move.from_square[0]
        if need_rank:
            notation += move.from_square[1]

    # Capture notation
    if move.captured_piece or move.is_en_passant:
        if move.piece.upper() == 'P' and not piece_symbol:
            # Pawn capture includes the file
            notation += move.from_square[0]
        notation += 'x'

    # Destination square
    notation += move.to_square

    # Promotion
    if move.promotion_piece:
        notation += '=' + move.promotion_piece.upper()

    # En passant (optional, some PGN doesn't include this)
    # if move.is_en_passant:
    #     notation += ' e.p.'

    # Check/checkmate symbols would require full game state
    if move.is_checkmate:
        notation += '#'
    elif move.is_check:
        notation += '+'

    return notation


def could_piece_move_to(
    piece: str,
    from_square: str,
    to_square: str,
    board: Dict[str, str]
) -> bool:
    """
    Simple check if a piece could potentially move to a target square.

    This is a simplified version - doesn't check for blocking pieces,
    pins, or other complex chess rules.

    Parameters
    ----------
    piece : str
        Piece code.
    from_square : str
        Starting square.
    to_square : str
        Target square.
    board : Dict[str, str]
        Current board state.

    Returns
    -------
    bool
        True if the move is potentially valid.
    """
    piece_type = piece.upper()

    from_file = ord(from_square[0]) - ord('a')
    from_rank = int(from_square[1]) - 1
    to_file = ord(to_square[0]) - ord('a')
    to_rank = int(to_square[1]) - 1

    dx = to_file - from_file
    dy = to_rank - from_rank

    if piece_type == 'N':  # Knight
        return (abs(dx), abs(dy)) in [(1, 2), (2, 1)]

    elif piece_type == 'B':  # Bishop
        return abs(dx) == abs(dy) and dx != 0

    elif piece_type == 'R':  # Rook
        return (dx == 0 or dy == 0) and (dx != 0 or dy != 0)

    elif piece_type == 'Q':  # Queen
        return (abs(dx) == abs(dy) or dx == 0 or dy == 0) and (dx != 0 or dy != 0)

    elif piece_type == 'K':  # King
        return abs(dx) <= 1 and abs(dy) <= 1 and (dx != 0 or dy != 0)

    elif piece_type == 'P':  # Pawn
        direction = 1 if piece.isupper() else -1  # White moves up, black down
        if dx == 0:  # Forward move
            return dy == direction or (dy == 2 * direction and
                   ((piece.isupper() and from_rank == 1) or
                    (piece.islower() and from_rank == 6)))
        else:  # Capture
            return abs(dx) == 1 and dy == direction

    return False


def append_move_to_pgn(
    move: Move,
    pgn_so_far: str,
    move_number: int,
    side_to_move: str,
    board_before: Dict[str, str]
) -> Tuple[str, int, str]:
    """
    Convert move to PGN and append it to the existing PGN string.

    If the game/sequence starts with a black move, format it as
    '1... Qh4+' instead of '1. Qh4+'.
    """
    move_notation = move_to_pgn_notation(move, board_before)

    if side_to_move == 'white':
        # White's move: always 'N. move'
        if pgn_so_far:
            pgn_so_far += ' '
        pgn_so_far += f'{move_number}. {move_notation}'
        next_side = 'black'
        next_number = move_number

    else:
        # Black's move
        if not pgn_so_far:
            # Game (or sequence) starts with Black
            pgn_so_far = f'{move_number}... {move_notation}'
        else:
            # Normal case: black reply to an existing white move
            pgn_so_far += f' {move_notation}'
        next_side = 'white'
        next_number = move_number + 1

    return pgn_so_far, next_number, next_side



# ============================================================================
# VISUALIZATION HELPERS
# ============================================================================

def visualize_frame(
    title: str,
    frame_bgr: np.ndarray,
    figsize: Tuple[int, int] = (10, 8)
) -> None:
    """
    Helper to show a BGR frame using matplotlib in Colab.

    Parameters
    ----------
    title : str
        Title for the plot.
    frame_bgr : np.ndarray
        Frame in BGR format.
    figsize : Tuple[int, int]
        Figure size.
    """
    plt.figure(figsize=figsize)
    plt.imshow(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
    plt.title(title, fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.show()


def visualize_move_overlay(
    frame: np.ndarray,
    move: Move,
    grid_centers: np.ndarray,
    square_names: List[str],
    board_size: int = DEFAULT_BOARD_SIZE,
    figsize: Tuple[int, int] = (10, 10)
) -> None:
    """
    Visualize a move with arrows on the warped board.

    Parameters
    ----------
    frame : np.ndarray
        Warped board image.
    move : Move
        The move to visualize.
    grid_centers : np.ndarray
        Grid square centers.
    square_names : List[str]
        Square names.
    board_size : int
        Board size.
    figsize : Tuple[int, int]
        Figure size.
    """
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # Get square centers
    from_idx = square_names.index(move.from_square)
    to_idx = square_names.index(move.to_square)

    from_center = grid_centers[from_idx]
    to_center = grid_centers[to_idx]

    # Draw arrow
    ax.annotate(
        '',
        xy=(to_center[0], to_center[1]),
        xytext=(from_center[0], from_center[1]),
        arrowprops=dict(
            arrowstyle='->,head_width=0.5,head_length=0.3',
            color='red',
            lw=3
        )
    )

    # Highlight squares
    square_size = board_size / 8

    # From square (red)
    rect_from = Rectangle(
        (from_center[0] - square_size/2, from_center[1] - square_size/2),
        square_size, square_size,
        linewidth=3, edgecolor='red', facecolor='red', alpha=0.3
    )
    ax.add_patch(rect_from)

    # To square (green)
    rect_to = Rectangle(
        (to_center[0] - square_size/2, to_center[1] - square_size/2),
        square_size, square_size,
        linewidth=3, edgecolor='green', facecolor='green', alpha=0.3
    )
    ax.add_patch(rect_to)

    # Create move description
    piece_name = PIECE_NAMES.get(move.piece, move.piece)
    move_desc = f"{piece_name}: {move.from_square} → {move.to_square}"
    if move.captured_piece:
        move_desc += f" (captures {PIECE_NAMES.get(move.captured_piece, move.captured_piece)})"
    if move.is_castle:
        move_desc = f"Castling {move.castle_side}"
    if move.promotion_piece:
        move_desc += f" (promotes to {move.promotion_piece})"

    ax.set_title(f'Move: {move_desc}', fontsize=14)
    ax.axis('off')
    plt.tight_layout()
    plt.show()


def visualize_board_state(
    board_dict: Dict[str, str],
    title: str = "Board State",
    figsize: Tuple[int, int] = (8, 8)
) -> None:
    """
    Visualize a board state as a text-based diagram.

    Parameters
    ----------
    board_dict : Dict[str, str]
        Board state dictionary.
    title : str
        Title for the diagram.
    figsize : Tuple[int, int]
        Figure size.
    """
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # Draw the board
    for row in range(8):
        for col in range(8):
            # Determine square color
            is_light = (row + col) % 2 == 0
            color = '#F0D9B5' if is_light else '#B58863'

            # Draw square
            rect = Rectangle(
                (col, 7 - row), 1, 1,
                facecolor=color, edgecolor='black', linewidth=0.5
            )
            ax.add_patch(rect)

            # Get piece on this square
            file = FILES[col]
            rank = str(row + 1)
            square = f'{file}{rank}'

            if square in board_dict:
                piece = board_dict[square]
                piece_color = 'white' if piece.isupper() else 'black'

                # Use Unicode chess symbols
                unicode_pieces = {
                    'K': '♔', 'Q': '♕', 'R': '♖', 'B': '♗', 'N': '♘', 'P': '♙',
                    'k': '♚', 'q': '♛', 'r': '♜', 'b': '♝', 'n': '♞', 'p': '♟'
                }

                symbol = unicode_pieces.get(piece, piece)
                ax.text(
                    col + 0.5, 7 - row + 0.5, symbol,
                    fontsize=32, ha='center', va='center',
                    color=piece_color
                )

    # Add file labels
    for col, file in enumerate(FILES):
        ax.text(col + 0.5, -0.3, file, fontsize=12, ha='center')

    # Add rank labels
    for row, rank in enumerate(RANKS):
        ax.text(-0.3, 7 - row + 0.5, rank, fontsize=12, va='center')

    ax.set_xlim(-0.5, 8.5)
    ax.set_ylim(-0.5, 8.5)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title(title, fontsize=14)
    plt.tight_layout()
    plt.show()


def create_summary_visualization(
    original_frame: np.ndarray,
    warped_board: np.ndarray,
    detections: List[PieceDetection],
    board_dict: Dict[str, str],
    move: Optional[Move],
    pgn_so_far: str,
    figsize: Tuple[int, int] = (18, 8)
) -> None:
    """
    Create a summary visualization of the current processing state.

    Parameters
    ----------
    original_frame : np.ndarray
        Original video frame.
    warped_board : np.ndarray
        Warped board image.
    detections : List[PieceDetection]
        YOLO detections.
    board_dict : Dict[str, str]
        Current board state.
    move : Optional[Move]
        Detected move (if any).
    pgn_so_far : str
        Current PGN string.
    figsize : Tuple[int, int]
        Figure size.
    """
    fig = plt.figure(figsize=figsize)

    # Original frame with detections
    ax1 = fig.add_subplot(1, 3, 1)
    frame_annotated = original_frame.copy()
    for det in detections:
        x1, y1, x2, y2 = det.bbox
        color = (255, 255, 0) if det.class_name[0].lower() == 'w' else (255, 0, 255)
        cv2.rectangle(frame_annotated, (x1, y1), (x2, y2), color, 2)
    ax1.imshow(cv2.cvtColor(frame_annotated, cv2.COLOR_BGR2RGB))
    ax1.set_title(f'Original Frame\n{len(detections)} pieces detected')
    ax1.axis('off')

    # Warped board
    ax2 = fig.add_subplot(1, 3, 2)
    ax2.imshow(cv2.cvtColor(warped_board, cv2.COLOR_BGR2RGB))
    ax2.set_title('Warped Board')
    ax2.axis('off')

    # Board state and move info
    ax3 = fig.add_subplot(1, 3, 3)
    ax3.axis('off')

    # Create board text representation
    board_text = "Board State:\n"
    board_text += "  a b c d e f g h\n"
    for rank in range(8, 0, -1):
        board_text += f"{rank} "
        for file in 'abcdefgh':
            square = f'{file}{rank}'
            piece = board_dict.get(square, '.')
            board_text += f"{piece} "
        board_text += f"{rank}\n"
    board_text += "  a b c d e f g h\n"

    # Add move info
    if move:
        board_text += f"\nLast Move: {move.from_square} → {move.to_square}"
        if move.captured_piece:
            board_text += f" (capture)"

    # Add PGN
    board_text += f"\n\nPGN: {pgn_so_far if pgn_so_far else '(no moves yet)'}"

    ax3.text(0.1, 0.9, board_text, transform=ax3.transAxes,
             fontsize=10, family='monospace', verticalalignment='top')
    ax3.set_title('Current State')

    plt.tight_layout()
    plt.show()


# ============================================================================
# MAIN PIPELINE
# ============================================================================

def main_pipeline(
    video_path: str,
    yolo_model: Any,
    unet_model: Any,
    ocr_model: Any = None,
    vlm_client: Any = None,
    output_fps: float = 1.0,
    visualize_steps: bool = True,
    visualize_every_n_frames: int = 5
) -> Tuple[List[BoardState], List[Move], str]:
    """
    Main pipeline for chess move detection from video.

    This function orchestrates the entire process:
    1. Extract frames from video
    2. Filter out frames with hands
    3. Detect board corners and compute homography
    4. Estimate board rotation
    5. Process each frame with YOLO detection
    6. Track board state changes
    7. Generate PGN notation
    """

    print("=" * 70)
    print("CHESS MOVE DETECTION PIPELINE")
    print("=" * 70)

    # =========================================================================
    # STEP 1: Extract frames from video
    # =========================================================================
    print("\n[STEP 1] Extracting frames from video...")
    print("-" * 40)

    frames = extract_frames_from_video(video_path, output_fps=output_fps)

    if visualize_steps and len(frames) > 0:
        visualize_frame("First Extracted Frame", frames[0])

    # =========================================================================
    # STEP 2: Filter frames without hands
    # =========================================================================
    print("\n[STEP 2] Filtering frames (removing frames with hands)...")
    print("-" * 40)

    clean_frames, clean_indices = filter_frames_no_hands(frames)

    if len(clean_frames) == 0:
        raise ValueError("No clean frames found! All frames contain hands.")

    print(f"Using {len(clean_frames)} clean frames for analysis")

    # =========================================================================
    # STEP 3: Process FIRST clean frame for global parameters
    # =========================================================================
    print("\n[STEP 3] Computing global board parameters from first clean frame...")
    print("-" * 40)

    first_frame = clean_frames[0]

    # 3.1: Preprocess with Sobel + LoG
    print("  3.1: Applying Sobel + LoG preprocessing...")
    preprocessed = preprocess_sobel_log(first_frame)

    if visualize_steps:
        visualize_preprocessing(preprocessed, first_frame)

    # 3.2: Detect corners with U-Net
    print("  3.2: Detecting board corners with U-Net...")
    corners_raw = detect_corners_unet(first_frame, unet_model)
    print(f"       Raw corners detected: {corners_raw}")

    if visualize_steps:
        visualize_unet_corners(first_frame, corners_raw, "U-Net Corner Detection")

    # 3.3: Estimate rotation with OCR + VLM
    print("  3.3: Estimating board rotation...")
    rotation_deg = estimate_rotation_with_vlm(
        first_frame, corners_raw, ocr_model, vlm_client,
        visualize=visualize_steps
    )
    print(f"       Estimated rotation: {rotation_deg}°")

    # 3.4: Order corners based on rotation
    print("  3.4: Ordering corners [TL, TR, BR, BL]...")
    corners_ordered = order_corners_tl_tr_br_bl(corners_raw, rotation_deg)
    print(f"       Ordered corners: {corners_ordered}")

    # 3.5: Compute homography
    print("  3.5: Computing homography matrix...")
    board_size = DEFAULT_BOARD_SIZE
    H, dst_points = compute_homography(corners_ordered, board_size)
    print(f"       Homography matrix computed (destination: {board_size}x{board_size})")

    # Warp first frame for visualization
    warped_first = warp_board(first_frame, H, board_size)

    if visualize_steps:
        visualize_frame("Warped Board (First Frame)", warped_first)

    # 3.6: Compute board grid
    print("  3.6: Computing board grid centers and square names...")
    grid_centers, square_names = compute_board_grid(board_size, 8, rotation_deg)
    print(f"       Grid computed: 64 squares from {square_names[0]} to {square_names[-1]}")

    if visualize_steps:
        visualize_board_grid(warped_first, grid_centers, square_names, board_size)

    # Store global parameters
    global_params = GlobalBoardParameters(
        rotation_degrees=rotation_deg,
        corners_ordered=corners_ordered,
        homography_matrix=H,
        board_size=board_size,
        grid_centers=grid_centers,
        square_names=square_names
    )

    # =========================================================================
    # STEP 4: Process each clean frame
    # =========================================================================
    print("\n[STEP 4] Processing each clean frame...")
    print("-" * 40)

    board_states: List[BoardState] = []
    moves: List[Move] = []
    pgn = ""
    move_number = 1
    side_to_move = 'white'

    # NEW: last reliable board (ignores noisy frames)
    trusted_board: Optional[Dict[str, str]] = None

    for i, (frame, orig_idx) in enumerate(zip(clean_frames, clean_indices)):
        print(f"  Processing frame {i+1}/{len(clean_frames)} (original index: {orig_idx})...")

        # 4.1: Optional preprocessing just for visualization
        preprocessed = preprocess_sobel_log(frame)
        if visualize_steps and i % visualize_every_n_frames == 0:
            visualize_preprocessing(preprocessed, frame)

        # 4.2: Run YOLO detection on ORIGINAL frame
        detections = detect_pieces_yolo(
            frame,
            yolo_model,
            conf_thres=0.3,
            iou_thres=0.45,
            rotation_deg=rotation_deg
        )
        print(f"       Detected {len(detections)} pieces")

        # 4.3 & 4.4: Project piece centers using homography
        projected_centers = None
        piece_types: List[str] = []

        if detections:
            piece_centers = np.array([det.center for det in detections])
            projected_centers = project_points_with_homography(piece_centers, H)

            # Update detections with projected centers
            for det, proj_center in zip(detections, projected_centers):
                det.projected_center = tuple(proj_center)

            # Get piece types (convert YOLO class names to piece codes)
            for det in detections:
                piece_code = PIECE_SYMBOLS.get(det.class_name.lower(), '?')
                piece_types.append(piece_code)

            # 4.5: Assign pieces to squares
            board_dict = assign_pieces_to_squares(
                projected_centers, piece_types, grid_centers, square_names
            )
        else:
            board_dict = {}

        # 4.6: Build board state
        board_state = BoardState(
            frame_index=orig_idx,
            pieces=board_dict,
            detections=detections
        )
        board_states.append(board_state)

        # 4.7: Detect move using TRUSTED board (not just last frame)
        detected_move: Optional[Move] = None

        if trusted_board is None:
            # First frame: initialize baseline
            trusted_board = board_dict.copy()
        else:
            # Compare with last reliable board
            detected_move = detect_move_from_states(trusted_board, board_dict)

            if detected_move:
                # Determine who moved first in the sequence
                if not pgn:
                    side_to_move = 'black' if detected_move.piece.islower() else 'white'

                # Check if this move gives check
                moving_is_white = detected_move.piece.isupper()
                enemy_king_code = 'k' if moving_is_white else 'K'
                enemy_king_square = None
                for sq, pc in board_dict.items():
                    if pc == enemy_king_code:
                        enemy_king_square = sq
                        break

                if enemy_king_square is not None:
                    if is_square_attacked(board_dict, enemy_king_square, by_white=moving_is_white):
                        detected_move.is_check = True

                moves.append(detected_move)
                check_suffix = "+" if detected_move.is_check else ""
                print(
                    f"       Move detected (trusted): "
                    f"{detected_move.from_square} → {detected_move.to_square}{check_suffix}"
                )

                # Use TRUSTED board as "board_before" for PGN
                pgn, move_number, side_to_move = append_move_to_pgn(
                    detected_move, pgn, move_number, side_to_move, trusted_board
                )

                # Update reliable baseline
                trusted_board = board_dict.copy()

            else:
                # No clear move: decide whether frame is just minor jitter or trash
                diff_squares = [
                    sq for sq in set(trusted_board.keys()) | set(board_dict.keys())
                    if trusted_board.get(sq) != board_dict.get(sq)
                ]
                num_diff = len(diff_squares)

                if num_diff <= 2:
                    # Small differences (e.g. one pawn flicker) – accept as updated baseline
                    print(
                        f"       No clear move, but only {num_diff} squares differ "
                        f"→ minor noise, updating trusted board."
                    )
                    trusted_board = board_dict.copy()
                else:
                    # Many differences – likely bad YOLO frame
                    print(
                        f"       No reliable move and {num_diff} squares differ "
                        f"→ IGNORING this frame as noisy."
                    )

        # Visualization (every N frames or when move detected)
        should_visualize = (
            visualize_steps and
            ((i % visualize_every_n_frames == 0) or detected_move is not None)
        )

        if should_visualize:
            warped_frame = warp_board(frame, H, board_size)

            # Show YOLO detections
            visualize_yolo_detections(
                frame, detections,
                f"YOLO Detections - Frame {orig_idx}"
            )

            # Show piece assignments
            if detections and projected_centers is not None:
                visualize_piece_assignments(
                    warped_frame, projected_centers, piece_types,
                    grid_centers, square_names, board_dict
                )

            # Show move overlay if move detected
            if detected_move:
                visualize_move_overlay(
                    warped_frame, detected_move,
                    grid_centers, square_names, board_size
                )

            # Show summary
            create_summary_visualization(
                frame, warped_frame, detections, board_dict, detected_move, pgn
            )

    # =========================================================================
    # STEP 5: Final output
    # =========================================================================
    print("\n[STEP 5] Final Results")
    print("-" * 40)
    print(f"Total frames processed: {len(clean_frames)}")
    print(f"Total board states recorded: {len(board_states)}")
    print(f"Total moves detected: {len(moves)}")
    print(f"\nPGN Output:")
    print(f"  {pgn if pgn else '(no moves detected)'}")

    # Final board visualization
    if visualize_steps and board_states:
        final_board = board_states[-1].pieces
        visualize_board_state(final_board, "Final Board Position")

    print("\n" + "=" * 70)
    print("PIPELINE COMPLETE")
    print("=" * 70)

    return board_states, moves, pgn



# ============================================================================
# UTILITY FUNCTIONS FOR COLAB
# ============================================================================

def setup_colab_environment():
    """
    Setup function for Google Colab environment.
    Installs required packages and sets up display.
    """
    print("Setting up Colab environment...")

    # Import display utilities
    try:
        from google.colab.patches import cv2_imshow
        print("Google Colab detected - using cv2_imshow")
    except ImportError:
        print("Not running in Colab - using matplotlib for display")

    # Set matplotlib to inline mode
    try:
        from IPython import get_ipython
        get_ipython().run_line_magic('matplotlib', 'inline')
    except:
        pass

    print("Environment setup complete!")


def load_sample_video(url: str, output_path: str = "chess_game.mp4") -> str:
    """
    Download a sample video from URL for testing.

    Parameters
    ----------
    url : str
        URL of the video to download.
    output_path : str
        Local path to save the video.

    Returns
    -------
    str
        Path to the downloaded video.
    """
    import urllib.request

    print(f"Downloading video from {url}...")
    urllib.request.urlretrieve(url, output_path)
    print(f"Video saved to {output_path}")

    return output_path



In [None]:
#install OCR
!pip install easyocr
!pip install -q google-genai
!pip install -U google-genai




In [None]:
import os
import base64
# @title Set Gemini API Key
GEMINI_API_KEY = "AIzaSyC0I-S6l-zf2ZQckDQsMaOVF9rznDD2Ahw"  # <-- put real key here
os.environ["GOOGLE_API_KEY"] = GEMINI_API_KEY

import google.genai as genai
from google.genai import types as genai_types


class GeminiVLMClient:
    """
    Simple wrapper so the pipeline can call:
        vlm_client.analyze(image_base64, prompt) -> text answer

    Uses Gemini 1.5 Flash by default (good + cheap + fast).
    """

    def __init__(
        self,
        api_key: str,
        model_name: str = "gemini-2.5-flash",
    ):
        if not api_key:
            raise ValueError("Gemini API key is required")

        self.client = genai.Client(api_key=api_key)
        self.model_name = model_name

    def analyze(self, image: str, prompt: str) -> str:
        """
        Parameters
        ----------
        image : str
            Base64-encoded JPEG bytes (as produced in query_vlm_for_rotation).
        prompt : str
            Text instruction.

        Returns
        -------
        str
            Model's text response.
        """
        # Decode base64 to raw bytes
        img_bytes = base64.b64decode(image)

        # Build the content for the Gemini multimodal call
        response = self.client.models.generate_content(
            model=self.model_name,
            contents=[
                genai_types.Part.from_bytes(
                    data=img_bytes,
                    mime_type="image/jpeg"
                ),
                prompt,
            ],
        )

        # Extract text (Gemini returns a structured object)
        try:
            return response.text.strip()
        except Exception:
            return str(response)


# Run Here

In [None]:
# 1. Setup environment (optional helper)
setup_colab_environment()

# 2. Load your models
from ultralytics import YOLO
import torch

# YOLO model
yolo_model = YOLO("/content/drive/MyDrive/yolo/best.pt")

# U-Net model (example if PyTorch)
MODEL_PATH = "/content/drive/MyDrive/Dig_Image/U-Net/U-Net Best.pt"

# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# # Instantiate and load model
unet_model = load_unet(MODEL_PATH)
state_dict = torch.load(MODEL_PATH, map_location=device)
unet_model.load_state_dict(state_dict)
unet_model.eval()

# # 3. Optional: OCR + VLM
import easyocr
ocr_model = easyocr.Reader(['en'])   # or your OCR

vlm_client = GeminiVLMClient(
    api_key=os.environ.get("GOOGLE_API_KEY", GEMINI_API_KEY)  # or just pass the key
)  # or your actual VLM client wrapper

# 4. Run pipeline
video_path = "/content/drive/MyDrive/Chess Detection Competition/test_videos/8_Move_student.mp4"
# video_path = "/content/drive/MyDrive/Chess Detection Competition/test_videos/2_Move_rotate_student.mp4"

board_states, moves, pgn = main_pipeline(
    video_path=video_path,
    yolo_model=yolo_model,
    unet_model=unet_model,
    ocr_model=ocr_model,
    vlm_client=vlm_client,          # or None if you temporarily disable
    output_fps=1.0,                 # lower for faster processing
    visualize_steps=True,
    visualize_every_n_frames=5
)

print("Final PGN:")
print(pgn)


NameError: name 'setup_colab_environment' is not defined