# Libraries

In [None]:
!pip install ultralytics
!pip install dill

In [None]:
# Libraries for image processing, machine learning, and utilities

import cv2

import numpy as np

import math

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision import models

from scipy.spatial import distance

from ultralytics import YOLO

# Video Utils

These functions read_video and write_video are designed to read frames from a video file and write a list of images (frames) back to a video file, respectively.

In [None]:
def read_video(path_video):
    cap = cv2.VideoCapture(path_video)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            frames.append(frame)
        else:
            break
    cap.release()
    return frames, fps, original_width, original_height

def write_video(imgs, fps, path_out='output.mp4'):
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    h, w, _ = imgs[0].shape
    out = cv2.VideoWriter(path_out, fourcc, fps, (w, h))
    for img in imgs:
        out.write(img)
    out.release()

# Ball Detection

In this part, I'm implementing a ball detection and tracking system for video analysis. The core of the detection model is based on TrackNet, an open-source repository specializing in object tracking. The model architecture and initial implementation are adapted from the TrackNet GitHub repository by [yastrebksv](https://github.com/yastrebksv/TrackNet), with modifications and adjustments made to suit specific project requirements.

This code defines a convolutional block (ConvBlock) in PyTorch using nn.Module. It encapsulates a sequence of operations typically used in convolutional neural networks

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, stride=1, bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=bias),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)

This BallTrackerNet class defines a convolutional neural network (CNN) architecture for ball tracking.

The network consists of:

* Encoder: Convolutional layers followed by max pooling to downsample the input image.
* Bottleneck: Further convolutional layers to capture complex features.
* Decoder: Upsampling layers followed by convolutional layers to reconstruct the spatial dimensions.
* Output: Final convolutional layer to produce the desired number of output channels.


In [None]:
import torch.nn as nn

