In [None]:
from flask import Flask, jsonify
import threading
import requests
import time
import cv2
from collections import defaultdict, deque
import torch
import numpy as np
from ultralytics import YOLO
import torch.nn as nn

# Flask 서버 설정
app = Flask(__name__)
MAIN_SERVER_URL = "http://192.168.0.8:8000/api/receive"

# YOLO 모델 로드 (더 가벼운 모델 고려)
pose_model = YOLO("yolov8n-pose.pt")  
object_model = YOLO("yolov8n.pt")  

# 학습된 클래스 이름
class_names = ["running", "walking", "sitting", "lying"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LSTMPoseClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout=0.5):
        super(LSTMPoseClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        _, (hidden, _) = self.lstm(x)
        hidden = hidden[-1]
        return self.fc(hidden)

# LSTM 모델 로드
lstm_model = LSTMPoseClassifier(input_dim=17 * 2, hidden_dim=128, output_dim=len(class_names)).to(device)
lstm_model.load_state_dict(torch.load("lstm_pose_classifier2.0.pth", map_location=device))
lstm_model.eval()

class TargetDetector:
    def __init__(self):
        self.last_transmission_time = 0
        self.transmission_interval = 2  # 2초마다 전송
        self.sequence_buffers = defaultdict(lambda: deque(maxlen=96))

    def send_data_to_server(self, person_id, pose_box, predicted_class, confidence, detected_class):
        current_time = time.time()
        if current_time - self.last_transmission_time < self.transmission_interval:
            return

        data = {
            "person_id": person_id,
            "pose_box": pose_box.tolist(),
            "predicted_class": predicted_class,
            "confidence": float(confidence) if confidence is not None else 0,
            "detected_class": detected_class
        }

        try:
            response = requests.post(MAIN_SERVER_URL, json=data, timeout=1)
            if response.status_code == 200:
                self.last_transmission_time = current_time
                print(f"Data sent successfully for person {person_id}")
        except requests.exceptions.RequestException as e:
            print(f"Error sending data: {e}")

def extract_landmarks_and_boxes(frame, model):
    """
    YOLO Pose 모델로 랜드마크와 바운딩 박스 추출
    Args:
        frame (np.array): 프레임 이미지
        model: YOLO Pose 모델
    Returns:
        list: 각 사람의 [(person_id, landmarks, box)]
    """
    results = model(frame, verbose=False)
    height, width, _ = frame.shape
    people_data = []

    if len(results) > 0 and hasattr(results[0], 'keypoints'):
        keypoints = results[0].keypoints.xy.cpu().numpy()  # (N, 17, 2)
        boxes = results[0].boxes.xyxy.cpu().numpy()  # (N, 4)
        for person_id, (landmarks, box) in enumerate(zip(keypoints, boxes)):
            # 좌표를 프레임 크기로 정규화
            keypoints_normalized = landmarks / [width, height]
            people_data.append((person_id, keypoints_normalized, box))
    return people_data

def predict_class(sequence, model):
    """
    LSTM 모델로 클래스 예측
    Args:
        sequence (deque): 랜드마크 시퀀스
        model (nn.Module): 학습된 LSTM 모델
    Returns:
        str: 예측된 클래스 이름
    """
    if len(sequence) < 96:
        return "Waiting for data..."

    input_tensor = torch.tensor(np.array(sequence), dtype=torch.float32).unsqueeze(0).to(device)
    input_tensor = input_tensor.view(input_tensor.size(0), input_tensor.size(1), -1)

    with torch.no_grad():
        output = model(input_tensor)
        _, predicted = torch.max(output, 1)
    return class_names[predicted.item()]

@app.route("/classification", methods=["GET"])
def send_data():
    detector = TargetDetector()
    cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 320)   # 낮은 해상도
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 240)  # 낮은 해상도
    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1
        if frame_count % 2 != 0:  # 프레임 스키핑
            continue

        # YOLO Pose로 랜드마크와 박스 추출
        people_data = extract_landmarks_and_boxes(frame, pose_model)

        for person_id, landmarks, pose_box in people_data:
            detector.sequence_buffers[person_id].append(landmarks)
            predicted_class = predict_class(detector.sequence_buffers[person_id], lstm_model)

            # 서버로 데이터 전송 (2초 간격)
            detector.send_data_to_server(
                person_id, 
                pose_box, 
                predicted_class, 
                None,  # 신뢰도 임시 제거 
                "person"
            )

            # 시각화 코드 (선택적)
            x1, y1, x2, y2 = map(int, pose_box)
            cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(
                frame,
                f"Pose: {predicted_class}",
                (x1, y1 - 10),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (255, 0, 0),
                2,
            )

        cv2.imshow("Pose Detection", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == "__main__":
    thread = threading.Thread(target=send_data, daemon=True)
    thread.start()
    app.run(host="0.0.0.0", port=5000, debug=False, threaded=True)

 * Serving Flask app '__main__'
 * Debug mode: off


  return torch._C._cuda_getDeviceCount() > 0
 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://192.168.0.9:5000
[33mPress CTRL+C to quit[0m


Data sent successfully for person 0


qt.qpa.plugin: Could not find the Qt platform plugin "wayland" in "/home/hyun/venv/torch_venv/lib/python3.12/site-packages/cv2/qt/plugins"


Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Error sending data: HTTPConnectionPool(host='192.168.0.8', port=8000): Read timed out. (read timeout=1)
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
Data sent successfully for person 0
