In [1]:
import cv2
import numpy as np
import json
import os
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam

# Đường dẫn tới thư mục dữ liệu
DATA_PATH = r'C:\Users\USER\Desktop\golf_swing\golf_dataset'
VIDEO_PATH = os.path.join(DATA_PATH, 'videos')
ANNOTATION_PATH = os.path.join(DATA_PATH, 'annotations')
POSSIBLE_ERRORS_PATH = os.path.join(DATA_PATH, 'possible_errors.json')

# Hàm trích xuất tổng số frame từ video
def get_total_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f'Error: Cannot open video file {video_path}')
        return 0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return total_frames

# Hàm trích xuất frame và gắn nhãn lỗi
def extract_frames_and_error_labels(video_file, annotation_file, phase_classifier):
    frames = []
    labels = []
    video_path = os.path.join(VIDEO_PATH, video_file)
    cap = cv2.VideoCapture(video_path)
    
    if not cap.isOpened():
        print(f'Error: Cannot open video file {video_path}')
        return np.array([]), np.array([])
    
    try:
        with open(annotation_file, 'r') as f:
            annotations = json.load(f)
    except FileNotFoundError:
        print(f'Error: Annotation file {annotation_file} not found')
        return np.array([]), np.array([])
    except json.JSONDecodeError:
        print(f'Error: Invalid JSON format in {annotation_file}')
        return np.array([]), np.array([])
    
    total_frames = get_total_frames(video_path)
    frame_count = 0
    labeled_frames = 0
    
    while cap.isOpened() and frame_count < total_frames:
        ret, frame = cap.read()
        if not ret:
            break
        frame_resized = cv2.resize(frame, (128, 128))
        frames.append(frame_resized)
        
        frame_input = frame_resized / 255.0
        frame_input = np.expand_dims(frame_input, axis=0)
        phase_pred = phase_classifier.predict(frame_input, verbose=0)
        phase = list(possible_errors.keys())[np.argmax(phase_pred)]
        
        label = np.zeros(num_errors)
        for phase_key, info in annotations.items():
            try:
                if phase_key == phase and info['start_frame'] <= frame_count <= info['end_frame']:
                    for error in info['errors']:
                        if error in error_to_index:
                            label[error_to_index[error]] = 1
                    labeled_frames += 1
            except KeyError:
                print(f'Error: Invalid phase data in {annotation_file} for phase {phase_key}')
                continue
        labels.append(label)
        frame_count += 1
    
    cap.release()
    print(f'Info: {labeled_frames}/{total_frames} frames labeled in {video_file}')
    if labeled_frames == 0:
        print(f'Error: No frames labeled in {video_file}')
    return np.array(frames), np.array(labels)

# Load toàn bộ dữ liệu
def load_error_dataset(phase_classifier):
    all_frames = []
    all_labels = []
    
    if not os.path.exists(VIDEO_PATH):
        print(f'Error: Directory {VIDEO_PATH} does not exist')
        return np.array([]), np.array([])
    
    video_files = [f for f in os.listdir(VIDEO_PATH) if f.endswith('.mp4')]
    if not video_files:
        print(f'Error: No .mp4 files found in {VIDEO_PATH}')
        return np.array([]), np.array([])
    
    for video_file in video_files:
        video_path = os.path.join(VIDEO_PATH, video_file)
        annotation_file = os.path.join(ANNOTATION_PATH, f'{video_file[:-4]}_error.json')
        if not os.path.exists(annotation_file):
            print(f'Error: No annotation file for {video_file}')
            continue
        frames, labels = extract_frames_and_error_labels(video_path, annotation_file, phase_classifier)
        if len(frames) > 0 and len(labels) > 0:
            all_frames.append(frames)
            all_labels.append(labels)
        else:
            print(f'Warning: No valid frames or labels for {video_file}')
    
    if not all_frames:
        print('Error: No valid data loaded')
        return np.array([]), np.array([])
    
    return np.concatenate(all_frames), np.concatenate(all_labels)

