!pip install --upgrade pip

!pip install -r requirements.txt

## Mediapipe Data Preprocessing

### Keypoint Selection

We have to make a selection of which keypoints we choose for the model, since taking all keypoints would mean more than 600 keypoints of which 500 are of the face. So we need to select! In the signgraph paper from https://ieeexplore.ieee.org/abstract/document/10049842 they have this to say:

"This model generated a total of 540+ landmarks, out of
which we have used data for only 65 landmarks. These
65 landmarks consist of pose information for both hands,
arms, body torso and some significant facial nodes like eyes,
nose, ears, and lips. We have discarded all the remaining land-
marks because they were providing no additional information
in our model."

So this is what we aim for.

In [6]:
import os
os.listdir("data")

['datasets', 'noorstorage']

In [1]:
# First import and initialize everything needed

import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
import numpy as np
import os
import tqdm as tqdm
import pandas as pd

mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_holistic = mp.solutions.holistic     #There are many different models you can use, like just face, hand or pose detections!

gloss_path = 'data/noorstorage/CorpusNGT/gloss_split_8frames'

In [2]:
# Found online a list of all the face mesh keypoints and their corresponding numbers so we are
# able to pick the useful ones.

from mediapipe.python.solutions.pose import PoseLandmark
from mediapipe.python.solutions.drawing_utils import DrawingSpec
from mediapipe.python.solutions.face_mesh_connections import FACEMESH_CONTOURS
from mediapipe.python.solutions.face_mesh_connections import FACEMESH_TESSELATION

MESH_ANNOTATIONS = {
  "silhouette": [
    10,  338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
    397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
    172, 58,  132, 93,  234, 127, 162, 21,  54,  103, 67,  109
  ],

  "lipsUpperOuter": [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291], #keep this for mouth shape
  "lipsLowerOuter": [146, 91, 181, 84, 17, 314, 405, 321, 375, 291], #keep this for mouth shape
  "lipsUpperInner": [78, 191, 80, 81, 82, 13, 312, 311, 310, 415, 308], #keep this for mouth shape
  "lipsLowerInner": [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308], #keep this for mouth shape

  "rightEyeUpper0": [246, 161, 160, 159, 158, 157, 173], #keep this for eye location and opening
  "rightEyeLower0": [33, 7, 163, 144, 145, 153, 154, 155, 133], #keep this for eye location and opening
  "rightEyeUpper1": [247, 30, 29, 27, 28, 56, 190],
  "rightEyeLower1": [130, 25, 110, 24, 23, 22, 26, 112, 243],
  "rightEyeUpper2": [113, 225, 224, 223, 222, 221, 189],
  "rightEyeLower2": [226, 31, 228, 229, 230, 231, 232, 233, 244],
  "rightEyeLower3": [143, 111, 117, 118, 119, 120, 121, 128, 245],

  "rightEyebrowUpper": [156, 70, 63, 105, 66, 107, 55, 193], #keep this for eyebrow location and raising
  "rightEyebrowLower": [35, 124, 46, 53, 52, 65], #keep this for eyebrow location and raising

  "rightEyeIris": [473, 474, 475, 476, 477], #keep this for eye direction

  "leftEyeUpper0": [466, 388, 387, 386, 385, 384, 398], #keep this for eye location and opening
  "leftEyeLower0": [263, 249, 390, 373, 374, 380, 381, 382, 362], #keep this for eye location and opening
  "leftEyeUpper1": [467, 260, 259, 257, 258, 286, 414],
  "leftEyeLower1": [359, 255, 339, 254, 253, 252, 256, 341, 463],
  "leftEyeUpper2": [342, 445, 444, 443, 442, 441, 413],
  "leftEyeLower2": [446, 261, 448, 449, 450, 451, 452, 453, 464],
  "leftEyeLower3": [372, 340, 346, 347, 348, 349, 350, 357, 465],

  "leftEyebrowUpper": [383, 300, 293, 334, 296, 336, 285, 417], #keep this for eyebrow location and raising
  "leftEyebrowLower": [265, 353, 276, 283, 282, 295], #keep this for eyebrow location and raising

  "leftEyeIris": [468, 469, 470, 471, 472], #keep this for eye direction

  "midwayBetweenEyes": [168],

  "noseTip": [1], #keep this for nose location
  "noseBottom": [2], #keep this for nose location
  "noseRightCorner": [98], #keep this for nose location
  "noseLeftCorner": [327], #keep this for nose location

  "rightCheek": [205],
  "leftCheek": [425]
};

