In [75]:
import sys
import os
sys.path.append("D:/tennis_v2/")

import cv2
import torch
import torchvision.transforms as transforms
import numpy as np
import pandas as pd

from model.wasb import HRNet
from utils.interpolater import TrajectoryInterpolator
from utils.kalman_filter import KalmanFilter
from utils.draw import Draw_video
from utils.to_json import prepare_json
from assets.bounce import detect_bounces, draw_cross
from assets.person_detector import PersonDetector

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# --- Constants and Configuration ---
CURRENT_DIR = 'D:/tennis_v2/'
MODEL_PATH = f"{CURRENT_DIR}/weights/wasb_tennis_best.pth"
# BOUNCE_DETECTOR_PATH = f"{CURRENT_DIR}/weights/ctb_regr_bounce.cbm"
DELAY = 15  # 延迟帧数

In [None]:
current_dir = 'D:/tennis_v2/'

def preprocess_frame(frame, transform):
    return transform(frame)

def predict_ball_position(prev_positions, width, height):
    if len(prev_positions) < 3:
        return None
    p_t = prev_positions[-1]
    a_t = p_t - 2 * prev_positions[-2] + prev_positions[-3]
    v_t = p_t - prev_positions[-2] + a_t
    predicted_position = p_t + v_t + 0.5 * a_t
    predicted_position = np.clip(predicted_position, [0, 0], [width, height])
    return predicted_position

