In [None]:
# Install MediaPipe
!pip install mediapipe

from google.colab import drive
import os
import cv2
import mediapipe as mp
import csv
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
from torchvision.models.segmentation import deeplabv3_resnet101

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Check if Google Drive is mounted correctly
if not os.path.exists("/content/drive/MyDrive"):
    raise Exception("Google Drive not mounted. Please mount it using `drive.mount('/content/drive')`.")

# Define video paths
video_paths = [
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr1_NIE.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/JoeSupportSide.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/JoeStrongSide.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr2_NIE.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te1_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te2_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te2_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te2_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te3_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te3_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te4_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te4_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te5_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te5_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te6_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te6_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te7_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te7_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te8_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te8_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te9_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te9_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te10_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te10_2.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te11_1.MOV",
    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Novice/Videos/Te11_2.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr3_Anticipating.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr4_TooLittleTriggerFinger.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr5_TooMuchTriggerFinger.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr6_OvergrippingWithPrimaryHand.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr7_OvergrippingWithSecondaryHand.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr8_BreakingWristUp.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr9_CheckingTargetOverSightsBetweenShots.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr10_JerkingTrigger.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr11_LimpWristing.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr13_SightsIncorrectlyAlligned.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr12_GrippingTooLow.MOV",
#    "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Extra Data/Experienced/Video/Tr14_IncorrectHandPlacement.MOV",

    # Add other video paths here
]

# Check if video files exist
for video_path in video_paths:
    if not os.path.exists(video_path):
        print(f"Warning: Video file not found - {video_path}")
    else:
        print(f"Video file found: {video_path}")

# Define base output directory
base_output_dir = "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Processed_Frames"
os.makedirs(base_output_dir, exist_ok=True)

# Initialize MediaPipe solutions
mp_pose = mp.solutions.pose
mp_hands = mp.solutions.hands
mp_face = mp.solutions.face_mesh
mp_drawing = mp.solutions.drawing_utils

# Load DeepLabV3 model for background removal
def load_deeplabv3():
    model = deeplabv3_resnet101(pretrained=True, progress=True)
    model.eval()  # Set the model to evaluation mode
    return model

# Preprocess the frame for DeepLabV3
def preprocess_frame(frame):
    transform = T.Compose([
        T.ToTensor(),  # Convert to tensor
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
    ])
    return transform(frame).unsqueeze(0)  # Add batch dimension

# Remove background using DeepLabV3
def remove_background(frame, model):
    # Preprocess the frame
    input_tensor = preprocess_frame(frame)

    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)['out'][0]

    # Get the segmentation mask (class 15 is the "person" class in DeepLabV3)
    mask = output.argmax(0) == 15  # Mask for the "person" class

    # Convert mask to numpy array
    mask = mask.cpu().numpy().astype(np.uint8) * 255

    # Apply mask to the original frame
    masked_frame = cv2.bitwise_and(frame, frame, mask=mask)

    return masked_frame, mask

# Function to save keypoints to a CSV file
def save_keypoints_to_csv(keypoints, output_path):
    with open(output_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["frame", "type", "landmark_id", "x", "y", "z"])
        for frame, data in keypoints.items():
            for keypoint_type, landmarks in data.items():
                for landmark_id, coords in enumerate(landmarks):
                    writer.writerow([frame, keypoint_type, landmark_id, coords[0], coords[1], coords[2]])

# Function to save keypoints to a JSON file
def save_keypoints_to_json(keypoints, output_path):
    with open(output_path, 'w') as file:
        json.dump(keypoints, file, indent=4)

# Function to visualize keypoints and save as an image
def visualize_keypoints(keypoints, frame_shape, output_path=None):
    """
    Visualize keypoints on a blank canvas and optionally save the result as an image file.

    Args:
        keypoints (dict): Dictionary containing keypoints for pose, hands, and face.
        frame_shape (tuple): Shape of the original frame (height, width, channels).
        output_path (str, optional): Path to save the keypoints visualization. If None, the image is not saved.
    """
    # Create a blank canvas
    canvas = np.zeros(frame_shape, dtype=np.uint8)
    colors = {"pose": (0, 255, 0), "hands": (255, 0, 0), "face": (0, 0, 255)}

    # Draw keypoints on the canvas
    for keypoint_type, landmarks in keypoints.items():
        if landmarks:
            if keypoint_type == "pose":
                for landmark in landmarks:
                    x = int(landmark[0] * frame_shape[1])
                    y = int(landmark[1] * frame_shape[0])
                    cv2.circle(canvas, (x, y), radius=5, color=colors[keypoint_type], thickness=-1)
            elif keypoint_type == "hands" or keypoint_type == "face":
                for landmark_group in landmarks:
                    for landmark in landmark_group:
                        x = int(landmark[0] * frame_shape[1])
                        y = int(landmark[1] * frame_shape[0])
                        cv2.circle(canvas, (x, y), radius=5, color=colors[keypoint_type], thickness=-1)

    # Save the keypoints visualization if an output path is provided
    if output_path:
        cv2.imwrite(output_path, canvas)
        print(f"Saved keypoints visualization to: {output_path}")

    # Display the visualization (optional)
    plt.figure(figsize=(10, 10))
    plt.imshow(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB))
    plt.title("Keypoints Visualization")
    plt.axis("off")
    plt.show()

