In [None]:
import cv2
import itertools
import numpy as np
from ultralytics import YOLO

### Load Trained YOLO Models

In [2]:
# ball_model = YOLO("models/ball_best_v8_30e.pt")
table_model = YOLO("models/table_yolo_v11_20e.pt")
# table_model = YOLO("models/table_yolo_v11_10e.pt")
ball_model = YOLO("models/ball_best_v11_15e.pt")

In [4]:
def correct_frame_orientation(frame):
    h, w = frame.shape[:2]
    if h > w:
        return frame  # already portrait
    else:
        return cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
    
def order_points_from_short_sides_adaptive(pts):
    """
    Hybrid version:
    - If the table is clearly vertical (height >> width), use top/bottom quadrant logic.
    - Else, use the robust short-side vector sorting.
    Returns corners in TL, TR, BR, BL order.
    """
    pts = np.array(pts, dtype=np.float32)

    # Step 1: Compute bounding box
    min_x, max_x = np.min(pts[:, 0]), np.max(pts[:, 0])
    min_y, max_y = np.min(pts[:, 1]), np.max(pts[:, 1])
    width = max_x - min_x
    height = max_y - min_y

    if height > 1.2 * width:
        # ⚠️ Table is vertical
        center_y = (min_y + max_y) / 2
        upper = []
        lower = []
        for pt in pts:
            if pt[1] < center_y:
                upper.append(pt)
            else:
                lower.append(pt)

        if len(upper) != 2 or len(lower) != 2:
            raise ValueError("Unexpected corner distribution in vertical mode.")

        upper = sorted(upper, key=lambda p: p[0])  # left to right
        lower = sorted(lower, key=lambda p: p[0])  # left to right

        tl, tr = upper
        bl, br = lower
        return np.array([tl, tr, br, bl], dtype=np.float32)

    else:
        # ✅ Use the reliable short-side vector-based ordering
        dists = []
        for i, j in itertools.combinations(range(4), 2):
            d = np.linalg.norm(pts[i] - pts[j])
            dists.append((d, i, j))
        dists.sort()
        (_, i1, j1), (_, i2, j2) = dists[:2]

        vec1 = pts[j1] - pts[i1]
        vec2 = pts[j2] - pts[i2]
        if vec1[0] > 0: vec1 = -vec1
        if vec2[0] > 0: vec2 = -vec2
        x1, x2 = vec1[0], vec2[0]

        if x1 < x2:
            left_pair = [pts[i1], pts[j1]]
            right_pair = [pts[i2], pts[j2]]
        else:
            left_pair = [pts[i2], pts[j2]]
            right_pair = [pts[i1], pts[j1]]

        left_pair = sorted(left_pair, key=lambda p: p[1])
        right_pair = sorted(right_pair, key=lambda p: p[1])

        tl, bl = left_pair
        tr, br = right_pair
        return np.array([tl, tr, br, bl], dtype=np.float32)
    
def order_points_from_short_sides(pts):
    """
    Orders 4 corner points of a rectangle into: [top-left, top-right, bottom-right, bottom-left].
    """
    pts = np.asarray(pts, dtype=np.float32)
    if pts.shape != (4, 2):
        raise ValueError("pts must be shape (4, 2)")

    # Compute all pairwise distances, store pairs
    pairs = list(itertools.combinations(range(4), 2))
    dists = np.linalg.norm(pts[pairs][:, 0] - pts[pairs][:, 1], axis=1)
    idx = np.argsort(dists)[:2]
    short_pairs = [pairs[i] for i in idx]

    # Gather unique indices of short side points (should be 3)
    short_idxs = list(set(i for pair in short_pairs for i in pair))
    if len(short_idxs) != 3:
        raise ValueError("Could not determine two distinct short sides.")

    # Count occurrences to find the shared (corner) index
    counts = {i: sum(i in pair for pair in short_pairs) for i in short_idxs}
    shared = [i for i, c in counts.items() if c == 2][0]
    neighbors = [i for i in short_idxs if i != shared]

    v1, v2 = pts[neighbors[0]] - pts[shared], pts[neighbors[1]] - pts[shared]
    cross = np.cross(v1, v2)
    if cross > 0:
        tl, tr, bl = pts[shared], pts[neighbors[0]], pts[neighbors[1]]
    else:
        tl, tr, bl = pts[shared], pts[neighbors[1]], pts[neighbors[0]]

    br_idx = next(i for i in range(4) if i not in [shared] + neighbors)
    br = pts[br_idx]

    return np.array([tl, tr, br, bl], dtype=np.float32)

def order_points_clockwise(pts):
    pts = np.array(pts, dtype="float32")
    center = np.mean(pts, axis=0)

    def angle(p):
        return np.arctan2(p[1] - center[1], p[0] - center[0])

    pts_sorted = sorted(pts, key=angle)
    return np.array(pts_sorted, dtype="float32")

