In [1]:
import sys
import os
sys.path.append("D:/tennis_v2/")

import cv2
import torch
import torchvision.transforms as transforms
import numpy as np
import pandas as pd

from model.wasb import HRNet
from utils.interpolater import TrajectoryInterpolator
from utils.kalman_filter import KalmanFilter
from utils.draw import Draw_video
from utils.utils import prepare_json, load_config
from utils.camera_home_graph import undistort_image
from assets.bounce import detect_bounces, draw_cross
from assets.person_detector import PersonDetector
import websocket
import requests

%load_ext autoreload
%autoreload 2

In [2]:
# --- Constants and Configuration ---
CURRENT_DIR = 'D:/tennis_v2/'
# BOUNCE_DETECTOR_PATH = f"{CURRENT_DIR}/weights/ctb_regr_bounce.cbm"

In [None]:
current_dir = 'D:/tennis_v2/'

def preprocess_frame(frame, transform):
    return transform(frame)

def predict_ball_position(prev_positions, width, height):
    if len(prev_positions) < 3:
        return None
    p_t = prev_positions[-1]
    a_t = p_t - 2 * prev_positions[-2] + prev_positions[-3]
    v_t = p_t - prev_positions[-2] + a_t
    predicted_position = p_t + v_t + 0.5 * a_t
    predicted_position = np.clip(predicted_position, [0, 0], [width, height])
    return predicted_position

