In [1]:
import cv2
import itertools
import numpy as np
from ultralytics import YOLO

### Load Trained YOLO Models

In [2]:
# ball_model = YOLO("models/ball_best_v8_30e.pt")
table_model = YOLO("models/table_yolo_v11_20e.pt")
# table_model = YOLO("models/table_yolo_v11_10e.pt")
ball_model = YOLO("models/ball_best_v11_15e.pt")

In [None]:
def correct_frame_orientation(frame):
    h, w = frame.shape[:2]
    if h > w:
        return frame  # already portrait
    else:
        return cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)

def order_points_stable(pts):
    pts = np.array(pts, dtype="float32")
    s = pts.sum(axis=1)
    diff = np.diff(pts, axis=1)

    return np.array([
        pts[np.argmin(s)],      # top-left
        pts[np.argmin(diff)],   # top-right
        pts[np.argmax(s)],      # bottom-right
        pts[np.argmax(diff)]    # bottom-left
    ], dtype="float32")

ball_label_to_name = {
    0: "cue_ball",
    1: "8_ball",
    2: "stripe",
    3: "solid"
}

### Optimized mapping

In [8]:
def detect_table_and_map_balls_optimized(video_path, output_path='video_output/', max_frames=1200):
    def expand_mask(mask, width, height):
        mask_binary = (mask > 0.5).astype(np.uint8)
        kernel_size = max(3, int(0.05 * max(width, height)) | 1)
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        return cv2.dilate(mask_binary, kernel, iterations=1)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return

    output_filename = "out_" + video_path.split("/")[-1]
    output_path = output_path + output_filename

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_to_process = min(total_frames, max_frames)

    homography = None
    dst_width, dst_height = 400, 800
    frame_idx = 0
    out = None

    while frame_idx < frames_to_process:
        ret, frame = cap.read()
        if not ret:
            break

        frame = correct_frame_orientation(frame)
        height, width = frame.shape[:2]

        if out is None:
            combined_size = (width * 2, height)
            out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, combined_size)

        if homography is None:
            table_results = table_model(frame)[0]
            if not table_results.masks:
                frame_idx += 1
                continue

            mask = table_results.masks[0].data[0].cpu().numpy()
            mask_dilated = expand_mask(cv2.resize(mask, (width, height)), width, height)

            contours, _ = cv2.findContours((mask_dilated * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if not contours:
                frame_idx += 1
                continue

            contour = max(contours, key=cv2.contourArea)
            epsilon = 0.05 * cv2.arcLength(contour, True)
            approx = cv2.approxPolyDP(contour, epsilon, True)
            if len(approx) < 4:
                frame_idx += 1
                continue

            src_pts = order_points_stable([pt[0] for pt in approx])
            dst_pts = np.array([[0, 0], [dst_width, 0], [dst_width, dst_height], [0, dst_height]], dtype=np.float32)
            homography, _ = cv2.findHomography(src_pts, dst_pts)

            if homography is None:
                print("Failed to compute homography matrix.")
                return

            print("✅ Table detected and homography set.")
        else:
            mask_dilated = mask_dilated  # Use from first detection

        # Ball Detection
        ball_results = ball_model(frame)[0]
        mapped_positions, ball_labels = [], []

        for box in ball_results.boxes:
            x_c, y_c = box.xywh[0][0].item(), box.xywh[0][1].item()
            x_int = int(np.clip(x_c, 0, width - 1))
            y_int = int(np.clip(y_c, 0, height - 1))

            if mask_dilated[y_int, x_int] < 0.5:
                continue

            pt = np.array([[[x_c, y_c]]], dtype=np.float32)
            dst = cv2.perspectiveTransform(pt, homography)[0][0]

            if 0 <= dst[0] <= dst_width and 0 <= dst[1] <= dst_height:
                mapped_positions.append(dst)
                ball_labels.append(int(box.cls[0].item()))

        # Draw 2D Table View
        table_view = np.full((dst_height, dst_width, 3), (26, 99, 15), dtype=np.uint8)
        for (x, y), label in zip(mapped_positions, ball_labels):
            cv2.circle(table_view, (int(x), int(y)), 8, (255, 255, 255), -1)
            cv2.putText(table_view, ball_label_to_name[label], (int(x) - 10, int(y) - 12),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)

        table_view_resized = cv2.resize(table_view, (width, height))
        combined = np.hstack((table_view_resized, frame.copy()))

        # Add Labels
        cv2.putText(combined, "2D Table View", (50, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
        cv2.putText(combined, "Original Video", (width + 50, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)

        out.write(combined)
        frame_idx += 1
        print(f"{frame_idx}/{frames_to_process} : {(frame_idx / frames_to_process) * 100:.1f}%")

    cap.release()
    if out:
        out.release()
    print(f"✅ Output video saved to: {output_path}")


In [11]:
detect_table_and_map_balls_optimized("video_input/vid1.mp4")


0: 640x384 (no detections), 214.2ms
Speed: 3.5ms preprocess, 214.2ms inference, 0.9ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 1 table_play_area, 164.9ms
Speed: 2.3ms preprocess, 164.9ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 384)
✅ Table detected and homography set.

0: 640x384 1 8_ball, 4 stripes, 7 solids, 150.2ms
Speed: 3.8ms preprocess, 150.2ms inference, 1.2ms postprocess per image at shape (1, 3, 640, 384)
2/253 : 0.8%

0: 640x384 7 8_balls, 2 stripes, 8 solids, 118.4ms
Speed: 2.1ms preprocess, 118.4ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 384)
3/253 : 1.2%

0: 640x384 5 8_balls, 3 stripes, 7 solids, 114.4ms
Speed: 2.1ms preprocess, 114.4ms inference, 0.7ms postprocess per image at shape (1, 3, 640, 384)
4/253 : 1.6%

0: 640x384 7 8_balls, 4 stripes, 8 solids, 119.8ms
Speed: 2.4ms preprocess, 119.8ms inference, 0.9ms postprocess per image at shape (1, 3, 640, 384)
5/253 : 2.0%

0: 640x384 3 8_balls, 6 stripes, 8 soli

### Show predictions on Video

In [None]:
def visualize_table_and_ball_predictions(video_path, output_path, max_frames=700):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"❌ Cannot open video: {video_path}")
        return

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_to_process = min(total_frames, max_frames)

    frame_idx = 0
    out = None  # Will be initialized after first frame

    while frame_idx < frames_to_process:
        ret, frame = cap.read()
        if not ret:
            break

        frame = correct_frame_orientation(frame)

        # Initialize VideoWriter AFTER getting true dimensions
        if out is None:
            height, width = frame.shape[:2]
            combined_width = width  # Only using single view here
            combined_height = height
            out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'),
                                  fps, (combined_width, combined_height))

        # 1. Predict table
        table_results = table_model(frame)[0]
        if table_results.masks:
            mask = table_results.masks[0].data[0].cpu().numpy()
            mask_resized = cv2.resize(mask, (width, height))
            contours, _ = cv2.findContours((mask_resized * 255).astype(np.uint8),
                                           cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                cv2.drawContours(frame, contours, -1, (0, 255, 0), 2)

        # 2. Predict balls
        ball_results = ball_model(frame)[0]
        for box in ball_results.boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            cls_id = int(box.cls[0].item())
            conf = float(box.conf[0])
            label = ball_label_to_name.get(cls_id, f"Ball {cls_id}")
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.putText(frame, f"{label} ({conf:.2f})", (x1, y1 - 8),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

        # 3. Frame index
        cv2.putText(frame, f"Frame: {frame_idx}", (10, frame.shape[0] - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)

        out.write(frame)
        frame_idx += 1
        print(f"{frame_idx}/{frames_to_process} : {(frame_idx / frames_to_process) * 100:.1f}%")

    cap.release()
    if out is not None:
        out.release()

    print(f"✅ Prediction video saved: {output_path}")


In [14]:
visualize_table_and_ball_predictions("video_input/tes.mp4", "video_output/tes_with_preds.mp4")


0: 640x384 1 table_play_area, 187.8ms
Speed: 5.0ms preprocess, 187.8ms inference, 2.1ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 5 8_balls, 1 stripe, 9 solids, 115.9ms
Speed: 2.6ms preprocess, 115.9ms inference, 1.3ms postprocess per image at shape (1, 3, 640, 384)
1/700 : 0.1%

0: 640x384 1 table_play_area, 192.8ms
Speed: 2.7ms preprocess, 192.8ms inference, 1.7ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 3 8_balls, 5 stripes, 8 solids, 102.0ms
Speed: 2.6ms preprocess, 102.0ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 384)
2/700 : 0.3%

0: 640x384 1 table_play_area, 140.6ms
Speed: 2.4ms preprocess, 140.6ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 384)

0: 640x384 1 8_ball, 3 stripes, 10 solids, 105.3ms
Speed: 2.2ms preprocess, 105.3ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 384)
3/700 : 0.4%

0: 640x384 1 table_play_area, 188.2ms
Speed: 2.6ms preprocess, 188.2ms inference, 2.5ms postprocess 