def order_points_stable(pts):
    pts = np.array(pts, dtype="float32")
    s = pts.sum(axis=1)
    diff = np.diff(pts, axis=1)

    return np.array([
        pts[np.argmin(s)],      # top-left
        pts[np.argmin(diff)],   # top-right
        pts[np.argmax(s)],      # bottom-right
        pts[np.argmax(diff)]    # bottom-left
    ], dtype="float32")



### Class list

In [5]:
ball_label_to_name = {
    0: "cue_ball",
    1: "8_ball",
    2: "stripe",
    3: "solid"
}

### Mapping a Video (New Method)

In [15]:
def detect_table_and_map_balls_from_video(video_path, output_path='video_output/', max_frames=1000):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return

    output_path +=  "out_" + video_path.split("/")[-1]

    # Video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_to_process = min(total_frames, max_frames)

    # Set output video writer
    combined_width = width * 2
    combined_height = height
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (combined_width, combined_height))

    homography, dst_width, dst_height, mask_resized = None, None, None, None
    dst_width = 400
    dst_height = 800
    table_set = False
    frame_idx = 0

    while frame_idx < frames_to_process:
            ret, frame = cap.read()
            if not ret:
                break
            frame = correct_frame_orientation(frame)

            # Update width/height AFTER orientation fix
            height, width = frame.shape[:2]

            # Re-initialize VideoWriter if first frame
            if frame_idx == 0:
                combined_width = width * 2
                combined_height = height
                out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (combined_width, combined_height))
       
            if frame_idx % 15 == 0 or not table_set:
                table_results = table_model(frame)[0]
                if table_results.masks:
                    mask = table_results.masks[0].data[0].cpu().numpy()
                    mask_resized = cv2.resize(mask, (width, height))
                    #========EXPANSION=============
                    # Binarize the mask
                    mask_binary = (mask_resized > 0.5).astype(np.uint8)

                    # --- ADD DILATION HERE (5% expansion) ---
                    # Calculate kernel size as 5% of the larger mask dimension
                    kernel_size = int(0.05 * max(width, height))
                    kernel_size = max(3, kernel_size | 1)  # Ensure kernel size is odd and >= 3

                    # Create dilation kernel (elliptical is more natural for blobs)
                    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))

                    # Dilate the binary mask
                    mask_dilated = cv2.dilate(mask_binary, kernel, iterations=1)
                    #========EXPANSION=============

                    contours, _ = cv2.findContours((mask_dilated * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    if contours:
                        contour = max(contours, key=cv2.contourArea)
                        epsilon = 0.05 * cv2.arcLength(contour, True)
                        approx = cv2.approxPolyDP(contour, epsilon, True)
                        
                        src_pts = order_points_stable([pt[0] for pt in approx])
                        dst_pts = np.array([
                        [0, 0],         # top-left
                        [400, 0],       # top-right
                        [400, 800],     # bottom-right
                        [0, 800]        # bottom-left
                        ], dtype=np.float32)

                        homography, _ = cv2.findHomography(src_pts, dst_pts)
                        table_set = True

            # ========== 1. Detect TABLE once ==========
            # Assume models accept NumPy arrays now
            table_results = table_model(frame)[0]
            if not table_results.masks:
                frame_idx += 1
                continue

            mask = table_results.masks[0].data[0].cpu().numpy()
            mask_resized = cv2.resize(mask, (width, height))

            #========EXPANSION=============
            # Binarize the mask
            mask_binary = (mask_resized > 0.5).astype(np.uint8)

            # --- ADD DILATION HERE (5% expansion) ---
            # Calculate kernel size as 5% of the larger mask dimension
            kernel_size = int(0.05 * max(width, height))
            kernel_size = max(3, kernel_size | 1)  # Ensure kernel size is odd and >= 3

            # Create dilation kernel (elliptical is more natural for blobs)
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))

            # Dilate the binary mask
            mask_dilated = cv2.dilate(mask_binary, kernel, iterations=1)
            #========EXPANSION=============

            contours, _ = cv2.findContours((mask_dilated * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            contour = max(contours, key=cv2.contourArea)
            epsilon = 0.05 * cv2.arcLength(contour, True)
            approx = cv2.approxPolyDP(contour, epsilon, True)

            src_pts = order_points_stable([pt[0] for pt in approx])
            dst_pts = np.array([
                [0, 0],         # top-left
                [400, 0],       # top-right
                [400, 800],     # bottom-right
                [0, 800]        # bottom-left
                ], dtype=np.float32)
            homography, _ = cv2.findHomography(src_pts, dst_pts)

            if homography is None:
                print("Failed to compute homography matrix.")
                return

            table_set = True
            print("✅ Table detected and homography set.")

            # ========== 2. Detect BALLS ==========
            ball_results = ball_model(frame)[0]
            mapped_positions, ball_labels = [], []

            for box in ball_results.boxes:
                x_c = box.xywh[0][0].item()
                y_c = box.xywh[0][1].item()
                x_int = int(np.clip(x_c, 0, width - 1))
                y_int = int(np.clip(y_c, 0, height - 1))

                if mask_dilated[y_int, x_int] < 0.5:
                    continue

                pt = np.array([[[x_c, y_c]]], dtype=np.float32)
                dst = cv2.perspectiveTransform(pt, homography)[0][0]

                if 0 <= dst[0] <= dst_width and 0 <= dst[1] <= dst_height:
                    mapped_positions.append(dst)
                    ball_labels.append(int(box.cls[0].item()))
                # else:
                #     print(f"⚠️ Mapped out-of-bound: ({dst[0]:.1f}, {dst[1]:.1f})")

            # ========== 3. Draw 2D Top-Down View ==========
            table_view = np.zeros((int(dst_height), int(dst_width), 3), dtype=np.uint8)
            table_view[:] = (26, 99, 15) # dark green

            for (x, y), label in zip(mapped_positions, ball_labels):
                cv2.circle(table_view, (int(x), int(y)), 8, (255, 255, 255), -1)
                cv2.putText(table_view, ball_label_to_name[label], (int(x) - 10, int(y) - 12),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)

            table_view_resized = cv2.resize(table_view, (frame.shape[1], frame.shape[0]))
            combined = np.hstack((table_view_resized, frame.copy()))

            # ========== 4. Combine Views with Titles ==========
            cv2.putText(combined, "2D Table View", (50, 40),
                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
            cv2.putText(combined, "Original Video", (width + 50, 40),
                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)

            out.write(combined)
            frame_idx += 1
            print(f"{frame_idx}/{frames_to_process} : {(frame_idx/frames_to_process) *100}%")

    cap.release()
    out.release()
    print(f"✅ Output video saved to: {output_path}")


In [16]:
detect_table_and_map_balls_from_video("video_input/vid1.mp4")


0: 640x384 (no detections), 148.6ms
Speed: 3.1ms preprocess, 148.6ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 (no detections), 143.4ms
Speed: 2.0ms preprocess, 143.4ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 1 table_play_area, 140.3ms
Speed: 2.0ms preprocess, 140.3ms inference, 1.7ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 1 table_play_area, 163.4ms
Speed: 4.3ms preprocess, 163.4ms inference, 3.5ms postprocess per image at shape (1, 3, 640, 384)
✅ Table detected and homography set.

0: 640x384 1 8_ball, 4 stripes, 7 solids, 99.3ms
Speed: 2.0ms preprocess, 99.3ms inference, 1.3ms postprocess per image at shape (1, 3, 640, 384)
2/253 : 0.7905138339920948%

0: 640x384 2 table_play_areas, 133.7ms
Speed: 2.1ms preprocess, 133.7ms inference, 2.7ms postprocess per image at shape (1, 3, 640, 384)
✅ Table detected and homography set.

0: 640x384 7 8_balls, 2 stripes, 8 solids, 123.1ms
Speed: 2.3ms prep

### Show predictions on Video

In [None]:
def visualize_table_and_ball_predictions(video_path, output_path, max_frames=700):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"❌ Cannot open video: {video_path}")
        return

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_to_process = min(total_frames, max_frames)

    frame_idx = 0
    out = None  # Will be initialized after first frame

    while frame_idx < frames_to_process:
        ret, frame = cap.read()
        if not ret:
            break

        frame = correct_frame_orientation(frame)

        # Initialize VideoWriter AFTER getting true dimensions
        if out is None:
            height, width = frame.shape[:2]
            combined_width = width  # Only using single view here
            combined_height = height
            out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'),
                                  fps, (combined_width, combined_height))

        # 1. Predict table
        table_results = table_model(frame)[0]
        if table_results.masks:
            mask = table_results.masks[0].data[0].cpu().numpy()
            mask_resized = cv2.resize(mask, (width, height))
            contours, _ = cv2.findContours((mask_resized * 255).astype(np.uint8),
                                           cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                cv2.drawContours(frame, contours, -1, (0, 255, 0), 2)

        # 2. Predict balls
        ball_results = ball_model(frame)[0]
        for box in ball_results.boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            cls_id = int(box.cls[0].item())
            conf = float(box.conf[0])
            label = ball_label_to_name.get(cls_id, f"Ball {cls_id}")
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.putText(frame, f"{label} ({conf:.2f})", (x1, y1 - 8),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

        # 3. Frame index
        cv2.putText(frame, f"Frame: {frame_idx}", (10, frame.shape[0] - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)

        out.write(frame)
        frame_idx += 1
        print(f"{frame_idx}/{frames_to_process} : {(frame_idx / frames_to_process) * 100:.1f}%")

    cap.release()
    if out is not None:
        out.release()

    print(f"✅ Prediction video saved: {output_path}")


In [14]:
visualize_table_and_ball_predictions("video_input/tes.mp4", "video_output/tes_with_preds.mp4")


0: 640x384 1 table_play_area, 187.8ms
Speed: 5.0ms preprocess, 187.8ms inference, 2.1ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 5 8_balls, 1 stripe, 9 solids, 115.9ms
Speed: 2.6ms preprocess, 115.9ms inference, 1.3ms postprocess per image at shape (1, 3, 640, 384)
1/700 : 0.1%

0: 640x384 1 table_play_area, 192.8ms
Speed: 2.7ms preprocess, 192.8ms inference, 1.7ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 3 8_balls, 5 stripes, 8 solids, 102.0ms
Speed: 2.6ms preprocess, 102.0ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 384)
2/700 : 0.3%

0: 640x384 1 table_play_area, 140.6ms
Speed: 2.4ms preprocess, 140.6ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 1 8_ball, 3 stripes, 10 solids, 105.3ms
Speed: 2.2ms preprocess, 105.3ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 384)
3/700 : 0.4%

0: 640x384 1 table_play_area, 188.2ms
Speed: 2.6ms preprocess, 188.2ms inference, 2.5ms postprocess 

### Mapping a Video (deprecated)

In [None]:
def detect_table_and_map_balls_video(input_video_path, output_video_path, max_frames=1000):
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print(f"Error: Cannot open video {input_video_path}")
        return

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_to_process = min(total_frames, max_frames)

    out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width * 2, height))

    frame_idx = 0
    while frame_idx < frames_to_process:
        ret, frame = cap.read()
        if not ret:
            break

        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # --- 1. Table detection ---
        table_results = table_model(frame)[0]
        if not table_results.masks:
            out.write(np.hstack((frame, frame)))
            frame_idx += 1
            continue

        mask = table_results.masks[0].data[0].cpu().numpy()
        mask_resized = cv2.resize(mask, (width, height))

        contours, _ = cv2.findContours((mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contour = max(contours, key=cv2.contourArea)
        epsilon = 0.05 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)

        if len(approx) != 4:
            out.write(np.hstack((frame, frame)))
            frame_idx += 1
            continue

        src_pts_raw = np.array([pt[0] for pt in approx], dtype=np.float32)
        src_pts = order_points_stable(src_pts_raw)
        
        dst_pts, dst_width, dst_height = compute_dst_pts_strict(src_pts)
        H, _ = cv2.findHomography(src_pts, dst_pts)

        # --- 2. Ball detection ---
        ball_results = ball_model(frame)[0]
        mapped_positions = []
        ball_labels = []

        for box in ball_results.boxes:
            x_c = box.xywh[0][0].item()
            y_c = box.xywh[0][1].item()
            x_int = int(np.clip(x_c, 0, width - 1))
            y_int = int(np.clip(y_c, 0, height - 1))

            if mask_resized[y_int, x_int] < 0.5:
                continue

            pt = np.array([[[x_c, y_c]]], dtype=np.float32)
            dst = cv2.perspectiveTransform(pt, H)[0][0]
            if 0 <= dst[0] <= dst_width and 0 <= dst[1] <= dst_height:
                mapped_positions.append(dst)
                ball_labels.append(int(box.cls[0].item()))

        # --- 3. Visualization ---
        # A. Original view
        vis_left = img_rgb.copy()
        for box in ball_results.boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            cls = int(box.cls[0].item())
            conf = float(box.conf[0])
            cv2.rectangle(vis_left, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(vis_left, f"{cls} ({conf:.2f})", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)

        # B. 2D top-down view
        vis_right = np.ones((dst_height, dst_width, 3), dtype=np.uint8) * 30
        for (x, y), label in zip(mapped_positions, ball_labels):
            y_flipped = dst_height - int(y)
            cv2.circle(vis_right, (int(x), y_flipped), 10, (255, 255, 255), -1)
            cv2.putText(vis_right, str(label), (int(x), y_flipped - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

        vis_left = cv2.cvtColor(vis_left, cv2.COLOR_RGB2BGR)
        vis_right = cv2.resize(vis_right, (width, height))
        combined = np.hstack((vis_left, vis_right))

        out.write(combined)
        frame_idx += 1
        print(f"{frame_idx}/{frames_to_process} : {(frame_idx/frames_to_process) *100}%")

    cap.release()
    out.release()
    print(f"✅ Video saved to {output_video_path}")


In [None]:
detect_table_and_map_balls_video("video_input/vid6.mp4", "video_output/vid6_out.mp4")