face_mesh_annotations_to_use = {
    61: "lipsUpperOuter1",
    37: "lipsUpperOuter2",
    267: "lipsUpperOuter3",
    291: "lipsUpperOuter4",
    146: "lipsLowerOuter1",
    181: "lipsLowerOuter2",
    17: "lipsLowerOuter3",
    405: "lipsLowerOuter4",
    375: "lipsLowerOuter5",
    78: "lipsUpperInner1",
    81: "lipsUpperInner2",
    311: "lipsUpperInner3",
    308: "lipsUpperInner4",
    88: "lipsLowerInner1",
    87: "lipsLowerInner2",
    317: "lipsLowerInner3",
    318: "lipsLowerInner4",
    246: "rightEyeUpper1",
    160: "rightEyeUpper2",
    158: "rightEyeUpper3",
    173: "rightEyeUpper4",
    163: "rightEyeLower1",
    145: "rightEyeLower2",
    154: "rightEyeLower3",
    70: "rightEyebrowUpper1",
    105: "rightEyebrowUpper2",
    107: "rightEyebrowUpper3",
    46: "rightEyebrowLower1",
    52: "rightEyebrowLower2",
    55: "rightEyebrowLower3",
    473: "rightEyeIris",
    466: "leftEyeUpper1",
    387: "leftEyeUpper2",
    385: "leftEyeUpper3",
    398: "leftEyeUpper4",
    390: "leftEyeLower1",
    374: "leftEyeLower2",
    381: "leftEyeLower3",
    300: "leftEyebrowUpper1",
    334: "leftEyebrowUpper2",
    336: "leftEyebrowUpper3",
    276: "leftEyebrowLower1",
    282: "leftEyebrowLower2",
    285: "leftEyebrowLower3",
    468: "leftEyeIris",
    4: "noseTip"
}

print(face_mesh_annotations_to_use[4])

# list of landmarks to exclude from the drawing
excluded_landmarks = [
    PoseLandmark.LEFT_EYE, 
    PoseLandmark.RIGHT_EYE, 
    PoseLandmark.LEFT_EYE_INNER, 
    PoseLandmark.RIGHT_EYE_INNER, 
    PoseLandmark.LEFT_EAR,
    PoseLandmark.RIGHT_EAR,
    PoseLandmark.LEFT_EYE_OUTER,
    PoseLandmark.RIGHT_EYE_OUTER,
    PoseLandmark.NOSE,
    PoseLandmark.MOUTH_LEFT,
    PoseLandmark.MOUTH_RIGHT,
    PoseLandmark.LEFT_KNEE,
    PoseLandmark.RIGHT_KNEE,
    PoseLandmark.LEFT_ANKLE,
    PoseLandmark.RIGHT_ANKLE,
    PoseLandmark.LEFT_HEEL,
    PoseLandmark.RIGHT_HEEL,
    PoseLandmark.LEFT_FOOT_INDEX,
    PoseLandmark.RIGHT_FOOT_INDEX,
    PoseLandmark.LEFT_PINKY,
    PoseLandmark.RIGHT_PINKY,
    PoseLandmark.LEFT_INDEX,
    PoseLandmark.RIGHT_INDEX,
    PoseLandmark.LEFT_THUMB,
    PoseLandmark.RIGHT_THUMB]

