In [1]:
import cv2 as cv
import matplotlib.pyplot as plt
import os
import numpy as np
import mediapipe as mp
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, Conv1D, MaxPooling1D, Flatten, TimeDistributed
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
import tensorflow as tf

# MediaPipe setup
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose

def mediapipe_detection(image, model):
    image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
    image.flags.writeable = False
    results = model.process(image)
    image.flags.writeable = True
    image = cv.cvtColor(image, cv.COLOR_RGB2BGR)
    return image, results

def draw_styled_landmarks(image, results):
    # Draw pose connections for cricket analysis
    if results.pose_landmarks:
        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
                                 mp_drawing.DrawingSpec(color=(80,22,10), thickness=2, circle_radius=4), 
                                 mp_drawing.DrawingSpec(color=(80,44,121), thickness=2, circle_radius=2))

def extract_keypoints(results):
    """Extract pose keypoints - focusing on pose for cricket shots"""
    pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
    return pose

# Configuration
DATA_PATH = '/home/smayan/Desktop/Cricket Pose Estimation /Data'  # Your cricket dataset path
sequence_length = 30  # Frames per sequence
min_sequences_per_class = 10  # Minimum sequences to generate per class

def load_cricket_data():
    """Load cricket shot data with variable number of videos per class"""
    # Get all cricket shot classes from folder names
    actions = np.array(sorted([folder for folder in os.listdir(DATA_PATH) 
                              if os.path.isdir(os.path.join(DATA_PATH, folder))]))
    print(f"Detected cricket shots: {actions}")
    
    # Count videos per class
    for action in actions:
        action_path = os.path.join(DATA_PATH, action)
        video_files = [f for f in os.listdir(action_path) if f.endswith(('.mp4', '.avi', '.mov'))]
        print(f"{action}: {len(video_files)} videos")
    
    return actions

def extract_sequences_from_videos(actions):
    """Extract sequences from videos with data augmentation for classes with fewer videos"""
    sequences = []
    labels = []
    label_map = {label: num for num, label in enumerate(actions)}
    
    with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
        for action in actions:
            action_path = os.path.join(DATA_PATH, action)
            video_files = [f for f in os.listdir(action_path) if f.endswith(('.mp4', '.avi', '.mov'))]
            
            action_sequences = []
            
            for video_file in video_files:
                video_path = os.path.join(action_path, video_file)
                cap = cv.VideoCapture(video_path)
                
                # Get total frames in video
                total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
                
                # Extract multiple sequences from each video using sliding window
                stride = max(1, sequence_length // 4)  # Overlapping sequences
                
                for start_frame in range(0, total_frames - sequence_length + 1, stride):
                    cap.set(cv.CAP_PROP_POS_FRAMES, start_frame)
                    sequence = []
                    
                    for frame_idx in range(sequence_length):
                        ret, frame = cap.read()
                        if not ret:
                            break
                        
                        # Resize frame for consistency
                        frame = cv.resize(frame, (640, 480))
                        
                        _, results = mediapipe_detection(frame, holistic)
                        keypoints = extract_keypoints(results)
                        sequence.append(keypoints)
                    
                    if len(sequence) == sequence_length:
                        action_sequences.append(sequence)
                
                cap.release()
            
            # Data augmentation for classes with fewer sequences
            while len(action_sequences) < min_sequences_per_class:
                # Add augmented versions (you can implement more sophisticated augmentation)
                if action_sequences:
                    # Simple augmentation: add noise
                    original_seq = np.array(action_sequences[len(action_sequences) % len(action_sequences)])
                    noise = np.random.normal(0, 0.01, original_seq.shape)
                    augmented_seq = original_seq + noise
                    action_sequences.append(augmented_seq.tolist())
            
            # Add sequences and labels
            for seq in action_sequences:
                sequences.append(seq)
                labels.append(label_map[action])
            
            print(f"Generated {len(action_sequences)} sequences for {action}")
    
    return np.array(sequences), np.array(labels), label_map

def create_hybrid_cnn_lstm_model(input_shape, num_classes):
    """Create a Hybrid CNN-LSTM model for cricket pose estimation"""
    model = Sequential()
    
    # CNN layers for spatial feature extraction
    model.add(TimeDistributed(Conv1D(64, kernel_size=3, activation='relu'), 
                             input_shape=input_shape))
    model.add(TimeDistributed(Conv1D(64, kernel_size=3, activation='relu')))
    model.add(TimeDistributed(MaxPooling1D(pool_size=2)))
    model.add(TimeDistributed(Dropout(0.25)))
    
    model.add(TimeDistributed(Conv1D(128, kernel_size=3, activation='relu')))
    model.add(TimeDistributed(Conv1D(128, kernel_size=3, activation='relu')))
    model.add(TimeDistributed(MaxPooling1D(pool_size=2)))
    model.add(TimeDistributed(Dropout(0.25)))
    
    # Flatten the CNN output for LSTM
    model.add(TimeDistributed(Flatten()))
    
    # LSTM layers for temporal sequence modeling
    model.add(LSTM(128, return_sequences=True, dropout=0.3, recurrent_dropout=0.3))
    model.add(LSTM(64, return_sequences=False, dropout=0.3, recurrent_dropout=0.3))
    
    # Dense layers for classification
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    
    return model

# Main training pipeline
def train_cricket_model():
    # Load data
    actions = load_cricket_data()
    print("Extracting sequences from videos...")
    X, y, label_map = extract_sequences_from_videos(actions)
    
    print(f"Dataset shape: {X.shape}")
    print(f"Labels shape: {y.shape}")
    
    # Reshape X for CNN input (samples, timesteps, features, 1)
    X = X.reshape(X.shape[0], X.shape[1], X.shape[2], 1)
    
    # Convert labels to categorical
    y_categorical = to_categorical(y, num_classes=len(actions))
    
    # Train-test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y_categorical, test_size=0.2, random_state=42, stratify=y
    )
    
    print(f"Training set: {X_train.shape}")
    print(f"Test set: {X_test.shape}")
    
    # Handle class imbalance
    class_weights = compute_class_weight(
        'balanced', 
        classes=np.unique(y), 
        y=y
    )
    class_weight_dict = dict(enumerate(class_weights))
    
    # Create model
    model = create_hybrid_cnn_lstm_model(
        input_shape=(sequence_length, X.shape[2], 1), 
        num_classes=len(actions)
    )
    
    # Compile model
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Callbacks
    callbacks = [
        TensorBoard(log_dir='logs/cricket_model'),
        EarlyStopping(patience=15, restore_best_weights=True),
        ReduceLROnPlateau(factor=0.5, patience=10, min_lr=1e-7)
    ]
    
    # Train model
    history = model.fit(
        X_train, y_train,
        epochs=10,
        batch_size=16,
        validation_data=(X_test, y_test),
        callbacks=callbacks,
        class_weight=class_weight_dict,
        verbose=1
    )
    
    # Evaluate model
    test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
    print(f"Test Accuracy: {test_accuracy:.4f}")
    
    # Save model
    model.save('cricket_pose_model.h5')
    model.save('cricket_pose_model.keras')
    
    # Save label mapping
    np.save('cricket_label_map.npy', label_map)
    
    return model, history, label_map

