In [None]:
import cv2
import torch
import numpy as np
from slowfast.config.defaults import get_cfg
from slowfast.models import build_model
from slowfast.utils import checkpoint as cu
from slowfast.utils.parser import load_config_file
import time
from threading import Thread, Event

class ArcheryActionDetector:
    def __init__(self, config_file="slowfast/configs/Kinetics/SLOWFAST_8x8_R50.yaml"):
        # 初始化SlowFast模型
        self.cfg = get_cfg()
        load_config_file(self.cfg, config_file)
        self.cfg.NUM_GPUS = 0  # 使用CPU推理
        self.model = build_model(self.cfg)
        cu.load_checkpoint("path/to/checkpoint.pyth", self.model)
        self.model.eval()

        # 视频流参数
        self.clip_length = 32  # 模型输入帧数
        self.cap_interval = 0.5  # 检测间隔（秒）
        self.buffer = []
        self.last_action_time = 0
        self.action_count = 0
        self.is_action_ongoing = False
        self.action_start_time = 0

        # 初始化摄像头
        self.cap = cv2.VideoCapture(0)
        self.stop_event = Event()

    def preprocess_frame(self, frame):
        """预处理单帧图像"""
        frame = cv2.resize(frame, (256, 256))
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return frame / 255.0

    def process_clip(self):
        """处理视频片段并进行推理"""
        if len(self.buffer) < self.clip_length:
            return

        # 转换为模型输入格式
        inputs = [torch.from_numpy(np.array(self.buffer)).float()]
        inputs = [i.permute(3, 0, 1, 2) for i in inputs]  # CTHW

        # 执行推理
        with torch.no_grad():
            preds = self.model(inputs)
        
        # 解析预测结果
        action_prob = torch.softmax(preds[0], dim=0)[1].item()
        return action_prob

    def camera_worker(self):
        """摄像头采集线程"""
        while not self.stop_event.is_set():
            ret, frame = self.cap.read()
            if ret:
                processed = self.preprocess_frame(frame)
                self.buffer.append(processed)
                # 保持缓冲区长度
                if len(self.buffer) > self.clip_length * 2:
                    self.buffer = self.buffer[-self.clip_length:]
                
            time.sleep(0.01)

    def detection_worker(self):
        """动作检测线程"""
        while not self.stop_event.is_set():
            current_time = time.time()
            if current_time - self.last_action_time >= self.cap_interval:
                if len(self.buffer) >= self.clip_length:
                    prob = self.process_clip()
                    self.update_action_state(prob, current_time)
                self.last_action_time = current_time
            time.sleep(0.01)

    def update_action_state(self, prob, current_time):
        """更新动作状态机"""
        threshold = 0.85  # 分类阈值
        min_duration = 1.0  # 最小有效动作持续时间（秒）

        if not self.is_action_ongoing and prob > threshold:
            # 检测到动作开始
            self.is_action_ongoing = True
            self.action_start_time = current_time
            print(f"🏹 拉弓动作开始 @ {time.strftime('%H:%M:%S')}")

        elif self.is_action_ongoing:
            if prob < threshold:
                # 动作结束
                duration = current_time - self.action_start_time
                if duration >= min_duration:
                    self.action_count += 1
                    print(f"✅ 完成拉弓动作！持续时间：{duration:.2f}s 总计：{self.action_count}次")
                else:
                    print("⚠️ 检测到短时动作（忽略计数）")
                self.is_action_ongoing = False
            else:
                # 更新持续时间显示
                duration = current_time - self.action_start_time
                print(f"⏱ 持续拉弓中... {duration:.1f}s", end='\r')

    def start(self):
        Thread(target=self.camera_worker, daemon=True).start()
        Thread(target=self.detection_worker, daemon=True).start()
        print("系统已启动，开始检测拉弓动作...")

    def stop(self):
        self.stop_event.set()
        self.cap.release()
        print("\n系统已停止")

if __name__ == "__main__":
    detector = ArcheryActionDetector()
    detector.start()
    
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        detector.stop()