def run_inference(input_path, output_path="", output_csv_path="", overlay=False):
    # load config
    config = load_config(f"{CURRENT_DIR}/config.yaml")

    # send json data socket
    ws = websocket.WebSocket()
    ws.connect(config['socket_url'])

    # cap url address
    response  = requests.get(config['cap_url'])
    if response.status_code == 200:   
        rtmp_addr = response.json()['data']['videoUrl']
        print(f'success get the camera: {rtmp_addr}')
    else:
        print(f'Error: {response.status_code}')

    # cuda
    device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

    # transform input image
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((config['inp_height'], config['inp_width'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # track tennis ball model(HRNet)
    model = HRNet(cfg=config).to(device)
    checkpoint = torch.load(CURRENT_DIR + config['model_path'], map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    model.eval()

    # # detect people model
    # person_model = PersonDetector()

    # read input video
    cap = cv2.VideoCapture(input_path)   # input_path
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    print(f'success connect camera, the video： {width} X {height}, FPS: {fps}')

    base_name = os.path.splitext(os.path.basename(input_path))[0]
    if output_path == "":
        output_video_path = os.path.join(os.path.dirname(output_path))
    else:
        output_video_path = os.path.join(output_path)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))   # fps
    print(output_video_path)
    frame_number = 0
    frames_buffer = []
    # frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # detect people
    frame_number_outer = 0
    player1_data = []
    player2_data = []

    # kalman
    kf = KalmanFilter()
    kf_ratio_range = 1
    kf_width_range = 8

    interpolator = TrajectoryInterpolator()
    need_interpolator = False
    interpolator_ranage = 10
    interpolator_count = []
    interpolator_start_frame = None


    # bounce
    stable_detect = []
    coordinate_history = []
    frame_history = []   # save the frame, because some frame can not detect the ball, have to predict (x, y) by the context
    
    prev_positions = [] # for blob select
    visited = {}

    # draw
    draw_video = Draw_video()

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        """undistort_image """
        # frame = undistort_image(frame)

        frames_buffer.append(frame)
        frame_history.append(frame)


        #####################################################################################################################
        ########################################           detect people               ######################################
        #####################################################################################################################

        # player_result = person_model.track_players(frame, frame_number_outer)
        # x1, y1, x2, y2, x3, y3, x4, y4 = player_result[1:]
        # player1_data.append([x1, y1, x2, y2])
        # player2_data.append([x3, y3, x4, y4])
        # frame_number_outer += 1

        # cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
        # cv2.rectangle(frame, (x3, y3), (x4, y4), (0, 0, 255), 2)

        #####################################################################################################################
        ########################################           detect tennis ball          ######################################
        #####################################################################################################################

        if len(frames_buffer) == config['frames_in']:
            # Preprocess the frames
            frames_processed = [preprocess_frame(f, transform) for f in frames_buffer]
            input_tensor = torch.cat(frames_processed, dim=0).unsqueeze(0).to(device)

            # Perform inference
            with torch.no_grad():
                outputs = model(input_tensor)[0]  # Get the raw logits

            detected = False
            center_x, center_y, confidence = 0, 0, 0

            for i in range(config['frames_out']):
                output = outputs[0][i]
                # Post-process the output
                output = torch.sigmoid(output)  # Apply sigmoid to the output to get probabilities
                heatmap = output.squeeze().cpu().numpy()
                # print(output. shape, heatmap.shape)

                heatmap = cv2.resize(heatmap, (width, height), interpolation=cv2.INTER_LINEAR)
                heatmap = (heatmap > 0.5).astype(np.float32) * heatmap

                # Find connected components
                num_labels, labels_im, stats, centroids = cv2.connectedComponentsWithStats((heatmap > 0).astype(np.uint8), connectivity=8)

                # Calculate centers of blobs
                blob_centers = []
                for j in range(1, num_labels):  # Skip the background label 0
                    mask = labels_im == j
                    blob_sum = heatmap[mask].sum()
                    if blob_sum > 0:
                        center_x = np.sum(np.where(mask)[1] * heatmap[mask]) / blob_sum
                        center_y = np.sum(np.where(mask)[0] * heatmap[mask]) / blob_sum
                        blob_centers.append((center_x, center_y, blob_sum))
                    
                if blob_centers:
                    predicted_position = predict_ball_position(prev_positions, width, height)
                    if predicted_position is not None:
                        # Select the blob closest to the predicted position
                        distances = [np.sqrt((x - predicted_position[0]) ** 2 + (y - predicted_position[1]) ** 2) for x, y, _ in blob_centers]
                        closest_blob_idx = np.argmin(distances)
                        center_x, center_y, confidence = blob_centers[closest_blob_idx]
                    else:
                        # Select the blob with the highest confidence if no prediction is available
                        blob_centers.sort(key=lambda x: x[2], reverse=True)
                        center_x, center_y, confidence = blob_centers[0]
                    detected = True
                    stable_detect.append(detected)
                    prev_positions.append(np.array([center_x, center_y]))
                    if len(prev_positions) > 3:
                        prev_positions.pop(0)
                
                
                """ if the frame can not detect the tennis, it will wait until the detect comeout 3 times and then fit the miss tennis coordinate """
                if detected == True :
                    coordinate_history.append([center_x, center_y])
                    if not need_interpolator:
                        color = (0, 255, 0)
                        x1, y1, x2, y2 =  kf.xyah_to_xyxy([center_x, center_y, kf_ratio_range, kf_width_range])
                        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                        text = f"{confidence:.2f}"

                        """ draw ball"""
                        # cv2.polylines(overlayed_frame, [points], isClosed=False, color=(0, 255, 255), thickness=2, lineType=cv2.LINE_AA)
                        cv2.rectangle(frame_history[-config['frames_out'] + i],(x1, y1), (x2, y2), color, 2)
                        cv2.putText(frame_history[-config['frames_out'] + i], text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
                        """ draw ball"""

                    else:
                        interpolator_count.append([center_x, center_y])
                        

                else:
                    coordinate_history.append([None, None])
                    if not need_interpolator:
                        need_interpolator = True
                        interpolator_start_frame = frame_number
                    else:
                        interpolator_ranage = min(interpolator_ranage + 1, 20)
                    
                if need_interpolator:
                    if len(interpolator_count) >= interpolator_ranage:
                        color = (0, 255, 255)
                        inter_trac = interpolator.update_detection_history(coordinate_history, interpolator_start_frame, frame_number)
                        for f in range(interpolator_start_frame, frame_number):
                            x, y = inter_trac.loc[f, 'x'], inter_trac.loc[f, 'y']
                            coordinate_history[f][0], coordinate_history[f][1] =  x, y
                            x1, y1, x2, y2 =  kf.xyah_to_xyxy([x, y, kf_ratio_range, kf_width_range])
                            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                            cv2.rectangle(frame_history[f],(x1, y1), (x2, y2), color, 2)
                        interpolator_count = []
                        need_interpolator = False
                        interpolator_ranage = 10
                
        #####################################################################################################################
        ########################################           detect bounce&shot          ######################################
        #####################################################################################################################
                if frame_number > 0 and (frame_number) % 10 == 0 and len(stable_detect) > 30:   #  and not need_interpolator
                    trajectory = pd.DataFrame(coordinate_history, columns=['x', 'y'])
                    bounces, ix_5, x,y, ie = detect_bounces(trajectory) # output_csv_path, path_to_video=output_path, path_to_output_video='d:/tennis_v2/inference/HRnetv3_test_03_with_bounce.mp4'
                    js_data = prepare_json(frame_number  = frame_number,
                                      player1            = player1_data[-30:],
                                      player2            = player2_data[-30:],
                                      ball_coordinate    = coordinate_history[-30:],
                                      ball_event         = ie)
                    ws.send(js_data)
                    visited = ie
                    stable_detect = []
                
                """ draw bounce """
                for key, value in visited.items():
                    if frame_number >= key:
                        if value == 'shot':
                            color = (0, 255, 0) 
                        elif value== 'bounce':
                            color = (0, 0, 255)
                        else:
                            color = (255, 0, 0)
                        draw_cross(frame_history[frame_number], coordinate_history[key][0], coordinate_history[key][1], color=color)
                """ draw bounce """

                if frame_number > 31:
                    cv2.imshow('frame',frame_history[-30])
                    
                print(f'frame_number: {frame_number}: detected: {detected}')
                # print(f'frame_number: {frame_number}/{frame_count}, detected: {detected}')
                frame_number += 1
            frames_buffer = []  # Clear the buffer for the next set of frames
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    for frame in frame_history:
        out.write(frame)

    # Release everything if job is finished
    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [45]:

video_input_path = 'd:/tennis_v2/iphone_test_02.mp4'
file_name= video_input_path.split('/')[-1]
video_output_path = 'd:/tennis_v2/inference/HRnetv3_30f_' + file_name
output_csv_path = 'd:/tennis_v2/inference/HRnetv3_30f_' + file_name + '.csv'

run_inference(video_input_path, video_output_path, output_csv_path)

success get the camera: rtmp://rtmp03open.ys7.com:1935/v3/openlive/BB8121883_1_1?expire=1740734026&id=815270467177811968&t=465c2005e093679e98430835513f82e3d3c8ceee5253fdd34887e54a522f9815&ev=100
success connect camera, the video： 1280 X 720, FPS: 29
d:/tennis_v2/inference/HRnetv3_30f_iphone_test_02.mp4
frame_number: 0: detected: False
frame_number: 1: detected: False
frame_number: 2: detected: False
frame_number: 3: detected: False
frame_number: 4: detected: False
frame_number: 5: detected: False
frame_number: 6: detected: False
frame_number: 7: detected: False
frame_number: 8: detected: False
frame_number: 9: detected: False
frame_number: 10: detected: False
frame_number: 11: detected: False
frame_number: 12: detected: False
frame_number: 13: detected: False
frame_number: 14: detected: False
frame_number: 15: detected: False
frame_number: 16: detected: False
frame_number: 17: detected: False
frame_number: 18: detected: False
frame_number: 19: detected: False
frame_number: 20: detected