pose_landmarks_to_use = [
    PoseLandmark.LEFT_SHOULDER,
    PoseLandmark.RIGHT_SHOULDER,
    PoseLandmark.LEFT_ELBOW,
    PoseLandmark.RIGHT_ELBOW,
    PoseLandmark.LEFT_WRIST,
    PoseLandmark.RIGHT_WRIST,
    PoseLandmark.LEFT_HIP,
    PoseLandmark.RIGHT_HIP]

# For hands we keep all of the keypoints
face_mesh_keeplist = [61, 37, 267, 291, 146, 181, 17, 405, 375, 78, 81, 311, 308, 88, 87, 317, 318, 
                      246, 160, 158, 173, 163, 145, 154, 70, 105, 107, 46, 52, 55, 473, 
                     466, 387, 385, 398, 390, 374, 381, 300, 334, 336, 276, 282, 285, 468, 4]   # 48
pose_keeplist = [11, 12, 13, 14, 15, 16, 23, 24]
hands_keeplist = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

print(len(face_mesh_keeplist)+len(pose_keeplist)+len(hands_keeplist)+len(hands_keeplist))
print(len(face_mesh_keeplist))
print(len(hands_keeplist))
# We only want to keep these keypoints so we make a list of them so we can check the results against this list later.

noseTip
96
46
21


## Testing on one video

In [3]:
## Test on one video to make sure pose estimation is working properly. The pose estimation will start up in another window, 
## check if the model is estimating properly for every frame.

# Initialize MP pose.
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_holistic = mp.solutions.holistic

holistic = mp_holistic.Holistic(static_image_mode=False, # Makes the model treat the input as a video
                    model_complexity=2,
                    enable_segmentation=True,
                    min_detection_confidence=0.5,
                    refine_face_landmarks=True)

custom_style = mp_drawing_styles.get_default_pose_landmarks_style()
custom_connections = list(mp_holistic.POSE_CONNECTIONS)

for landmark in excluded_landmarks:
    # we change the way the excluded landmarks are drawn
    custom_style[landmark] = DrawingSpec(color=(255,255,255), thickness=0) 
    # we remove all connections which contain these landmarks
    custom_connections = [connection_tuple for connection_tuple in custom_connections 
                            if landmark.value not in connection_tuple]

cap = cv2.VideoCapture("data/noorstorage/CorpusNGT/gloss/CNGT0004_S004.mpg")
#cap = cv2.VideoCapture(0)  ## Video stream from camera

while cap.isOpened():
    ret, image = cap.read()
    # if frame is read correctly ret is True
    if not ret:
        print("Video or stream has ended. Exiting ...")
        break
    else:
        #print("Frame read success")
        # To improve performance, optionally mark the image as not writeable to pass by reference.
        image.flags.writeable = True
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        results = holistic.process(image)
    
        # Draw landmark annotation on the image.
        mp_drawing.draw_landmarks(image,
                            results.pose_landmarks,
                            connections = custom_connections, #  passing the modified connections list
                            landmark_drawing_spec=custom_style) # and drawing style 
        mp_drawing.draw_landmarks(image,
                            results.face_landmarks,
                            mp_holistic.FACEMESH_CONTOURS,
                            landmark_drawing_spec=None,
                            connection_drawing_spec=mp_drawing_styles.get_default_face_mesh_contours_style())
        mp_drawing.draw_landmarks(image,
                            results.left_hand_landmarks,
                            mp_holistic.HAND_CONNECTIONS,
                            landmark_drawing_spec=mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=1, circle_radius=1),
                            connection_drawing_spec=mp_drawing.DrawingSpec(color=(200, 200, 200), thickness=1, circle_radius=1))
        mp_drawing.draw_landmarks(image,
                            results.right_hand_landmarks,
                            mp_holistic.HAND_CONNECTIONS,
                            landmark_drawing_spec=mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=1, circle_radius=1),
                            connection_drawing_spec=mp_drawing.DrawingSpec(color=(200, 200, 200), thickness=1, circle_radius=1))

        # Naming a window 
        cv2.namedWindow("Resized_Window", cv2.WINDOW_NORMAL) 
        cv2.resizeWindow("Resized_Window", 800, 600) 

        # Display frame
        cv2.imshow("Resized_Window", image)
        if cv2.waitKey(1) == ord('q'):
            break
            