def run_inference(input_path, output_path="", output_csv_path="", overlay=False):
    config = {
        "name": "hrnet",
        "frames_in": 3,
        "frames_out": 3,
        "inp_height": 288,
        "inp_width": 512,
        "out_height": 288,
        "out_width": 512,
        "rgb_diff": False,
        "out_scales": [0],
        "MODEL": {
            "EXTRA": {
                "FINAL_CONV_KERNEL": 1,
                "PRETRAINED_LAYERS": ['*'],
                "STEM": {
                    "INPLANES": 64,
                    "STRIDES": [1, 1]
                },
                "STAGE1": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 1,
                    "BLOCK": 'BOTTLENECK',
                    "NUM_BLOCKS": [1],
                    "NUM_CHANNELS": [32],
                    "FUSE_METHOD": 'SUM'
                },
                "STAGE2": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 2,
                    "BLOCK": 'BASIC',
                    "NUM_BLOCKS": [2, 2],
                    "NUM_CHANNELS": [16, 32],
                    "FUSE_METHOD": 'SUM'
                },
                "STAGE3": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 3,
                    "BLOCK": 'BASIC',
                    "NUM_BLOCKS": [2, 2, 2],
                    "NUM_CHANNELS": [16, 32, 64],
                    "FUSE_METHOD": 'SUM'
                },
                "STAGE4": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 4,
                    "BLOCK": 'BASIC',
                    "NUM_BLOCKS": [2, 2, 2, 2],
                    "NUM_CHANNELS": [16, 32, 64, 128],
                    "FUSE_METHOD": 'SUM'
                },
                "DECONV": {
                    "NUM_DECONVS": 0,
                    "KERNEL_SIZE": [],
                    "NUM_BASIC_BLOCKS": 2
                }
            },
            "INIT_WEIGHTS": True
        },
        "model_path": f"{current_dir}/weights/wasb_tennis_best.pth",  # Update with your model path
    }
    device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((config['inp_height'], config['inp_width'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    model = HRNet(cfg=config).to(device)
    checkpoint = torch.load(config['model_path'], map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    model.eval()

    person_model = PersonDetector()

    cap = cv2.VideoCapture(input_path)

    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    base_name = os.path.splitext(os.path.basename(input_path))[0]
    if output_path == "":
        output_video_path = os.path.join(os.path.dirname(output_path))
    else:
        output_video_path = os.path.join(output_path)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))   # fps
    print(output_video_path)
    frame_number = 0
    frames_buffer = []
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # detect people
    frame_number_outer = 0
    player1_data = []
    player2_data = []

    # kalman
    kf = KalmanFilter()
    kf_ratio_range = 1
    kf_width_range = 8

    interpolator = TrajectoryInterpolator()
    need_interpolator = False
    interpolator_ranage = 10
    interpolator_count = []
    interpolator_start_frame = None


    # bounce
    coordinate_history = []
    frame_history = []   # save the frame, because some frame can not detect the ball, have to predict (x, y) by the context
    
    prev_positions = [] # for blob select
    visited = {}

    # draw
    draw_video = Draw_video()

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frames_buffer.append(frame)
        frame_history.append(frame)


        #####################################################################################################################
        ########################################           detect people               ######################################
        #####################################################################################################################

        player_result = person_model.track_players(frame, frame_number_outer)
        x1, y1, x2, y2, x3, y3, x4, y4 = player_result[1:]
        player1_data.append([x1, y1, x2, y2])
        player2_data.append([x3, y3, x4, y4])
        frame_number_outer += 1

        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
        cv2.rectangle(frame, (x3, y3), (x4, y4), (0, 0, 255), 2)

        #####################################################################################################################
        ########################################           detect tennis ball          ######################################
        #####################################################################################################################

        if len(frames_buffer) == config['frames_in']:
            # Preprocess the frames
            frames_processed = [preprocess_frame(f, transform) for f in frames_buffer]
            input_tensor = torch.cat(frames_processed, dim=0).unsqueeze(0).to(device)

            # Perform inference
            with torch.no_grad():
                outputs = model(input_tensor)[0]  # Get the raw logits

            detected = False
            center_x, center_y, confidence = 0, 0, 0

            for i in range(config['frames_out']):
                output = outputs[0][i]
                # Post-process the output
                output = torch.sigmoid(output)  # Apply sigmoid to the output to get probabilities
                heatmap = output.squeeze().cpu().numpy()
                # print(output. shape, heatmap.shape)

                heatmap = cv2.resize(heatmap, (width, height), interpolation=cv2.INTER_LINEAR)
                heatmap = (heatmap > 0.5).astype(np.float32) * heatmap

                if overlay:
                    heatmap_normalized_visualization = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
                    heatmap_normalized_visualization = heatmap_normalized_visualization.astype(np.uint8)
                    # Apply color map to the heatmap
                    heatmap_colored = cv2.applyColorMap(heatmap_normalized_visualization, cv2.COLORMAP_JET)
                    # Overlay the heatmap on the original frame
                    overlayed_frame = cv2.addWeighted(frames_buffer[i], 0.6, heatmap_colored, 0.4, 0)

                # Find connected components
                num_labels, labels_im, stats, centroids = cv2.connectedComponentsWithStats((heatmap > 0).astype(np.uint8), connectivity=8)

                # Calculate centers of blobs
                blob_centers = []
                for j in range(1, num_labels):  # Skip the background label 0
                    mask = labels_im == j
                    blob_sum = heatmap[mask].sum()
                    if blob_sum > 0:
                        center_x = np.sum(np.where(mask)[1] * heatmap[mask]) / blob_sum
                        center_y = np.sum(np.where(mask)[0] * heatmap[mask]) / blob_sum
                        blob_centers.append((center_x, center_y, blob_sum))
                    
                if blob_centers:
                    predicted_position = predict_ball_position(prev_positions, width, height)
                    if predicted_position is not None:
                        # Select the blob closest to the predicted position
                        distances = [np.sqrt((x - predicted_position[0]) ** 2 + (y - predicted_position[1]) ** 2) for x, y, _ in blob_centers]
                        closest_blob_idx = np.argmin(distances)
                        center_x, center_y, confidence = blob_centers[closest_blob_idx]
                    else:
                        # Select the blob with the highest confidence if no prediction is available
                        blob_centers.sort(key=lambda x: x[2], reverse=True)
                        center_x, center_y, confidence = blob_centers[0]
                    detected = True
                    prev_positions.append(np.array([center_x, center_y]))
                    if len(prev_positions) > 3:
                        prev_positions.pop(0)
                
                
                """ if the frame can not detect the tennis, it will wait until the detect comeout 3 times and then fit the miss tennis coordinate """
                if detected == True :
                    coordinate_history.append([center_x, center_y])
                    if not need_interpolator:
                        color = (0, 255, 0)
                        x1, y1, x2, y2 =  kf.xyah_to_xyxy([center_x, center_y, kf_ratio_range, kf_width_range])
                        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                        text = f"{confidence:.2f}"

                        """ draw ball"""
                        # cv2.polylines(overlayed_frame, [points], isClosed=False, color=(0, 255, 255), thickness=2, lineType=cv2.LINE_AA)
                        cv2.rectangle(frame_history[-config['frames_out'] + i],(x1, y1), (x2, y2), color, 2)
                        cv2.putText(frame_history[-config['frames_out'] + i], text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
                        """ draw ball"""

                    else:
                        interpolator_count.append([center_x, center_y])
                        

                else:
                    coordinate_history.append([None, None])
                    if not need_interpolator:
                        need_interpolator = True
                        interpolator_start_frame = frame_number
                    else:
                        interpolator_ranage = min(interpolator_ranage + 1, 20)
                    
                if need_interpolator:
                    if len(interpolator_count) >= interpolator_ranage:
                        color = (0, 255, 255)
                        inter_trac = interpolator.update_detection_history(coordinate_history, interpolator_start_frame, frame_number)
                        for f in range(interpolator_start_frame, frame_number):
                            x, y = inter_trac.loc[f, 'x'], inter_trac.loc[f, 'y']
                            coordinate_history[f][0], coordinate_history[f][1] =  x, y
                            x1, y1, x2, y2 =  kf.xyah_to_xyxy([x, y, kf_ratio_range, kf_width_range])
                            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                            cv2.rectangle(frame_history[f],(x1, y1), (x2, y2), color, 2)
                        interpolator_count = []
                        need_interpolator = False
                        interpolator_ranage = 10
                
        #####################################################################################################################
        ########################################           detect bounce&shot          ######################################
        #####################################################################################################################
                if frame_number > 0 and (frame_number) % 30 == 0:   #  and not need_interpolator
                    trajectory = pd.DataFrame(coordinate_history, columns=['x', 'y'])
                    bounces, ix_5, x,y, ie = detect_bounces(trajectory) # output_csv_path, path_to_video=output_path, path_to_output_video='d:/tennis_v2/inference/HRnetv3_test_03_with_bounce.mp4'
                    js_data = prepare_json(frame_number       = frame_number,
                                      player1            = player1_data[-30:],
                                      player2            = player2_data[-30:],
                                      ball_coordinate    = coordinate_history[-30:],
                                      ball_event         = ie)
                   #  print(js_data)
                    # print(js_data)
                    visited = ie
                
                """ draw bounce """
                for key, value in visited.items():
                    if frame_number >= key:
                        if value == 'shot':
                            color = (0, 255, 0) 
                        elif value== 'bounce':
                            color = (0, 0, 255)
                        else:
                            color = (255, 0, 0)
                        draw_cross(frame_history[frame_number], coordinate_history[key][0], coordinate_history[key][1], color=color)
                """ draw bounce """

                print(f'frame_number: {frame_number}/{frame_count}, detected: {detected}')
                frame_number += 1
            frames_buffer = []  # Clear the buffer for the next set of frames

    for frame in frame_history:
        out.write(frame)

    # Release everything if job is finished
    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [89]:

video_input_path = 'd:/tennis_v2/test_02.mp4'
file_name= video_input_path.split('/')[-1]
video_output_path = 'd:/tennis_v2/inference/HRnetv3_30f_' + file_name
output_csv_path = 'd:/tennis_v2/inference/HRnetv3_30f_' + file_name + '.csv'

run_inference(video_input_path, video_output_path, output_csv_path)


The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1`. You can also use `weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT` to get the most up-to-date weights.



d:/tennis_v2/inference/HRnetv3_30f_test_02.mp4
frame_number: 0/361, detected: True
frame_number: 1/361, detected: True
frame_number: 2/361, detected: True
frame_number: 3/361, detected: True
frame_number: 4/361, detected: True
frame_number: 5/361, detected: True
frame_number: 6/361, detected: True
frame_number: 7/361, detected: True
frame_number: 8/361, detected: True
frame_number: 9/361, detected: True
frame_number: 10/361, detected: True
frame_number: 11/361, detected: True
frame_number: 12/361, detected: True
frame_number: 13/361, detected: True
frame_number: 14/361, detected: True
frame_number: 15/361, detected: True
frame_number: 16/361, detected: True
frame_number: 17/361, detected: True
frame_number: 18/361, detected: True
frame_number: 19/361, detected: True
frame_number: 20/361, detected: True
frame_number: 21/361, detected: False
frame_number: 22/361, detected: True
frame_number: 23/361, detected: True
frame_number: 24/361, detected: True
frame_number: 25/361, detected: True



invalid value encountered in arccos



{
    "frame_range": [
        151,
        180
    ],
    "players": {
        "player1": [],
        "player2": []
    },
    "ball_data": {
        "trajectory": [
            {
                "x": 428.8539771448889,
                "y": 165.8956786737845
            },
            {
                "x": 424.6231028370551,
                "y": 160.70073547753017
            },
            {
                "x": 419.4242851856777,
                "y": 154.5993298049647
            },
            {
                "x": 413.55545826256457,
                "y": 150.28718075027516
            },
            {
                "x": 409.3057674316017,
                "y": 146.07191977364118
            },
            {
                "x": 404.183889828872,
                "y": 143.38715091546808
            },
            {
                "x": 399.19779822997043,
                "y": 141.69865275274339
            },
            {
                "x": 394.32148839023745,
                

In [87]:
import json 
def prepare2_json(frame_number, player1, player2):

    # Prepare the JSON object with relevant information
    json_data = {
        "frame_range": [frame_number - 29, frame_number],  # Range of 30 frames
        "players": {
            "player1": [
                {"frame": f, "x1": player[0], "y1": player[1], "x2": player[2], "y2": player[3]}
                for f, player in enumerate(player1) if f >= frame_number - 29 and f <= frame_number
            ],
            "player1": [
                {"frame": f, "x1": player[0], "y1": player[1], "x2": player[2], "y2": player[3]}
                for f, player in enumerate(player2) if f >= frame_number - 29 and f <= frame_number
            ]
        },
    }
    # Print or return the JSON object
    return json.dumps(json_data, indent=4)
player1 = [[704, 478, 801, 630], [703, 481, 802, 631], [704, 480, 804, 631], [705, 479, 801, 630], [710, 478, 807, 630], [709, 476, 812, 632], [718, 475, 819, 630], [729, 474, 825, 630], [743, 472, 832, 629], [771, 468, 843, 630], [781, 468, 839, 629], [787, 464, 843, 624], [793, 464, 857, 623], [804, 465, 864, 624], [816, 463, 882, 620], [824, 466, 890, 623], [829, 466, 904, 622], [835, 467, 912, 624], [838, 468, 925, 622], [842, 467, 934, 622], [845, 465, 938, 622], [846, 467, 937, 620], [847, 467, 943, 620], [849, 464, 943, 620], [850, 463, 938, 619], [851, 459, 941, 617], [858, 454, 943, 616], [861, 449, 941, 612], [864, 445, 945, 610], [868, 444, 945, 605]]
player2 =  player1.copy()
z = prepare2_json(30, player1, player2)

In [88]:
z

'{\n    "frame_range": [\n        1,\n        30\n    ],\n    "players": {\n        "player1": [\n            {\n                "frame": 1,\n                "x1": 703,\n                "y1": 481,\n                "x2": 802,\n                "y2": 631\n            },\n            {\n                "frame": 2,\n                "x1": 704,\n                "y1": 480,\n                "x2": 804,\n                "y2": 631\n            },\n            {\n                "frame": 3,\n                "x1": 705,\n                "y1": 479,\n                "x2": 801,\n                "y2": 630\n            },\n            {\n                "frame": 4,\n                "x1": 710,\n                "y1": 478,\n                "x2": 807,\n                "y2": 630\n            },\n            {\n                "frame": 5,\n                "x1": 709,\n                "y1": 476,\n                "x2": 812,\n                "y2": 632\n            },\n            {\n                "frame": 6,\n   