In [4]:
import cv2
import numpy as np
import json
import os
from tensorflow.keras.models import load_model

# Đường dẫn tới thư mục dữ liệu
DATA_PATH = r'C:\Users\USER\Desktop\golf_swing\golf_dataset'
VIDEO_PATH = os.path.join(DATA_PATH, 'videos')
POSSIBLE_ERRORS_PATH = os.path.join(DATA_PATH, 'possible_errors.json')

# Feedback map cho từng lỗi
FEEDBACK_MAP = {
    "PoorFootAlignment": "Poor foot alignment detected. Ensure your feet are shoulder-width apart and aligned with the target line to improve stability.",
    "ImproperSpineAngle": "Improper spine angle detected. Maintain a straight spine with a slight forward tilt to ensure proper posture during setup.",
    "OpenShoulderAlignment": "Open shoulder alignment detected. Align your shoulders parallel to the target line for a more accurate swing.",
    "HeadMovement": "Excessive head movement detected. Keep your head still and eyes on the ball to maintain focus during the backswing.",
    "ArmCollapse": "Arm collapse detected. Keep your lead arm straight and maintain a consistent arc to improve swing power.",
    "OverSwing": "Over-swing detected. Shorten your backswing to maintain control and avoid losing balance.",
    "EarlyRelease": "Early release detected. Delay your wrist release until impact to maximize clubhead speed and control.",
    "HipSway": "Hip sway detected. Keep your hips stable and rotate them around your spine to maintain swing plane.",
    "OverTheTop": "Over-the-top swing detected. Start the downswing with your lower body to avoid an outside-in swing path.",
    "LossOfBalance": "Loss of balance detected. Focus on maintaining a stable base and smooth weight transfer during follow-through.",
    "IncompleteFinish": "Incomplete finish detected. Complete your swing with a full follow-through to ensure proper weight shift and balance.",
    "HeadDrop": "Head drop detected. Keep your head up and chest facing forward at the finish to maintain posture."
}

# Hàm trích xuất tổng số frame từ video
def get_total_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f'Error: Cannot open video file {video_path}')
        return 0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return total_frames

# Hàm phân tích video demo
def analyze_video(video_file):
    video_path = os.path.join(VIDEO_PATH, video_file)
    
    # Load mô hình
    try:
        phase_classifier = load_model('phase_classifier.h5')
        error_classifier = load_model('error_classifier.h5')
        print(f"Phase classifier output shape: {phase_classifier.output_shape}")
    except FileNotFoundError:
        print('Error: One or more model files (phase_classifier.h5 or error_classifier.h5) not found.')
        return
    
    # Load possible errors
    try:
        with open(POSSIBLE_ERRORS_PATH, 'r') as f:
            possible_errors = json.load(f)
        print(f"Phase order: {list(possible_errors.keys())}")
    except FileNotFoundError:
        print(f'Error: {POSSIBLE_ERRORS_PATH} not found')
        return
    except json.JSONDecodeError:
        print(f'Error: Invalid JSON format in {POSSIBLE_ERRORS_PATH}')
        return
    
    # Tạo từ điển ánh xạ lỗi
    error_to_index = {}
    index = 0
    for phase, errors in possible_errors.items():
        for error in errors:
            error_to_index[error] = index
            index += 1
    num_errors = len(error_to_index)
    index_to_error = {v: k for k, v in error_to_index.items()}
    
    # Mở video
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f'Error: Cannot open video file {video_path}')
        return
    
    total_frames = get_total_frames(video_path)
    if total_frames == 0:
        print(f'Error: No frames detected in {video_path}')
        return
    
    frame_count = 0
    phase_history = []
    error_history = []
    phase_prob_history = []
    
    while cap.isOpened() and frame_count < total_frames:
        ret, frame = cap.read()
        if not ret:
            break
        
        frame_resized = cv2.resize(frame, (128, 128))
        frame_input = frame_resized / 255.0
        frame_input = np.expand_dims(frame_input, axis=0)
        
        # Dự đoán phase
        phase_pred = phase_classifier.predict(frame_input, verbose=0)
        phase_prob_history.append(phase_pred[0].tolist())
        max_prob = np.max(phase_pred)
        if max_prob > 0.7:  # Ngưỡng xác suất
            phase = list(possible_errors.keys())[np.argmax(phase_pred)]
        else:
            phase = "unknown"
        phase_history.append(phase)
        
        # Dự đoán lỗi
        error_pred = error_classifier.predict(frame_input, verbose=0)
        error_labels = [index_to_error[i] for i in range(num_errors) if error_pred[0][i] > 0.5]
        error_history.append(error_labels if error_labels else [])
        
        frame_count += 1
    
    cap.release()
    
    # Kiểm tra nếu chỉ có một giai đoạn được phát hiện
    if len(set(phase_history)) == 1:
        print(f"Warning: Only one phase ({phase_history[0]}) detected across all frames. Check model or video input.")
    
    # In xác suất trung bình để debug
    if phase_prob_history:
        avg_probs = np.mean(phase_prob_history, axis=0)
        print(f"Average phase probabilities: {dict(zip(possible_errors.keys(), avg_probs))}")
    
    print(f'Video analysis complete for {video_file}')
    
    # Tạo phản hồi
    feedback = generate_feedback(phase_history, error_history, possible_errors)
    print('Feedback:')
    print(feedback)

# Hàm tạo phản hồi dựa trên phase và lỗi
def generate_feedback(phase_history, error_history, possible_errors):
    feedback = "Analysis of your golf swing:\n"
    phase_counts = {'setup': 0, 'backswing': 0, 'downswing': 0, 'follow-through': 0}
    error_counts = {}
    
    for phase, errors in zip(phase_history, error_history):
        phase_counts[phase] = phase_counts.get(phase, 0) + 1
        for error in errors:
            error_counts[error] = error_counts.get(error, 0) + 1
    
    for phase, count in phase_counts.items():
        if count > 0:
            feedback += f"- {phase.capitalize()} phase detected for {count} frames.\n"
            if phase in possible_errors:
                phase_errors = error_counts.keys()
                relevant_errors = [e for e in phase_errors if e in possible_errors[phase]]
                if relevant_errors:
                    feedback += f"  Common errors:\n"
                    for error in relevant_errors:
                        feedback += f"  - {FEEDBACK_MAP[error]}\n"
                else:
                    feedback += "  No significant errors detected. Good job!\n"
    
    return feedback

def main():
    video_file = 'video_demo.mp4'
    if not os.path.exists(os.path.join(VIDEO_PATH, video_file)):
        print(f'Error: {video_file} not found in {VIDEO_PATH}')
        return
    analyze_video(video_file)

if __name__ == '__main__':
    main()



Phase classifier output shape: (None, 4)
Phase order: ['setup', 'backswing', 'downswing', 'follow-through']
Average phase probabilities: {'setup': 0.0007859868245820204, 'backswing': 0.9917015512784322, 'downswing': 0.006925395857542753, 'follow-through': 0.0005870667418154577}
Video analysis complete for video_demo.mp4
Feedback:
Analysis of your golf swing:
- Backswing phase detected for 75 frames.
  Common errors:
  - Excessive head movement detected. Keep your head still and eyes on the ball to maintain focus during the backswing.
  - Arm collapse detected. Keep your lead arm straight and maintain a consistent arc to improve swing power.
  - Over-swing detected. Shorten your backswing to maintain control and avoid losing balance.