cap.release()
cv2.destroyAllWindows()

I0000 00:00:1734362148.120735    4596 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1734362148.249959    4696 gl_context.cc:357] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 535.183.01), renderer: NVIDIA A10/PCIe/SSE2
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


## If testing is successful, run the code to estimate pose for all data

In [5]:
labels = np.load('data/noorstorage/CorpusNGT/gloss_labels.npy', allow_pickle='TRUE').item()
gloss_count = np.load('data/noorstorage/CorpusNGT/gloss_counts.npy', allow_pickle='TRUE').item()

In [6]:
# Initialize MP pose.
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_holistic = mp.solutions.holistic

holistic = mp_holistic.Holistic(static_image_mode=False, # Makes the model treat the input as a video
                    model_complexity=2,
                    enable_segmentation=True,
                    min_detection_confidence=0.5,
                    refine_face_landmarks=True)

I0000 00:00:1734362181.439532    4596 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1734362181.496799    5278 gl_context.cc:357] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 535.183.01), renderer: NVIDIA A10/PCIe/SSE2


In [7]:
print(len(os.listdir(gloss_path)))

0


In [9]:
print(gloss_path)

data/noorstorage/CorpusNGT/gloss_split_8frames


In [None]:
## DEBUGGING

## Remove previous keypoints
for catg in tqdm.tqdm(os.listdir(gloss_path)):
    f1 = '{}/{}'.format(gloss_path, catg)
    for number in os.listdir(f1):
        f2 = '{}/{}'.format(f1, number)
        if os.path.isdir(f2):
            if "frame8.jpg" in os.listdir(f2):
                os.remove('{}/{}'.format(f2, "frame8.jpg"))  ## or "keypoints"

In [None]:
# Perform pose estimation on all frames and save to dataframe

# This model generated a total of 540+ landmarks, out of which we have used data for only 65 landmarks. These 65 landmarks consist of 
# pose information for both hands, arms, body torso and some significant facial nodes like eyes, nose, ears, and lips.

# Only run this once! if you break off the code here halfway and run it again it will not work.

def CreateKeypoint(Model, frame_idx, idx, landmark_name, x, y):
    keypoint_entry = {
        'frame': frame_idx,
        'landmark model': Model,
        'landmark index': idx,
        'landmark name': landmark_name,
        'x': x,
        'y': y
    }
    return keypoint_entry

def ExtractKeypoints(results, mp_pose_type, whichlandmarkstouse, frame_idx, Model="Face"):
    # If any pose is detected
    frame_keypoints = []
    if results: #if the pose estimation was successful
        for idx, landmark in enumerate(results.landmark):
            if idx in whichlandmarkstouse:
                frame_keypoints.append(CreateKeypoint(Model, frame_idx, idx, mp_pose_type[idx] if Model=="Face" else mp_pose_type(idx).name, landmark.x, landmark.y))
    else: #if the model did not estimate properly we fill in zeroes.
        if Model=="Face":
            for (_, idx) in enumerate(mp_pose_type):
                frame_keypoints.append(CreateKeypoint(Model, frame_idx, idx, mp_pose_type[idx], 0, 0))
        elif Model=="LeftHand" or Model=="RightHand" or Model=="Pose":
            for idx in enumerate(mp_pose_type):
                frame_keypoints.append(CreateKeypoint(Model, frame_idx, idx[0], mp_pose_type(idx[0]).name, 0, 0))
    return frame_keypoints

for catg in tqdm.tqdm(os.listdir(gloss_path)):
    f1 = '{}/{}'.format(gloss_path, catg)
    for number in os.listdir(f1):
        f2 = '{}/{}'.format(f1, number)
        if os.path.isdir(f2):
            allFramesLandmarks = pd.DataFrame()
            frame_idx = 0
            if "keypoints" in os.listdir(f2):
                continue
            else:
                for frame in os.listdir(f2):