# Function to process a single frame and extract keypoints
def process_frame_and_extract_keypoints(frame, pose, hands, face, frame_count, output_dir, visualize=False):
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pose_results = pose.process(frame_rgb)
    hand_results = hands.process(frame_rgb)
    face_results = face.process(frame_rgb)
    annotated_frame = frame.copy()
    keypoints = {}

    if pose_results.pose_landmarks:
        mp_drawing.draw_landmarks(annotated_frame, pose_results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
        keypoints["pose"] = [[lm.x, lm.y, lm.z] for lm in pose_results.pose_landmarks.landmark]

    if hand_results.multi_hand_landmarks:
        keypoints["hands"] = []
        for hand_landmarks in hand_results.multi_hand_landmarks:
            mp_drawing.draw_landmarks(annotated_frame, hand_landmarks, mp_hands.HAND_CONNECTIONS)
            keypoints["hands"].append([[lm.x, lm.y, lm.z] for lm in hand_landmarks.landmark])

    if face_results.multi_face_landmarks:
        keypoints["face"] = []
        for face_landmarks in face_results.multi_face_landmarks:
            mp_drawing.draw_landmarks(annotated_frame, face_landmarks, mp_face.FACEMESH_CONTOURS)
            keypoints["face"].append([[lm.x, lm.y, lm.z] for lm in face_landmarks.landmark])

    # Save the keypoints visualization as an image
    keypoints_output_path = os.path.join(output_dir, f"keypoints_frame_{frame_count:04d}.jpg")
    visualize_keypoints(keypoints, frame.shape, output_path=keypoints_output_path)

    if visualize:
        visualize_keypoints(keypoints, frame.shape)

    return annotated_frame, {frame_count: keypoints}

# Function to extract frames from a video
def extract_and_process_frames(video_path, output_dir, extract_interval=10, visualize=False):
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    video_output_dir = os.path.join(output_dir, video_name)
    os.makedirs(video_output_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}.")
        return

    # Load DeepLabV3 model
    deeplabv3_model = load_deeplabv3()

    with mp_pose.Pose(min_detection_confidence=0.28, min_tracking_confidence=0.28) as pose, \
         mp_hands.Hands(min_detection_confidence=0.346109, min_tracking_confidence=0.34609) as hands, \
         mp_face.FaceMesh(min_detection_confidence=0.3, min_tracking_confidence=0.3) as face:
        frame_count = 0
        extracted_count = 0
        all_keypoints = {}
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            if frame_count % extract_interval == 0:
                # Remove background
                masked_frame, _ = remove_background(frame, deeplabv3_model)

                # Process the frame with MediaPipe and extract keypoints
                annotated_frame, keypoints = process_frame_and_extract_keypoints(
                    masked_frame, pose, hands, face, frame_count, video_output_dir, visualize
                )

                # Save the processed frame
                output_path = os.path.join(video_output_dir, f"frame_{frame_count:04d}.jpg")
                cv2.imwrite(output_path, annotated_frame)
                print(f"Processed and saved: {output_path}")

                # Save the keypoints
                all_keypoints.update(keypoints)
                extracted_count += 1
            frame_count += 1
        cap.release()
        csv_output_path = os.path.join(video_output_dir, "keypoints.csv")
        save_keypoints_to_csv(all_keypoints, csv_output_path)
        print(f"Saved keypoints to CSV: {csv_output_path}")
        json_output_path = os.path.join(video_output_dir, "keypoints.json")
        save_keypoints_to_json(all_keypoints, json_output_path)
        print(f"Saved keypoints to JSON: {json_output_path}")
        print(f"Finished processing {extracted_count} frames from {video_path}")

# Process each video
for video_path in video_paths:
    extract_and_process_frames(video_path, base_output_dir, visualize=True)