class BallTrackerNet(nn.Module):
    def __init__(self, input_channels=3, out_channels=14):
        super().__init__()
        self.out_channels = out_channels
        self.input_channels = input_channels

        self.conv1 = ConvBlock(in_channels=self.input_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = ConvBlock(in_channels=64, out_channels=128)
        self.conv4 = ConvBlock(in_channels=128, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = ConvBlock(in_channels=128, out_channels=256)
        self.conv6 = ConvBlock(in_channels=256, out_channels=256)
        self.conv7 = ConvBlock(in_channels=256, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv8 = ConvBlock(in_channels=256, out_channels=512)
        self.conv9 = ConvBlock(in_channels=512, out_channels=512)
        self.conv10 = ConvBlock(in_channels=512, out_channels=512)

        self.ups1 = nn.Upsample(scale_factor=2)
        self.conv11 = ConvBlock(in_channels=512, out_channels=256)
        self.conv12 = ConvBlock(in_channels=256, out_channels=256)
        self.conv13 = ConvBlock(in_channels=256, out_channels=256)

        self.ups2 = nn.Upsample(scale_factor=2)
        self.conv14 = ConvBlock(in_channels=256, out_channels=128)
        self.conv15 = ConvBlock(in_channels=128, out_channels=128)

        self.ups3 = nn.Upsample(scale_factor=2)
        self.conv16 = ConvBlock(in_channels=128, out_channels=64)
        self.conv17 = ConvBlock(in_channels=64, out_channels=64)

        self.conv18 = ConvBlock(in_channels=64, out_channels=self.out_channels)

        self._init_weights()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool2(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.pool3(x)

        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)

        x = self.ups1(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)

        x = self.ups2(x)
        x = self.conv14(x)
        x = self.conv15(x)

        x = self.ups3(x)
        x = self.conv16(x)
        x = self.conv17(x)

        x = self.conv18(x)
        return x

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.uniform_(module.weight, -0.05, 0.05)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

The BallDetector class is designed to detect and track a ball in a video using a pre-trained neural network model (BallTrackerNet)

In [None]:
class BallDetector:
    def __init__(self, path_model=None, device='cuda'):
        self.model = BallTrackerNet(input_channels=9, out_channels=256)
        self.device = device

        if path_model:
            self.model.load_state_dict(torch.load(path_model, map_location=device))
            self.model = self.model.to(device)
            self.model.eval()

        self.original_width = None
        self.original_height = None
        self.resized_width = 640
        self.resized_height = 360
        self.scale_factor = None
        self.scaler = torch.cuda.amp.GradScaler()

    def set_scale_factor(self, original_width, original_height):
        self.original_width = original_width
        self.original_height = original_height
        self.scale_factor = self.original_width / self.resized_width

    def infer_model(self, frames):
        ball_track = [(None, None)] * 2
        prev_pred = [None, None]
        batch_size = 16

        for start in range(2, len(frames), batch_size):
            batch_frames = frames[start:start + batch_size]
            imgs_batch = []

            for num in range(len(batch_frames)):
                img = cv2.resize(batch_frames[num], (self.resized_width, self.resized_height))
                img_prev = cv2.resize(frames[start + num - 1], (self.resized_width, self.resized_height))
                img_preprev = cv2.resize(frames[start + num - 2], (self.resized_width, self.resized_height))
                imgs = np.concatenate((img, img_prev, img_preprev), axis=2)
                imgs = imgs.astype(np.float32) / 255.0
                imgs = np.rollaxis(imgs, 2, 0)
                imgs_batch.append(imgs)

            inp = np.stack(imgs_batch, axis=0)
            inp = torch.from_numpy(inp).half().to(self.device)

            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    out = self.model(inp)

            output = out.argmax(dim=1).detach().cpu().numpy()
            for num in range(len(output)):
                x_pred, y_pred = self.postprocess(output[num], prev_pred)
                prev_pred = [x_pred, y_pred]
                ball_track.append((x_pred, y_pred))

        return ball_track

    def postprocess(self, feature_map, prev_pred, max_dist=80):
        feature_map *= 255
        feature_map = feature_map.reshape((self.resized_height, self.resized_width))
        feature_map = feature_map.astype(np.uint8)
        ret, heatmap = cv2.threshold(feature_map, 127, 255, cv2.THRESH_BINARY)
        circles = cv2.HoughCircles(heatmap, cv2.HOUGH_GRADIENT, dp=1, minDist=1, param1=50, param2=2, minRadius=2,
                                   maxRadius=7)
        x, y = None, None
        if circles is not None:
            if prev_pred[0] is not None:
                for i in range(len(circles[0])):
                    x_temp = circles[0][i][0] * self.scale_factor
                    y_temp = circles[0][i][1] * self.scale_factor
                    dist = distance.euclidean((x_temp, y_temp), prev_pred)
                    if dist < max_dist:
                        x, y = x_temp, y_temp
                        break
            else:
                x = circles[0][0][0] * self.scale_factor
                y = circles[0][0][1] * self.scale_factor
        return x, y

    def calculate_ball_speed(self, ball_tracks, fps, meters_per_pixel=0.0145):
        ball_speeds_kmph = []
        for i in range(1, len(ball_tracks)):
            if ball_tracks[i][0] is not None and ball_tracks[i-1][0] is not None:
                dist_pixels = distance.euclidean(ball_tracks[i], ball_tracks[i-1])
                speed_mps = dist_pixels * meters_per_pixel * fps
                speed_kmph = speed_mps * 3.6
                ball_speeds_kmph.append(speed_kmph)
            else:
                ball_speeds_kmph.append(None)
        return ball_speeds_kmph

The next function draws the traced ball position on each frame

In [None]:
def draw_ball_trace(frames, ball_track):
    imgs_res = []
    for i in range(len(frames)):
        img_res = frames[i].copy()
        if ball_track[i][0] is not None:
            img_res = cv2.circle(img_res, (int(ball_track[i][0]), int(ball_track[i][1])), radius=5, color=(0, 255, 0), thickness=2)
            img_res = cv2.putText(img_res, 'ball', org=(int(ball_track[i][0]) + 8, int(ball_track[i][1]) + 8),
                                  fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.8, thickness=2, color=(0, 255, 0))
        imgs_res.append(img_res)
    return imgs_res

# Player Detection

Here, i've implemented a custom player detection model using YOLO v8. The training code and details of this model are available in the notebook "Training.ipynb".



The PlayerDetector class is designed to detect and track players in a video using a YOLO model.

In [None]:
class PlayerDetector:
    def __init__(self, path_model, device='cuda'):
        self.model = YOLO(path_model, verbose=True).to(device)
        self.device = device
        self.id_counter = 0
        self.players = {}
        self.player_colors = {}

    def preprocess_frames(self, frames):
        batch = []
        for frame in frames:
            img = cv2.resize(frame, (640, 640)).astype(np.float32) / 255.0
            img = np.rollaxis(img, 2, 0)
            batch.append(img)
        return np.array(batch)

    def infer_model(self, frames, original_width, original_height):
        batch = self.preprocess_frames(frames)
        batch_tensor = torch.from_numpy(batch).float().to(self.device)

        with torch.no_grad():
            results = self.model(batch_tensor)

        player_detections = []
        for i, result in enumerate(results):
            frame = frames[i]
            detections = self.postprocess(result, original_width, original_height)
            tracked_players = self.track_players(frame, detections)
            player_detections.append(tracked_players)

        return player_detections

    def postprocess(self, result, original_width, original_height):
        boxes = result.boxes.xyxy.cpu().numpy()
        scores = result.boxes.conf.cpu().numpy()
        classes = result.boxes.cls.cpu().numpy()

        scale_x = original_width / 640
        scale_y = original_height / 640
        scaled_boxes = boxes * [scale_x, scale_y, scale_x, scale_y]

        valid_indices = (scores > 0.5) & (classes == 0)
        return [(scaled_boxes[i], scores[i], classes[i]) for i in range(len(scores)) if valid_indices[i]]

    def track_players(self, frame, detections):
        tracked_players = {}
        for box, score, _ in detections:
            x1, y1, x2, y2 = box
            box_center = ((x1 + x2) / 2, (y1 + y2) / 2)

            matched_id = None
            for player_id, (prev_box, _) in self.players.items():
                prev_x1, prev_y1, prev_x2, prev_y2 = prev_box
                prev_box_center = ((prev_x1 + prev_x2) / 2, (prev_y1 + prev_y2) / 2)
                distance = np.linalg.norm(np.array(box_center) - np.array(prev_box_center))

                if distance < 50:
                    matched_id = player_id
                    break

            if matched_id is None:
                if len(self.players) >= 2:
                    continue

                matched_id = self.id_counter
                self.id_counter += 1
                self.player_colors[matched_id] = (128, 128, 128) if matched_id % 2 == 0 else (147, 20, 255)

            tracked_players[matched_id] = (box, score)
            self.players[matched_id] = (box, score)

        for player_id, (box, score) in tracked_players.items():
            x1, y1, x2, y2 = box
            color = self.player_colors[player_id]
            frame = cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
            frame = cv2.putText(frame, f'Player {player_id}', (int(x1), int(y1) - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        return tracked_players

# Key Points Detection

Here i'm implementing a video analysis system focused on detecting and tracking keypoints on a tennis court. The detection model I'm using is adapted from a repository by [abdullahtarek](https://github.com/abdullahtarek/tennis_analysis), specifically designed for tennis court analysis and keypoints detection. This model forms the core architecture of my project, customized and adjusted to fit the specific requirements of my application.

The CourtLineDetector class utilizes a pretrained ResNet50 model to detect court lines (keypoints) in images. It loads connections between keypoints from a file, refines keypoints to intersections of detected lines, and can draw these keypoints and connections on images or video frames.

In [None]:
class CourtLineDetector:
    def __init__(self, model_path, connections_file):
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 14*2)
        self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.connections = self.load_connections(connections_file)
        self.stabilized_keypoints = None

    def load_connections(self, connections_file):
        connections = []
        with open(connections_file, 'r') as f:
            next(f)
            for line in f:
                start, end, distance = line.strip().split(';')
                connections.append((int(start), int(end), float(distance)))
        return connections

    def detect_keypoints(self, frame):
        image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image_tensor = self.transform(image_rgb).unsqueeze(0)
        with torch.no_grad():
            outputs = self.model(image_tensor)
        keypoints = outputs.squeeze().cpu().numpy()
        original_h, original_w = frame.shape[:2]
        keypoints[::2] *= original_w / 224.0
        keypoints[1::2] *= original_h / 224.0
        keypoints = self.refine_keypoints(frame, keypoints)
        return keypoints

    def get_keypoints(self, frame):
        if self.stabilized_keypoints is None:
            self.stabilized_keypoints = self.detect_keypoints(frame)
        return self.stabilized_keypoints

    def refine_keypoints(self, frame, keypoints):
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)
        edges = cv2.Canny(thresh, 50, 150, apertureSize=3)
        lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100, minLineLength=50, maxLineGap=10)

        if lines is not None:
            lines = lines.squeeze()
            intersections = self.find_intersections(lines)
            adjusted_keypoints = []
            for i in range(0, len(keypoints), 2):
                x, y = int(keypoints[i]), int(keypoints[i + 1])
                nearest_intersection = self.find_nearest_intersection((x, y), intersections)
                adjusted_keypoints.extend(nearest_intersection)
            return np.array(adjusted_keypoints)
        else:
            return keypoints

    def find_intersections(self, lines):
        intersections = []
        for i, line1 in enumerate(lines):
            for line2 in lines[i + 1:]:
                intersection = self.get_line_intersection(line1, line2)
                if intersection is not None:
                    intersections.append(intersection)
        return intersections

    def get_line_intersection(self, line1, line2):
        x1, y1, x2, y2 = line1
        x3, y3, x4, y4 = line2

        denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
        if denominator == 0:
            return None

        intersect_x = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denominator
        intersect_y = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denominator
        return int(intersect_x), int(intersect_y)

    def find_nearest_intersection(self, point, intersections):
        x, y = point
        nearest_point = min(intersections, key=lambda p: np.sqrt((p[0] - x)**2 + (p[1] - y)**2))
        return nearest_point

    def draw_keypoints(self, image, keypoints):
        for i in range(0, len(keypoints), 2):
            x = int(keypoints[i])
            y = int(keypoints[i+1])
            cv2.circle(image, (x, y), 5, (0, 0, 255), -1)
            cv2.putText(image, str(i//2), (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
        return image

    def draw_connections(self, image, keypoints):
        for start, end, distance in self.connections:
            x1, y1 = int(keypoints[start * 2]), int(keypoints[start * 2 + 1])
            x2, y2 = int(keypoints[end * 2]), int(keypoints[end * 2 + 1])
            cv2.line(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image, f"{distance:.2f} m", ((x1 + x2) // 2, (y1 + y2) // 2),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
            distance_pixels = math.sqrt((x2 - x1)**2 + (y2 - y1)**2)
            cv2.putText(image, f"{distance_pixels:.2f} px", ((x1 + x2) // 2, (y1 + y2) // 2 + 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
        return image

    def draw_keypoints_and_connections_on_video(self, video_frames, keypoints_list):
        output_video_frames = []
        for frame, keypoints in zip(video_frames, keypoints_list):
            frame = self.draw_keypoints(frame, keypoints)
            frame = self.draw_connections(frame, keypoints)
            output_video_frames.append(frame)
        return output_video_frames

# Map

The MiniCourt class is designed to create and manage a minimap overlay on a video frame, typically for visualizing keypoints, connections, players, and other elements related to a sports scene.

In [None]:
class MiniCourt:
    def __init__(self, frame):
        self.drawing_rectangle_width = 250
        self.drawing_rectangle_height = 500
        self.padding_court = 20
        self.frame = frame
        self.minimap = self.create_minimap()
        self.players = {}
        self.player_colors = {}
        self.ball_speed = None

    def create_minimap(self):
        minimap = np.ones((self.drawing_rectangle_height, self.drawing_rectangle_width, 3), dtype=np.uint8) * 255
        return minimap

    def draw_keypoints(self, keypoints):
        for i in range(0, len(keypoints), 2):
            x = int(keypoints[i] * self.drawing_rectangle_width / self.frame.shape[1])
            y = int(keypoints[i+1] * self.drawing_rectangle_height / self.frame.shape[0])
            cv2.circle(self.minimap, (x, y), 5, (0, 0, 255), -1)
        return self.minimap

    def draw_connections(self, keypoints, connections):
        for start, end, _ in connections:
            x1 = int(keypoints[start * 2] * self.drawing_rectangle_width / self.frame.shape[1])
            y1 = int(keypoints[start * 2 + 1] * self.drawing_rectangle_height / self.frame.shape[0])
            x2 = int(keypoints[end * 2] * self.drawing_rectangle_width / self.frame.shape[1])
            y2 = int(keypoints[end * 2 + 1] * self.drawing_rectangle_height / self.frame.shape[0])
            cv2.line(self.minimap, (x1, y1), (x2, y2), (255, 0, 0), 2)
        return self.minimap

    def draw_ball(self, ball_position):
        if ball_position is not None and ball_position[0] is not None and ball_position[1] is not None:
            x = int(ball_position[0] * self.drawing_rectangle_width / self.frame.shape[1])
            y = int(ball_position[1] * self.drawing_rectangle_height / self.frame.shape[0])
            cv2.circle(self.minimap, (x, y), 5, (0, 255, 0), -1)
        return self.minimap

    def draw_players(self, players, player_colors):
        for player_id, (x, y) in players.items():
            player_color = player_colors[player_id]
            player_x = int(x * self.drawing_rectangle_width / self.frame.shape[1])
            player_y = int(y * self.drawing_rectangle_height / self.frame.shape[0])
            cv2.rectangle(self.minimap, (player_x - 10, player_y - 10), (player_x + 10, player_y + 10), player_color, -1)
        return self.minimap

    def add_title(self, title):
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        thickness = 2
        text_size = cv2.getTextSize(title, font, font_scale, thickness)[0]
        text_width = text_size[0]
        text_height = text_size[1]

        x = (self.drawing_rectangle_width - text_width) // 2
        y = text_height + 10

        cv2.putText(self.minimap, title, (x, y), font, font_scale, (0, 0, 0), thickness)

    def add_minimap_to_frame(self):
        height, width, _ = self.frame.shape
        y_start = (height - self.drawing_rectangle_height) // 2
        self.frame[y_start:y_start+self.drawing_rectangle_height,
                   width-self.drawing_rectangle_width-self.padding_court:width-self.padding_court] = self.minimap
        return self.frame

    def draw_ball_speed(self, ball_speed_kmph, max_speed_kmph=100):
        if ball_speed_kmph is not None:
            x_start = self.frame.shape[1] - self.drawing_rectangle_width - self.padding_court
            y_start = self.frame.shape[0] - 250

            speed_kmph_text = f'Speed (km/h):\n{ball_speed_kmph:.2f}'
            for idx, line in enumerate(speed_kmph_text.split('\n')):
                cv2.putText(self.frame, line, (x_start + 10, y_start + 20 + idx * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)

            bar_width = int((ball_speed_kmph / max_speed_kmph) * 300)
            bar_height = 30
            bar_x_start = x_start - 60
            bar_y_start = y_start + 90

            cv2.rectangle(self.frame, (bar_x_start, bar_y_start), (bar_x_start + 300, bar_y_start + bar_height), (200, 200, 200), -1)

            red_value = int((ball_speed_kmph / max_speed_kmph) * 255)
            green_value = 255 - red_value
            bar_color = (0, green_value, red_value)

            cv2.rectangle(self.frame, (bar_x_start, bar_y_start), (bar_x_start + bar_width, bar_y_start + bar_height), bar_color, -1)

            for i in range(0, max_speed_kmph + 1, 10):
                line_x = bar_x_start + int((i / max_speed_kmph) * 300)
                cv2.line(self.frame, (line_x, bar_y_start - 5), (line_x, bar_y_start + bar_height + 5), (0, 0, 0), 1)
                cv2.putText(self.frame, f'{i}', (line_x - 10, bar_y_start + bar_height + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

        return self.frame

# Detector

This UnifiedDetector class integrates various components (BallDetector, PlayerDetector, CourtLineDetector, and MiniCourt) to process frames, detect objects (balls and players), annotate frames with detections and keypoints, and create a minimap overlay with additional annotations such as player positions and ball speed.

In [None]:
class UnifiedDetector:
    def __init__(self, path_ball_model, path_player_model, path_keypoints_model, path_connections_file, device='cuda'):
        self.ball_detector = BallDetector(path_ball_model, device)
        self.player_detector = PlayerDetector(path_player_model, device)
        self.court_line_detector = CourtLineDetector(path_keypoints_model, path_connections_file)
        self.device = device
        self.stabilized_keypoints = None
        self.minimap = None

    def process_frame_others(self, frame):
        original_height, original_width = frame.shape[:2]
        player_detections = self.player_detector.infer_model([frame], original_width, original_height)[0]
        keypoints = self.court_line_detector.get_keypoints(frame)
        return player_detections, keypoints

    def process_video(self, frames, original_width, original_height, batch_size=16):
        self.ball_detector.set_scale_factor(original_width, original_height)
        ball_tracks = self.ball_detector.infer_model(frames)

        player_tracks = []
        court_keypoints = []

        for i in range(0, len(frames), batch_size):
            batch_frames = frames[i:i + batch_size]
            batch_player_detections = self.player_detector.infer_model(batch_frames, original_width, original_height)
            for frame, player_detections in zip(batch_frames, batch_player_detections):
                keypoints = self.court_line_detector.get_keypoints(frame)
                player_tracks.append(player_detections)
                court_keypoints.append(keypoints)

        return ball_tracks, player_tracks, court_keypoints

    def annotate_frames(self, frames, ball_tracks, player_tracks, court_keypoints, fps):
        annotated_frames = []
        ball_speeds_kmph = self.ball_detector.calculate_ball_speed(ball_tracks, fps)

        for frame, ball_position, player_detections, keypoints, speed_kmph in zip(frames, ball_tracks, player_tracks, court_keypoints, ball_speeds_kmph):
            frame = self.court_line_detector.draw_keypoints(frame, keypoints)
            frame = self.court_line_detector.draw_connections(frame, keypoints)

            for player_id, (box, _) in player_detections.items():
                x1, y1, x2, y2 = box
                color = self.player_detector.player_colors[player_id]
                frame = cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
                frame = cv2.putText(frame, f'Player {player_id}', (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

            frame = draw_ball_trace([frame], [ball_position])[0]

            if self.minimap is None:
                minimap = MiniCourt(frame)
                minimap.draw_keypoints(keypoints)
                minimap.draw_connections(keypoints, self.court_line_detector.connections)
                minimap.add_title("Minimap")
                self.minimap = minimap.minimap.copy()

            minimap_frame = self.minimap.copy()
            minimap = MiniCourt(frame)
            minimap.minimap = minimap_frame
            minimap.draw_ball(ball_position)

            player_positions = {}
            for player_id, (box, _) in player_detections.items():
                x1, y1, x2, y2 = box
                player_positions[player_id] = ((x1 + x2) / 2, (y1 + y2) / 2)

            minimap.draw_players(player_positions, self.player_detector.player_colors)

            if speed_kmph is not None:
                minimap.draw_ball_speed(speed_kmph)

            frame_with_lines = minimap.add_minimap_to_frame()
            annotated_frames.append(frame_with_lines)

        return annotated_frames

# Main

This script is designed to analyze a tennis video by detecting and tracking the ball and players, annotating the frames with these detections, and saving the annotated video.

In [18]:
path_video = 'Tennis-Video-Analysis/Videos/Inputs/Test_1.mp4'
path_ball_model = 'Tennis-Video-Analysis/Models/ball.pt'
path_player_model = 'Tennis-Video-Analysis/Models/best.pt'
path_keypoints_model = '/Tennis-Video-Analysis/Models/keypoints_model.pth'
path_connections_file = 'Tennis-Video-Analysis/Configs/court_connections.txt'
path_output_video = 'Tennis-Video-Analysis/Videos/Inputs/Test_output_1.mp4'

frames, fps, original_width, original_height = read_video(path_video)

unified_detector = UnifiedDetector(path_ball_model, path_player_model, path_keypoints_model, path_connections_file)

ball_tracks, player_tracks, court_keypoints = unified_detector.process_video(frames, original_width, original_height)

annotated_frames = unified_detector.annotate_frames(frames, ball_tracks, player_tracks, court_keypoints, fps)

write_video(annotated_frames, fps, path_output_video)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0



0: 640x640 2 Players, 3.1ms
1: 640x640 2 Players, 3.1ms
2: 640x640 2 Players, 3.1ms
3: 640x640 2 Players, 3.1ms
4: 640x640 2 Players, 3.1ms
5: 640x640 2 Players, 3.1ms
6: 640x640 2 Players, 3.1ms
7: 640x640 2 Players, 3.1ms
8: 640x640 2 Players, 3.1ms
9: 640x640 3 Players, 3.1ms
10: 640x640 2 Players, 3.1ms
11: 640x640 2 Players, 3.1ms
12: 640x640 2 Players, 3.1ms
13: 640x640 2 Players, 3.1ms
14: 640x640 2 Players, 3.1ms
15: 640x640 2 Players, 3.1ms
Speed: 0.0ms preprocess, 3.1ms inference, 1.9ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 2 Players, 3.0ms
1: 640x640 2 Players, 3.0ms
2: 640x640 2 Players, 3.0ms
3: 640x640 2 Players, 3.0ms
4: 640x640 2 Players, 3.0ms
5: 640x640 2 Players, 3.0ms
6: 640x640 2 Players, 3.0ms
7: 640x640 2 Players, 3.0ms
8: 640x640 2 Players, 3.0ms
9: 640x640 2 Players, 3.0ms
10: 640x640 2 Players, 3.0ms
11: 640x640 2 Players, 3.0ms
12: 640x640 2 Players, 3.0ms
13: 640x640 2 Players, 3.0ms
14: 640x640 2 Players, 3.0ms
15: 640x640 2 Players, 