#                try:
                    f3 = '{}/{}'.format(f2, frame)
                    split = f3.split(".")
                    if split[1] == "jpg":
                        # Read image
                        img = cv2.imread(f3)
                        
                        # Process.
                        results = holistic.process(img)
    
                        landmark_list_pose = pd.DataFrame(ExtractKeypoints(
                            results.pose_landmarks, mp_holistic.PoseLandmark, pose_keeplist, frame_idx, Model="Pose"))
                        landmark_list_left = pd.DataFrame(ExtractKeypoints(
                            results.left_hand_landmarks, mp_holistic.HandLandmark, hands_keeplist, frame_idx, Model="LeftHand"))
                        landmark_list_right = pd.DataFrame(ExtractKeypoints(
                            results.right_hand_landmarks, mp_holistic.HandLandmark, hands_keeplist, frame_idx, Model="RightHand"))
                        landmark_list_face = pd.DataFrame(ExtractKeypoints(
                            results.face_landmarks, face_mesh_annotations_to_use, face_mesh_keeplist, frame_idx, Model="Face"))
    
                        allFramesLandmarks = pd.concat(
                            [allFramesLandmarks, landmark_list_pose, landmark_list_left, landmark_list_right, landmark_list_face], ignore_index=True)
                        allFramesLandmarks.sort_values(by=['frame', 'landmark model', 'landmark index'], ascending=True, inplace=True)
                        allFramesLandmarks.reset_index(drop=True, inplace=True)
    
#                except Exception as err:
#                    print(Exception, err)
#                    print(f2)
                    
                    frame_idx += 1
            
            allFramesLandmarks.to_pickle('{}/{}'.format(f2, "keypoints"), compression='infer', protocol=5, storage_options=None)
                    

In [None]:
#for idx in enumerate(mp_holistic.HandLandmark):
#    print(idx[1])

print(mp_holistic.HandLandmark(5).name)

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    print(frame_keypoints)

Congrats, you successfully ran the mediapipe pose estimation

-------

## Get Mediapipe Statistics

In [23]:
## Write some code to find out what percentage of the mediapipe data contains 'all' keypoints

import os
import shutil
import pandas as pd
import numpy as np
from numpy import asarray
import cv2 as cv
import re
import math
from subprocess import check_call, PIPE, Popen
import shlex

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
from torch import nn
import torcheval
from torcheval.metrics.functional import multiclass_auprc

from torch.utils.data import Dataset, DataLoader, Subset
import glob
from PIL import Image
import random
from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt

np.random.seed(2024)
random.seed(2024)
torch.manual_seed(2024)

train = pd.read_pickle("data/noorstorage/CorpusNGT/train_cutoff0_upperlimit999999.pkl")
test  = pd.read_pickle("data/noorstorage/CorpusNGT/test_cutoff0_upperlimit999999.pkl")
val   = pd.read_pickle("data/noorstorage/CorpusNGT/validate_cutoff0_upperlimit999999.pkl")

In [24]:
tr = (train["id"]).unique() ## So there are 157 classes in this dataset of which each have about
va = (val["id"]).unique()
te = (test["id"]).unique()
total = np.concatenate((tr, va, te), axis=0, out=None)
total = np.unique(total)
total = np.sort(total)

catgs = total
num_classes = len(catgs)
labels_dict = {}
ind = 0
for uc in catgs:
    labels_dict[uc] = ind
    ind+=1

def preprocess_datasubset(split):
    split_labels = split["id"].tolist()
    split_ids = split["path"].tolist()
    print(len(split_ids), len(split_labels))
    return split_ids, split_labels