# Load possible errors và tạo từ điển ánh xạ
try:
    with open(POSSIBLE_ERRORS_PATH, 'r') as f:
        possible_errors = json.load(f)
except FileNotFoundError:
    print(f'Error: {POSSIBLE_ERRORS_PATH} not found')
    exit()
except json.JSONDecodeError:
    print(f'Error: Invalid JSON format in {POSSIBLE_ERRORS_PATH}')
    exit()

error_to_index = {}
index = 0
for phase, errors in possible_errors.items():
    for error in errors:
        error_to_index[error] = index
        index += 1
num_errors = len(error_to_index)

def main():
    # Load mô hình phase classifier
    try:
        phase_classifier = load_model('phase_classifier.h5')
    except FileNotFoundError:
        print('Error: phase_classifier.h5 not found. Run model_phase.py first.')
        return

    # Chuẩn bị dữ liệu
    frames, labels = load_error_dataset(phase_classifier)
    if len(frames) == 0:
        print('Exiting due to no valid data')
        return
    
    frames = frames / 255.0  # Chuẩn hóa giá trị pixel

    # Chia tập train/test
    X_train, X_test, y_train, y_test = train_test_split(frames, labels, test_size=0.2, random_state=42)

    # Data augmentation
    datagen = ImageDataGenerator(
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True
    )
    datagen.fit(X_train)

    # Xây dựng mô hình CNN
    model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Conv2D(128, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Flatten(),
        Dense(256, activation='relu'),
        Dropout(0.5),
        Dense(num_errors, activation='sigmoid')  # Multi-label output
    ])

    model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

    # Huấn luyện mô hình với augmentation
    model.fit(datagen.flow(X_train, y_train, batch_size=32), 
              epochs=15, 
              validation_data=(X_test, y_test))

    # Lưu mô hình
    model.save('error_classifier.h5')

    # Đánh giá mô hình
    loss, accuracy = model.evaluate(X_test, y_test)
    print(f'Test accuracy: {accuracy:.4f}')

if __name__ == '__main__':
    main()



Info: 18/75 frames labeled in C:\Users\USER\Desktop\golf_swing\golf_dataset\videos\video_001.mp4
Info: 13/54 frames labeled in C:\Users\USER\Desktop\golf_swing\golf_dataset\videos\video_002.mp4
Info: 19/79 frames labeled in C:\Users\USER\Desktop\golf_swing\golf_dataset\videos\video_003.mp4
Info: 13/55 frames labeled in C:\Users\USER\Desktop\golf_swing\golf_dataset\videos\video_004.mp4
Info: 13/54 frames labeled in C:\Users\USER\Desktop\golf_swing\golf_dataset\videos\video_005.mp4
Epoch 1/15


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 241ms/step - accuracy: 0.0545 - loss: 0.6353 - val_accuracy: 0.0312 - val_loss: 0.5781
Epoch 2/15
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 221ms/step - accuracy: 0.0756 - loss: 0.0694 - val_accuracy: 0.0312 - val_loss: 0.7766
Epoch 3/15
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 220ms/step - accuracy: 0.0937 - loss: 0.0636 - val_accuracy: 0.0000e+00 - val_loss: 1.1166
Epoch 4/15
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 220ms/step - accuracy: 0.0685 - loss: 0.0360 - val_accuracy: 0.0000e+00 - val_loss: 1.5001
Epoch 5/15
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 219ms/step - accuracy: 0.1437 - loss: 0.0650 - val_accuracy: 0.0000e+00 - val_loss: 1.8639
Epoch 6/15
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 218ms/step - accuracy: 0.1212 - loss: 0.0338 - val_accuracy: 0.0000e+00 - val_loss: 2.1777
Epoch 7/15
[1m8/8[0m [32m━━━━━━━



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 0.0000e+00 - loss: 3.3193
Test accuracy: 0.0000
