In [100]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output


In [33]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import os
from multiprocessing import Pool
from loky import get_reusable_executor

# Load templates from specified folders
def load_templates(template_folder, suffix):
    templates = []
    for filename in os.listdir(template_folder):
        if filename.endswith(suffix):
            path = os.path.join(template_folder, filename)
            template = cv2.imread(path, 0)
            templates.append(template)
    return templates

def draw_rectangles(frame, matched_points):
    for (x, y, w, h) in matched_points:
        cv2.rectangle(frame, (x - w // 2, y - h // 2), (x + w // 2, y + h // 2), (255, 0, 0), 2)
    return frame

# Function to get the initial socket positions from the first frame
def get_initial_socket_positions(frame, pocket_template_folder, suffix='.png'):
    pocket_templates = load_templates(pocket_template_folder, suffix)
    pocket_positions = match_templates(frame, pocket_templates, 0.85)
    return pocket_positions

def detect_collision(white_ball_position, other_balls_positions, ball_radius):
    collisions = []
    for index, pos in enumerate(other_balls_positions):
        distance = np.linalg.norm(np.array(white_ball_position[:2]) - np.array(pos[:2]))
        if distance <= ball_radius:  # Assuming all balls, including the white ball, have the same radius
            collisions.append(index)
    return collisions

# Match templates
def match_templates(frame, templates, threshold=0.8):
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    matched_points = []
    for template in templates:
        w, h = template.shape[::-1]
        res = cv2.matchTemplate(gray_frame, template, cv2.TM_CCOEFF_NORMED)
        loc = np.where(res >= threshold)
        for pt in zip(*loc[::-1]):  # Reverse tuple to get (x, y) coordinates
            center_point = (pt[0] + w // 2, pt[1] + h // 2, w, h)
            matched_points.append(center_point)
    return matched_points

def get_socket_positions_from_first_frame(frames, pocket_template_folder):
    if not frames:
        return None, []  # Return None if there are no frames
    
    first_frame = frames[0]
    pocket_templates = load_templates(pocket_template_folder, '.png')
    pocket_positions = match_templates(first_frame, pocket_templates, 0.85)
    return first_frame, pocket_positions

def calculate_middle_sockets(pocket_positions):
    # Assuming pocket_positions have been sorted by their x-coordinate or are in a fixed known order
    leftmost, rightmost = pocket_positions[0], pocket_positions[-1]
    middle_top = ((leftmost[0] + rightmost[0]) // 2, min(leftmost[1], rightmost[1]))
    middle_bottom = ((leftmost[0] + rightmost[0]) // 2, max(leftmost[1], rightmost[1]))
    return middle_top, middle_bottom


def read_video(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    i = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if i < 100:
            i += 1
            continue
        elif i > 150:
            break
        frames.append(frame)
        i += 1
        # print(len(frames))
    cap.release()
    return frames

def draw_text_with_background(image, text, font_scale=0.8, thickness=2, max_width=1000):
    font = cv2.FONT_HERSHEY_COMPLEX
    text_color = (255, 255, 255)
    bg_color = (0, 0, 255)
    # Break the text into multiple lines if it's too long
    words = text.split()
    lines = ['']
    for word in words:
        if cv2.getTextSize(lines[-1] + word, font, font_scale, thickness)[0][0] > max_width:
            lines.append(word)
        else:
            lines[-1] += ' ' + word

    y_offset = 30  # Start slightly lower to avoid being too close to the top edge
    for line in lines:
        (line_width, line_height), baseline = cv2.getTextSize(line, font, font_scale, thickness)
        cv2.rectangle(image, (10, y_offset - line_height), (10 + line_width, y_offset + baseline), bg_color, cv2.FILLED)
        cv2.putText(image, line, (10, y_offset), font, font_scale, text_color, thickness, cv2.LINE_AA)
        y_offset += line_height + baseline

    return image

def process_frame(frame_args):
    frame, i, ball_template_folder, socket_positions = frame_args
    print(f"Processing frame {i}")
    
    ball_templates = load_templates(ball_template_folder, '.png')
    
    # Detect balls in the frame
    ball_positions = match_templates(frame, ball_templates, 0.9)
    print(f"Detected ball positions: {ball_positions}")
    
    # Draw balls and sockets on the frame
    frame_with_balls = draw_rectangles(frame, ball_positions)
    frame_with_pockets = draw_rectangles(frame_with_balls, socket_positions)
    
    # Detect collisions between the white ball and other balls
    white_ball_position = ball_positions[0]
    other_balls_positions = ball_positions[1:]  # All other balls
    collisions = detect_collision(white_ball_position, other_balls_positions, ball_radius)
    print(f"Collisions: {collisions}")

    # Draw middle sockets if there are at least two sockets detected
    if len(socket_positions) >= 2:
        middle_top, middle_bottom = calculate_middle_sockets(socket_positions)
        cv2.circle(frame_with_pockets, middle_top, 10, (0, 255, 255), -1)
        cv2.circle(frame_with_pockets, middle_bottom, 10, (0, 255, 255), -1)

    # Check if any ball is in the pocket
    message = None
    for ball_pos in ball_positions:
        for pocket_pos in socket_positions:
            x, y, w, h = pocket_pos
            pocket_center = (x, y)
            distance = np.linalg.norm(np.array(ball_pos[:2]) - np.array(pocket_center))
            if distance < 100:  # Proximity threshold
                message = f"Ball in the socket at position: {ball_pos[:2]} in frame {i}"
                cv2.putText(frame_with_pockets, message, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
                break  # Stop after the first detection
        if message:
            break  # Stop if a message has been set

    # Draw collisions between balls
    for collision_index in collisions:
        x, y, w, h = ball_positions[collision_index]
        cv2.rectangle(frame_with_balls, (x - w // 2, y - h // 2), (x + w // 2, y + h // 2), (0, 0, 255), 2)
        cv2.putText(frame_with_balls, "Collision Detected", (x - 50, y - 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)

    # Convert the frame to RGB for display
    img = cv2.cvtColor(frame_with_balls, cv2.COLOR_BGR2RGB)
    
    # Save the frame as an image file
    cv2.imwrite(f"results/frame_{i}.jpg", img)

    return None  # Return None after processing the frame

ball_radius = 11
ball_template_folder = 'template'
pocket_template_folder = 'sockets'
video_path = 'bilard_1.mp4'
frames = read_video(video_path)
# Get the first frame and detect socket positions
initial_frame, initial_socket_positions = get_socket_positions_from_first_frame(frames, pocket_template_folder)

if initial_frame is not None:
    executor = get_reusable_executor(max_workers=16, timeout=2)
    frame_args = [(frame, i, ball_template_folder, initial_socket_positions) for i, frame in enumerate(frames)]
    results = executor.map(process_frame, frame_args)
    results_list = list(results)  # Collecting results
else:
    print("No frames available to process.")

Processing frame 0
Processing frame 1
Processing frame 4
Processing frame 5
Processing frame 6
Processing frame 3
Processing frame 2
Processing frame 11
Processing frame 9
Processing frame 8
Processing frame 10
Processing frame 7
Processing frame 12
Processing frame 15
Processing frame 14
Processing frame 13
Detected ball positions: [(1291, 395, 74, 58), (1292, 395, 74, 58), (1289, 396, 74, 58), (1290, 396, 74, 58), (1291, 396, 74, 58), (1292, 396, 74, 58), (1293, 396, 74, 58), (1294, 396, 74, 58), (1289, 397, 74, 58), (1290, 397, 74, 58), (1291, 397, 74, 58), (1292, 397, 74, 58), (1293, 397, 74, 58), (1294, 397, 74, 58), (1288, 398, 74, 58), (1289, 398, 74, 58), (1290, 398, 74, 58), (1291, 398, 74, 58), (1292, 398, 74, 58), (1293, 398, 74, 58), (1294, 398, 74, 58), (1288, 399, 74, 58), (1289, 399, 74, 58), (1290, 399, 74, 58), (1291, 399, 74, 58), (1292, 399, 74, 58), (1293, 399, 74, 58), (1294, 399, 74, 58), (1288, 400, 74, 58), (1289, 400, 74, 58), (1290, 400, 74, 58), (1291, 400, 7

In [8]:
import cv2
import os
import glob

def create_video_from_images(image_folder, output_video, frame_rate=30):
    image_files = glob.glob(os.path.join(image_folder, '*.jpg'))
    image_files = sorted(image_files, key=lambda x: int(x.split("_")[1].split(".")[0]))

    first_image = cv2.imread(image_files[0])
    height, width, layers = first_image.shape

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video, fourcc, frame_rate, (width, height))

    for filename in image_files:
        img = cv2.imread(filename)
        out.write(img)

    out.release()
    cv2.destroyAllWindows()

create_video_from_images('results', 'movie.mp4')