class KeyPointDataset(Dataset):
    def __init__(self, ids, labels, transform):
        self.transform = transform
        self.ids = ids
        self.labels = labels
    def __len__(self):
        return len(self.ids)
    def __getitem__(self, idx):
        keypointfile = pd.read_pickle(glob.glob(self.ids[idx]+"/keypoints")[0], compression='infer', storage_options=None)
        keypointfile = keypointfile[["frame","x","y"]]
        label = labels_dict[self.labels[idx]]   ### WAS BUSY HERE TRYING TO FIGURE OUT KEYERROR
        try:
            frames = []
            framelist = list(keypointfile.groupby(by="frame"))
            for frame in framelist:
                frames.append(frame[1].to_numpy())
    
            seed = np.random.randint(1e9)
            frames_tr = []
            for frame in frames:
                random.seed(seed)
                np.random.seed(seed)
                frames_tr.append(torch.from_numpy(frame))
            if len(frames_tr)>0:
                frames_tr = torch.stack(frames_tr)
            return frames_tr, label
        except Exception as e:
            print(f"Error in __getitem__ at idx {idx}: {e}")
            return 0, 0


def denormalize(x_, mean, std):
    x = x_.clone()
    for i in range(3):
        x[i] = x[i]*std[i]+mean[i]
    x = to_pil_image(x)
    return x

mean = [0.43216, 0.394666, 0.37645]
std = [0.22803, 0.22145, 0.216989]

In [26]:
train_ids, train_labels = preprocess_datasubset(train)
test_ids, test_labels = preprocess_datasubset(test)
val_ids, val_labels = preprocess_datasubset(val)

80424 80424
10054 10054
10053 10053


In [28]:
train_transformer_keypoints = transforms.Compose([
            transforms.ToTensor()
            ])

train_ds = KeyPointDataset(ids= train_ids, labels= train_labels, transform= train_transformer_keypoints)

test_ds = KeyPointDataset(ids= test_ids, labels= test_labels, transform= train_transformer_keypoints)

val_ds = KeyPointDataset(ids= val_ids, labels= val_labels, transform= train_transformer_keypoints)

In [30]:
def calculate_keypoint_errors(dataset):
    """
    Calculate the error percentage for each keypoint based on the number of missed values for that keypoint.
    
    Args:
        dataset: A dataset where each sample is a tuple (img, label). 
                 `img` is expected to be a collection of frames, and each frame is a tensor of shape (96, 3).

    Returns:
        dict: A dictionary with the following keys:
            - "total_keypoints": Total number of keypoints processed (x and y coordinates).
            - "missed_keypoints": Total number of keypoints with missing values.
            - "fraction_missed": Fraction of missed keypoints.
            - "missing_percentages_per_keypoint": List of missing percentages for each keypoint (1 to 96), calculated independently.
    """
    total = 0
    missed_keypoints = 0

    # Initialize counters for each keypoint
    keypoint_miss_count = np.zeros(96, dtype=int)  # Count of misses for each keypoint
    keypoint_total_count = np.zeros(96, dtype=int)  # Total count for each keypoint

    for idx, (img, label) in enumerate(dataset):
        try:
            for tensor in img:  # Each tensor corresponds to a frame
                # Ensure tensor is a NumPy array
                if not isinstance(tensor, np.ndarray):
                    tensor = tensor.numpy()  # Convert from PyTorch tensor to NumPy array

                # Update counts for each keypoint
                for i in range(96):
                    keypoint_total_count[i] += 1  # Every frame contributes a value for each keypoint
                    if tensor[i, 1] == 0 or tensor[i, 2] == 0:  # Check if x or y is missing
                        keypoint_miss_count[i] += 1

                # Update global counts
                total += tensor.shape[0] * 2  # Each keypoint has x and y
                missed_keypoints += (tensor[:, 1:] == 0).sum()

        except Exception as e:
            print(f"Error processing dataset at index {idx}: {e}")

    # Calculate missing percentage for each keypoint based on the total for each keypoint
    keypoint_miss_percentages = (keypoint_miss_count / keypoint_total_count) * 100

    # Calculate the global fraction of missed keypoints
    fraction_missed = missed_keypoints / total

    return {
        "total_keypoints": total,
        "missed_keypoints": missed_keypoints,
        "fraction_missed": fraction_missed,
        "missing_percentages_per_keypoint": keypoint_miss_percentages.tolist()
    }


