In [1]:
import sys
sys.path.append('E:/tennis_v2/')
print(sys.path)
import os
import cv2
import torch
import pathlib
import torchvision.transforms as transforms
import numpy as np
import pandas as pd

from model.wasb import HRNet
from utils.kalman_filter import KalmanFilter
from utils.LoadData import resize_feature
from model.bounce_detector import BounceDetector
from collections import deque
from assets.bounce import detect_bounces

# from asset.tennis_v1.LoadData import resize_feature

%load_ext autoreload
%autoreload 2

['D:\\python\\Python3106-64', 'd:\\python\\Python3106-64\\python310.zip', 'd:\\python\\Python3106-64\\DLLs', 'd:\\python\\Python3106-64\\lib', '', 'C:\\Users\\10765\\AppData\\Roaming\\Python\\Python310\\site-packages', 'C:\\Users\\10765\\AppData\\Roaming\\Python\\Python310\\site-packages\\win32', 'C:\\Users\\10765\\AppData\\Roaming\\Python\\Python310\\site-packages\\win32\\lib', 'C:\\Users\\10765\\AppData\\Roaming\\Python\\Python310\\site-packages\\Pythonwin', 'd:\\python\\Python3106-64\\lib\\site-packages', 'd:\\python\\Python3106-64\\lib\\site-packages\\stable_diffusion-0.0.1-py3.10.egg', 'd:\\python\\Python3106-64\\lib\\site-packages\\k_diffusion-0.2.0.dev0-py3.10.egg', 'd:\\python\\Python3106-64\\lib\\site-packages\\wandb-0.17.5-py3.10.egg', 'd:\\python\\Python3106-64\\lib\\site-packages\\dctorch-0.1.2-py3.10.egg', 'd:\\python\\Python3106-64\\lib\\site-packages\\clip_anytorch-2.6.0-py3.10.egg', 'd:\\python\\Python3106-64\\lib\\site-packages\\setproctitle-1.3.3-py3.10-win-amd64.egg'

In [2]:
# --- Constants and Configuration ---
CURRENT_DIR = 'E:/tennis_v2/'
MODEL_PATH = f"{CURRENT_DIR}/weights/wasb_tennis_best.pth"
BOUNCE_DETECTOR_PATH = f"{CURRENT_DIR}/weights/ctb_regr_bounce.cbm"
DELAY = 15  # 延迟帧数

In [None]:
current_dir = 'E:/tennis_v2/'

def preprocess_frame(frame, transform):
    return transform(frame)

def predict_ball_position(prev_positions, width, height):
    if len(prev_positions) < 3:
        return None
    p_t = prev_positions[-1]
    a_t = p_t - 2 * prev_positions[-2] + prev_positions[-3]
    v_t = p_t - prev_positions[-2] + a_t
    predicted_position = p_t + v_t + 0.5 * a_t
    predicted_position = np.clip(predicted_position, [0, 0], [width, height])
    return predicted_position

def run_inference(input_path, output_path="", output_csv_path="", overlay=False):
    config = {
        "name": "hrnet",
        "frames_in": 3,
        "frames_out": 3,
        "inp_height": 288,
        "inp_width": 512,
        "out_height": 288,
        "out_width": 512,
        "rgb_diff": False,
        "out_scales": [0],
        "MODEL": {
            "EXTRA": {
                "FINAL_CONV_KERNEL": 1,
                "PRETRAINED_LAYERS": ['*'],
                "STEM": {
                    "INPLANES": 64,
                    "STRIDES": [1, 1]
                },
                "STAGE1": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 1,
                    "BLOCK": 'BOTTLENECK',
                    "NUM_BLOCKS": [1],
                    "NUM_CHANNELS": [32],
                    "FUSE_METHOD": 'SUM'
                },
                "STAGE2": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 2,
                    "BLOCK": 'BASIC',
                    "NUM_BLOCKS": [2, 2],
                    "NUM_CHANNELS": [16, 32],
                    "FUSE_METHOD": 'SUM'
                },
                "STAGE3": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 3,
                    "BLOCK": 'BASIC',
                    "NUM_BLOCKS": [2, 2, 2],
                    "NUM_CHANNELS": [16, 32, 64],
                    "FUSE_METHOD": 'SUM'
                },
                "STAGE4": {
                    "NUM_MODULES": 1,
                    "NUM_BRANCHES": 4,
                    "BLOCK": 'BASIC',
                    "NUM_BLOCKS": [2, 2, 2, 2],
                    "NUM_CHANNELS": [16, 32, 64, 128],
                    "FUSE_METHOD": 'SUM'
                },
                "DECONV": {
                    "NUM_DECONVS": 0,
                    "KERNEL_SIZE": [],
                    "NUM_BASIC_BLOCKS": 2
                }
            },
            "INIT_WEIGHTS": True
        },
        "model_path": f"{current_dir}/weights/wasb_tennis_best.pth",  # Update with your model path
    }
    device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((config['inp_height'], config['inp_width'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    model = HRNet(cfg=config).to(device)
    checkpoint = torch.load(config['model_path'], map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    model.eval()

    # bounce model
    # bounce_model = Model().to(device)
    # bounce_model.load_state_dict(torch.load(f'{current_dir}/model_weight/model_state_dict_32_v3.pth'))
    # bounce_model.eval()  # 设置为评估模式

    bounce_detector = BounceDetector(f'{current_dir}/weights/ctb_regr_bounce.cbm')

    cap = cv2.VideoCapture(input_path)

    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    base_name = os.path.splitext(os.path.basename(input_path))[0]
    if output_path == "":
        output_video_path = os.path.join(os.path.dirname(output_path))
    else:
        output_video_path = os.path.join(output_path)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, 1, (width, height))   # fps
    print(output_video_path)
    frame_number = 0
    frames_buffer = []
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 初始化卡尔曼滤波器
    kf = KalmanFilter()
    kf_mean, kf_covariance = None, None
    stable_frames_threshold = 1
    stable_ious_threshold = 0.1
    kf_ratio_range = 1
    kf_width_range = 8
    stable_frames = 0
    detection_queue = deque(maxlen=stable_frames_threshold)  # 检测队列

    # bounce
    coordinate_history = []
    save = False
    
    # draw
    points = None
    
    prev_positions = []
    positions = []


    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frames_buffer.append(frame)
        if len(frames_buffer) == config['frames_in']:
            # Preprocess the frames
            frames_processed = [preprocess_frame(f, transform) for f in frames_buffer]
            input_tensor = torch.cat(frames_processed, dim=0).unsqueeze(0).to(device)

            # Perform inference
            with torch.no_grad():
                outputs = model(input_tensor)[0]  # Get the raw logits

            detected = False
            center_x, center_y, confidence = 0, 0, 0

            for i in range(config['frames_out']):
                output = outputs[0][i]
                # Post-process the output
                output = torch.sigmoid(output)  # Apply sigmoid to the output to get probabilities
                heatmap = output.squeeze().cpu().numpy()
                # print(output. shape, heatmap.shape)

                heatmap = cv2.resize(heatmap, (width, height), interpolation=cv2.INTER_LINEAR)
                heatmap = (heatmap > 0.5).astype(np.float32) * heatmap

                if overlay:
                    heatmap_normalized_visualization = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
                    heatmap_normalized_visualization = heatmap_normalized_visualization.astype(np.uint8)
                    # Apply color map to the heatmap
                    heatmap_colored = cv2.applyColorMap(heatmap_normalized_visualization, cv2.COLORMAP_JET)
                    # Overlay the heatmap on the original frame
                    overlayed_frame = cv2.addWeighted(frames_buffer[i], 0.6, heatmap_colored, 0.4, 0)

                # Find connected components
                num_labels, labels_im, stats, centroids = cv2.connectedComponentsWithStats((heatmap > 0).astype(np.uint8), connectivity=8)

                # Calculate centers of blobs
                blob_centers = []
                for j in range(1, num_labels):  # Skip the background label 0
                    mask = labels_im == j
                    blob_sum = heatmap[mask].sum()
                    if blob_sum > 0:
                        center_x = np.sum(np.where(mask)[1] * heatmap[mask]) / blob_sum
                        center_y = np.sum(np.where(mask)[0] * heatmap[mask]) / blob_sum
                        blob_centers.append((center_x, center_y, blob_sum))
                    
                if blob_centers:
                    predicted_position = predict_ball_position(prev_positions, width, height)
                    if predicted_position is not None:
                        # Select the blob closest to the predicted position
                        distances = [np.sqrt((x - predicted_position[0]) ** 2 + (y - predicted_position[1]) ** 2) for x, y, _ in blob_centers]
                        closest_blob_idx = np.argmin(distances)
                        center_x, center_y, confidence = blob_centers[closest_blob_idx]
                    else:
                        # Select the blob with the highest confidence if no prediction is available
                        blob_centers.sort(key=lambda x: x[2], reverse=True)
                        center_x, center_y, confidence = blob_centers[0]
                    detected = True
                    prev_positions.append(np.array([center_x, center_y]))
                    if len(prev_positions) > 3:
                        prev_positions.pop(0)
                
                
                # Draw a circle on the detected ball
                if detected:
                    positions.append((center_x, center_y))
                    detection_queue.append((center_x, center_y, confidence))
                    color = (0, 255, 0)

                else:
                    # kalman filter predict
                    if kf_mean is not None and kf_covariance is not None:
                        # print(kf_mean)
                        kf_mean, kf_covariance = kf.predict(kf_mean, kf_covariance)                        
                        center_x, center_y = kf_mean[:2]
                        color = (0, 255, 255)
                coordinate_history.append([center_x, center_y])
                ############################################################################################################
                ####################################                Draw                   #################################
                ############################################################################################################

                ################################################################## test kalman filter #########################################################################
                # if kf_mean is not None and kf_covariance is not None:
                #     kf_mean, kf_covariance = kf.predict(kf_mean, kf_covariance)
                #     cv2.rectangle(overlayed_frame if overlay else frames_buffer[0],(x1, y1), (x2, y2), (255, 255, 255), 2)
                #     cv2.putText(overlayed_frame if overlay else frames_buffer[0], "test_kalman", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
                ################################################################## test kalman filter #########################################################################
                if (center_x is not None) and (center_y is not None) and (kf_mean is not None) and (kf_covariance is not None):
                    # print(detected, kf_mean)
                    x1, y1, x2, y2 =  kf.xyah_to_xyxy([center_x, center_y, kf_ratio_range, kf_width_range])
                    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                    text = f"{confidence:.2f}"  # 格式化置信度到小数点后两位

                    if overlay:
                        # cv2.polylines(overlayed_frame, [points], isClosed=False, color=(0, 255, 255), thickness=2, lineType=cv2.LINE_AA)
                        cv2.rectangle(overlayed_frame,(x1, y1), (x2, y2), color, 2)
                        cv2.putText(overlayed_frame, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
   
                        
                    else:
                        # cv2.polylines(overlayed_frame, [points], isClosed=False, color=(0, 255, 255), thickness=2, lineType=cv2.LINE_AA)
                        cv2.rectangle(frames_buffer[i],(x1, y1), (x2, y2), color, 2)
                        cv2.putText(frames_buffer[i], text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)


                ################################################################这个逻辑貌似不太理想############################################################################
                # kalman filter
                if len(detection_queue) == stable_frames_threshold:
                    # 筛选可靠检测
                    reliable_detections = [
                        (x, y) for x, y, c in detection_queue if c > 0.25  # 例如，置信度阈值为 0.8
                    ]

                    # 如果有足够的可靠检测，则计算平均位置
                    if len(reliable_detections) >= stable_frames_threshold // 2:  # 例如，至少需要一半的检测可靠
                        avg_x = np.mean([x for x, y in reliable_detections])
                        avg_y = np.mean([y for x, y in reliable_detections])
                        if kf_mean is None and kf_covariance is None or stable_frames == 0:
                            """
                            **初始化阶段**: 当self.kf_mean和self.kf_covariance尚未初始化,
                            或者self.stable_frames为0时,使用第一个检测到的边界框来初始化卡尔曼滤波的状态(均值和协方差矩阵)
                            此时直接选择IoU最高的掩码,并初始化KF。
                            """

                            kf_mean, kf_covariance = kf.initiate(kf.xyxy_to_xyah([avg_x, avg_y, kf_ratio_range, kf_width_range]))
                            stable_frames += 1
                                # 清空队列，为下一次初始化做准备
                    detection_queue.clear()
                
                    if stable_frames < stable_frames_threshold:
                        """
                        当self.stable_frames小于阈值时,卡尔曼滤波进行预测,但不进行更新,除非当前帧的IoU足够高.
                        此时,如果检测到的IoU超过阈值,则用当前边界框更新KF状态,并增加stable_frames的计数,否则重置计数。
                        """
                    
                        kf_mean, kf_covariance = kf.predict(kf_mean, kf_covariance)
                        predict_iou = kf._compute_iou(kf.xyah_to_xyxy([center_x, center_y, kf_ratio_range, kf_width_range]), kf.xyah_to_xyxy(kf_mean[:4]))

                        if predict_iou > stable_ious_threshold:
                            kf_mean, kf_covariance = kf.update(kf_mean, kf_covariance, [center_x, center_y, kf_ratio_range, kf_width_range])
                            stable_frames += 1
                        else:
                            stable_frames -= 1
                    else:
                        """
                        **稳定后的更新阶段**:当stable_frames达到阈值后,KF进入正常预测和更新循环.
                        每次预测后, 计算当前多个候选掩码的边界框与KF预测的边界框之间的IoU, 
                        然后结合KF的IoU和模型预测的IoU,加权后选择最佳掩码,并更新KF的状态。
                        """
                                            
                        kf_mean, kf_covariance = kf.predict(kf_mean, kf_covariance)
                        kf_mean, kf_covariance = kf.update(kf_mean, kf_covariance, [center_x, center_y, kf_ratio_range, kf_width_range])
                ###############################################################################################################################################################
                
                if len(positions) > 12:
                    positions.pop(0)
                # Write the frame to the output video and save the coordinates
                out.write(overlayed_frame if overlay else frames_buffer[0])
                frame_number += 1
                print(f'frame_number: {frame_number}/{frame_count}')
            frames_buffer = []  # Clear the buffer for the next set of frames

    # Release everything if job is finished
    cap.release()
    out.release()
    cv2.destroyAllWindows()

    """ bounce detection """

    trajectory = pd.DataFrame(coordinate_history, columns=['x', 'y'])
    # print(trajectory)
    bounces, ix_5, x,y = detect_bounces(trajectory, output_csv_path, path_to_video=output_path, path_to_output_video='e:/tennis_v2/inference/HRnet_test_v3_with_bounce.mp4', save=save)
        # print(bounces)
    # if len(coordinate_history) > 15:
    #     print((center_x, center_y), coordinate_history)
    #     bounce_data = resize_feature(np.array(coordinate_history))
    #     bounces = bounce_detector.predict(bounce_data[:, 0], bounce_data[:, 1])
    #     if bounces:
    #         bounce_record.append([center_x, center_y])
    #         coordinate_history = []
    #         # draw bounce
    #         cv2.circle(overlayed_frame if overlay else frames_buffer[i], (int(center_x), int(center_y)), 20, (255, 0, 0), 3)



In [8]:
csv = 'test'
video_input_path = 'e:/TennisProject-main/inference/test.mp4'
video_output_path = 'e:/tennis_v2/inference/HRnet_v3_' + video_input_path.split('/')[-1]
output_csv_path = 'e:/tennis_v2/inference/HRnet_v3_' + csv + '.csv'

run_inference(video_input_path, video_output_path, output_csv_path)

e:/tennis_v2/inference/HRnet_v3_test.mp4
frame_number: 1/266
frame_number: 2/266
frame_number: 3/266
frame_number: 4/266
frame_number: 5/266
frame_number: 6/266
frame_number: 7/266
frame_number: 8/266
frame_number: 9/266
frame_number: 10/266
frame_number: 11/266
frame_number: 12/266
frame_number: 13/266
frame_number: 14/266
frame_number: 15/266
frame_number: 16/266
frame_number: 17/266
frame_number: 18/266
frame_number: 19/266
frame_number: 20/266
frame_number: 21/266
frame_number: 22/266
frame_number: 23/266
frame_number: 24/266
frame_number: 25/266
frame_number: 26/266
frame_number: 27/266
frame_number: 28/266
frame_number: 29/266
frame_number: 30/266
frame_number: 31/266
frame_number: 32/266
frame_number: 33/266
frame_number: 34/266
frame_number: 35/266
frame_number: 36/266
frame_number: 37/266
frame_number: 38/266
frame_number: 39/266
frame_number: 40/266
frame_number: 41/266
frame_number: 42/266
frame_number: 43/266
frame_number: 44/266
frame_number: 45/266
frame_number: 46/266
fr