# Real-time prediction function
def real_time_cricket_prediction():
    """Real-time cricket shot prediction"""
    model = tf.keras.models.load_model('cricket_pose_model.h5')
    label_map = np.load('cricket_label_map.npy', allow_pickle=True).item()
    actions = list(label_map.keys())
    
    # Prediction variables
    sequence = []
    predictions = []
    threshold = 0.7
    
    cap = cv.VideoCapture(0)  # Use webcam
    
    with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            # Make detections
            image, results = mediapipe_detection(frame, holistic)
            draw_styled_landmarks(image, results)
            
            # Extract keypoints
            keypoints = extract_keypoints(results)
            sequence.append(keypoints)
            sequence = sequence[-sequence_length:]
            
            if len(sequence) == sequence_length:
                # Reshape for model input
                input_seq = np.expand_dims(np.array(sequence), axis=0)
                input_seq = input_seq.reshape(1, sequence_length, -1, 1)
                
                # Make prediction
                res = model.predict(input_seq, verbose=0)[0]
                predicted_action = actions[np.argmax(res)]
                confidence = np.max(res)
                
                # Display prediction
                if confidence > threshold:
                    cv.putText(image, f'{predicted_action}: {confidence:.2f}', 
                              (10, 50), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
                
                # Visualization of probabilities
                for i, (action, prob) in enumerate(zip(actions, res)):
                    y_pos = 100 + i * 30
                    cv.rectangle(image, (10, y_pos), (int(prob * 300) + 10, y_pos + 25), 
                               (0, 255, 0), -1)
                    cv.putText(image, f'{action}: {prob:.2f}', (15, y_pos + 18), 
                              cv.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
            
            cv.imshow('Cricket Pose Estimation', image)
            
            if cv.waitKey(10) & 0xFF == ord('q'):
                break
    
    cap.release()
    cv.destroyAllWindows()

if __name__ == "__main__":
    # Train the model
    model, history, label_map = train_cricket_model()
    
    # Optionally run real-time prediction
    # real_time_cricket_prediction()

2025-07-27 20:16:03.701215: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-27 20:16:03.708034: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753627563.716105    9550 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753627563.718472    9550 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753627563.724704    9550 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Detected cricket shots: ['Backfoot punch' 'Cover drive' 'Cut Shot' 'FBD' 'Flick'
 'Front Food defence' 'On Drive' 'Pull Shot' 'Reverse Sweep'
 'Straight Drive' 'Sweep' 'Uppercut' 'loft']
Backfoot punch: 19 videos
Cover drive: 29 videos
Cut Shot: 43 videos
FBD: 15 videos
Flick: 22 videos
Front Food defence: 32 videos
On Drive: 40 videos
Pull Shot: 40 videos
Reverse Sweep: 30 videos
Straight Drive: 25 videos
Sweep: 27 videos
Uppercut: 29 videos
loft: 31 videos
Extracting sequences from videos...


I0000 00:00:1753627564.899793    9550 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1753627564.957533   12295 gl_context.cc:369] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 570.172.08), renderer: NVIDIA GeForce RTX 4070 SUPER/PCIe/SSE2
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1753627564.991091   12274 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1753627565.020352   12294 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1753627565.021698   12271 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1753627565.022071   12290 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature 

KeyboardInterrupt: 