def compute_weighted_average(stats_train, stats_test, stats_val, sizes):
    """
    Compute the weighted average of missing percentages and other statistics from three datasets.

    Args:
        stats_train: Statistics dictionary from the train dataset.
        stats_test: Statistics dictionary from the test dataset.
        stats_val: Statistics dictionary from the validation dataset.
        sizes: Tuple with the sizes of the datasets (train_size, test_size, val_size).

    Returns:
        dict: A dictionary with the weighted averages for the following:
            - "total_keypoints": Weighted average of total keypoints.
            - "missed_keypoints": Weighted average of missed keypoints.
            - "fraction_missed": Weighted average of fraction missed.
            - "missing_percentages_per_keypoint": Weighted average of missing percentages per keypoint.
    """
    train_size, test_size, val_size = sizes
    total_size = train_size + test_size + val_size
    
    # Calculate the weighted averages for the overall statistics
    weighted_avg_total_keypoints = (
        stats_train['total_keypoints'] * train_size +
        stats_test['total_keypoints'] * test_size +
        stats_val['total_keypoints'] * val_size
    ) / total_size

    weighted_avg_missed_keypoints = (
        stats_train['missed_keypoints'] * train_size +
        stats_test['missed_keypoints'] * test_size +
        stats_val['missed_keypoints'] * val_size
    ) / total_size

    weighted_avg_fraction_missed = (
        stats_train['fraction_missed'] * train_size +
        stats_test['fraction_missed'] * test_size +
        stats_val['fraction_missed'] * val_size
    ) / total_size

    # Calculate the weighted average of missing percentages for each keypoint
    weighted_avg_missing_percentages = (
        (np.array(stats_train['missing_percentages_per_keypoint']) * train_size) +
        (np.array(stats_test['missing_percentages_per_keypoint']) * test_size) +
        (np.array(stats_val['missing_percentages_per_keypoint']) * val_size)
    ) / total_size

    return {
        "total_keypoints": weighted_avg_total_keypoints,
        "missed_keypoints": weighted_avg_missed_keypoints,
        "fraction_missed": weighted_avg_fraction_missed,
        "missing_percentages_per_keypoint": weighted_avg_missing_percentages.tolist()
    }

Error in __getitem__ at idx 0: list index out of range
Processing sample 1
Error processing sample 1: 'int' object is not iterable
Error in __getitem__ at idx 1: list index out of range
Processing sample 2
Error processing sample 2: 'int' object is not iterable
Error in __getitem__ at idx 2: list index out of range
Processing sample 3
Error processing sample 3: 'int' object is not iterable
Error in __getitem__ at idx 3: list index out of range
Processing sample 4
Error processing sample 4: 'int' object is not iterable
Error in __getitem__ at idx 4: list index out of range
Processing sample 5
Error processing sample 5: 'int' object is not iterable
Error in __getitem__ at idx 5: list index out of range
Processing sample 6
Error processing sample 6: 'int' object is not iterable
Error in __getitem__ at idx 6: list index out of range
Processing sample 7
Error processing sample 7: 'int' object is not iterable
Error in __getitem__ at idx 7: list index out of range
Processing sample 8
Error pr

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



KeyboardInterrupt: 

In [None]:
# Example usage
train_size = (len(train_ds))
test_size = (len(test_ds))
val_size = (len(val_ds))

# Compute statistics for each dataset
train_stats = calculate_keypoint_errors(train_ds)
test_stats = calculate_keypoint_errors(test_ds)
val_stats = calculate_keypoint_errors(val_ds)

# Compute the weighted average of the statistics
weighted_avg_stats = compute_weighted_average(train_stats, test_stats, val_stats, (train_size, test_size, val_size))