In [1]:
import os
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mediapipe as mp
from mediapipe.framework.formats import landmark_pb2
import matplotlib.gridspec as gridspec

In [2]:
# input folder - parent folder of data with raw pose estimation data 
raw_pose_data_in_path = r"C:\Users\mmccu\Box\MM_Personal\5_Projects\BoveLab\3_Data_and_Code\gait_bw_zeno_outputs_006\subset_for_memory_c"

In [3]:
# set task to plot - FW_1, PWS_1, right, left
task = 'PWS_1'

In [4]:
# do we want to re-do plots if folder already exits? ('yes' or 'no' ) 
redo = 'no'

In [5]:
# set up mediapipe drawing 
mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose

In [6]:
# connections to draw for yolo 
yolo_connections = {(1,2), # right to left eye 
                    (5,6), # right shoulder to left shoulder 
                    (5,7), # left shoulder to elbow
                    (7, 9), # left elbow to left wrist  
                    (6, 8), # right houlder to elbow
                    (8, 10), # right elbow to wrist 
                    (5, 11), # left shoulder to left hip
                    (6, 12), # right shoulder to right hip
                    (11, 12), # left hip to right hip 
                    (11, 13), # left hip to left kee 
                    (12, 14), # right hip to right knee
                    (13, 15), # left knee to left ankle
                    (14, 16) # right knee to right ankle 
                   }

In [7]:
# temp 
# if only want to run on specific BW IDs - set those here 

bw_id_include = ['BW-0036', # PWS high errors  
'BW-0244', 
'BW-0243',  
'BW-0041',  
'BW-0052',  
'BW-0049',  
'BW-0003',  
'BW-0092',  
'BW-0150',  
'BW-0219',  
'BW-0230',  
'BW-0233',  
'BW-0232',  
'BW-0131',  
'BW-0102'] 

### Plotting functions 

In [8]:
def plot_mp_pose_per_frame(ax, frame, grouped_df, x_min, x_max, y_min, y_max): 
    # add title 
    ax.set_title('Pose')
    
    # get current pose_group of position data for this frame 
    group = grouped_df.get_group(frame) 
    
    # save all landmarks for this frame  
    landmarks = []

    for _, row in group.iterrows(): # each row = one landmark 
        landmark = landmark_pb2.NormalizedLandmark( # format using mediapipe plottin format 
            x=row["X"], y=row["Y"], z=row["Z"], visibility=row["vis"]
        )
        landmarks.append(landmark) # append all landmarks for this frame to landmarks  

    # Extract x and y coordinates 
    x_vals = [lm.x for lm in landmarks]
    y_vals = [lm.y for lm in landmarks]

    # plot 2D 
    ax.scatter(x_vals, y_vals, c='green', s=20)

    # draw connections if all landmarks are present 
    if len(landmarks) == 33:
        # for each connction 
        for connection in mp_pose.POSE_CONNECTIONS: # stores all connections, can edit like yolo if helpful 
            start_idx, end_idx = connection  
            x = [landmarks[start_idx].x, landmarks[end_idx].x]
            y = [landmarks[start_idx].y, landmarks[end_idx].y]
            ax.plot(x, y, c='black') # plt each connection as a single line 
       
    # set consistent limits throughout all plots   
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_max, y_min) # flip because pose y is inverse 
    ax.set_aspect('equal') # x and y on same scale 
    
    # Set axis labels
    ax.set_xlabel('X Axis')
    ax.set_ylabel('Y Axis')

    # draw box around 0 to 1 and 0 to 1 
        # MP "guessing" if outside these bounds 
    ax.hlines(y=1, xmin=0, xmax=1, color='grey', linestyle='-.', alpha = 0.5)
    ax.hlines(y=0, xmin=0, xmax=1, color='grey', linestyle='-.', alpha = 0.5)
    ax.vlines(x = 1, ymin=1, ymax=0, color='grey', linestyle='-.', alpha = 0.5) 
    ax.vlines(x = 0, ymin=1, ymax=0, color='grey', linestyle='-.', alpha = 0.5)  

In [9]:
def plot_mp_world_per_frame(ax, frame, grouped_df, x_min, x_max, y_min, y_max): 
    # add title 
    ax.set_title('World (Meters)')
    
    # get current pose_group of position data for this frame 
    group = grouped_df.get_group(frame) 

    # save all landmarks for this frame  
    landmarks = []

    for _, row in group.iterrows(): # each row = one landmark 
        landmark = landmark_pb2.NormalizedLandmark(
            x=row["X"], y=row["Y"], z=row["Z"], visibility=row["vis"]
        )
       # print(landmark)
        landmarks.append(landmark) # append all landmarks for this frame to landmarks  

    # Extract x and y coordinates 
    x_vals = [lm.x for lm in landmarks]
    y_vals = [lm.y for lm in landmarks]

    # plot 2D 
    ax.scatter(x_vals, y_vals, c='blue', s=20)

    # draw connections if all landmarks are present 
    if len(landmarks) == 33:
        for connection in mp_pose.POSE_CONNECTIONS:
            start_idx, end_idx = connection
            x = [landmarks[start_idx].x, landmarks[end_idx].x]
            y = [landmarks[start_idx].y, landmarks[end_idx].y]
            ax.plot(x, y, c='black')

    # set consistent limits throughout all plots   
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_max, y_min) # flip because pose y is inverse 
    ax.set_aspect('equal') # x and y on same scale 
    
    # Set axis labels
    ax.set_xlabel('X Axis')
    ax.set_ylabel('Y Axis')

In [10]:
def plot_mp_pixel_per_frame(ax, frame, grouped_df, connections, x_min, x_max, y_min, y_max): 
    # add title 
    ax.set_title('Pixels')

    # get current pose_group of position data for this frame 
    group = grouped_df.get_group(frame) 

    # save all landmarks for this frame  
    landmarks = []

    for _, row in group.iterrows(): # each row = one landmark 
        landmark = landmark_pb2.NormalizedLandmark( # format according to mp outputs 
            x=row["X"], y=row["Y"], z=np.nan, visibility=np.nan # no visibility or z values from yolo 
        )
       # print(landmark)
        landmarks.append(landmark) # append all landmarks for this frame to landmarks  

    # Extract x and y coordinates 
    x_vals = [lm.x for lm in landmarks]
    y_vals = [lm.y for lm in landmarks]

    # plot 2D 
    ax.scatter(x_vals, y_vals, c='purple', s=20)

    # draw connections if all landmarks are present 
    for connection in connections:
        start_idx, end_idx = connection
        x = [landmarks[start_idx].x, landmarks[end_idx].x]
        y = [landmarks[start_idx].y, landmarks[end_idx].y]
        ax.plot(x, y, c='black')

    # set consistent limits throughout all plots   
    ax.set_xlim(0, x_max)
    ax.set_ylim(y_max, 0) # flip because pose y is inverse 
    ax.set_aspect('equal') # x and y on same scale 
    
    # Set axis labels
    ax.set_xlabel('X Axis')
    ax.set_ylabel('Y Axis')

In [11]:
def get_zeno_image_size(dir_path, after_bw_file_path, task): 

    # file path from raw data 
    data_file_path = f"BW-{after_bw_file_path}"
    data_id_date, data_pose_folder = os.path.split(data_file_path) 
    
    before_slash, after_slash = data_id_date.split("\\", 1) 
    bw_id = before_slash 

    # video path 
    if (task == 'FW_1') or (task == 'PWS_1'): 
        video_folder = r"C:\Users\mmccu\Box\MM_Personal\5_Projects\BoveLab\3_Data_and_Code\2024_10_Megan_BW_Zeno"
        file_name = f"gait_vertical_{task}.mp4"  # all zeno videos .mp4 
        video_path = os.path.join(video_folder,
                                  data_id_date,
                                  file_name) 

    # home videos 
    elif (task == 'right') or (task == 'left'):
        # version 1 
        if ('_V1' in dir_path):
            video_folder = r"C:\Users\mmccu\Box\Brainwalk\Home Video Walking\Walking home videos for analysis\Instruction_V1"
        elif 'V2' in dir_path:
            video_folder = r"C:\Users\mmccu\Box\Brainwalk\Home Video Walking\Walking home videos for analysis\Instruction_V2"
            
        # get extension for video  
        video_folder_path = os.path.join(video_folder, data_id_date)
        video_folder_listdir = os.listdir(video_folder_path)
        _, ext = os.path.splitext(video_folder_listdir[0])
        print(f"ext {ext}")
        video_path = os.path.join(video_folder_path, f"gait_vertical_{task}{ext}") 
        
    
    # Open the video
    print(f"video path for image size = {video_path}")
    cap = cv2.VideoCapture(video_path)

    # Check if video opened successfully
    if not cap.isOpened():
        print("Error opening video file")
    else:
        # Read the first frame
        ret, frame = cap.read()
    
        if ret:
            # Get height and width of the frame
            height, width = frame.shape[:2]
           # print(f"Width: {width}, Height: {height}")
        else:
            print("Error reading first frame")

    # Release the capture
    cap.release()

    return height, width
    
    

## Loop through all folders and create video 

In [12]:
raw_data_full_path_all = [] 
raw_data_file_names_all = []

# loop through all files in input path 
for (dir_path, dir_names, file_names) in os.walk(raw_pose_data_in_path):
    
    # temp - use if only using specific BW IDs 
    # only using specific ID, change last & statement of next for loop 
    if ('BW-' in dir_path) & ('000_raw_pose_data' in dir_path):
        split_point = "BW-"
        before_bw, after_bw = dir_path.split(split_point, 1)  # Split once at the first occurrence of ID 
        before_slash, after_slash = after_bw.split("\\", 1) 
        id_num = int(before_slash) 
        id_string = f"BW-{before_slash}"
        
    else: 
        id_num = np.nan
        id_string = 'blank'

    # for folders with raw pose data 
    if ('000_raw_pose_data' in dir_path) & ('plots' not in dir_path) & (id_string in bw_id_include): #(id_num > 200): 
        print('------------------------------')
        print(f"input dir_path: {dir_path}")

        # --------------------------------------------
        # set output folder 
        plot_folder_full = os.path.join(dir_path, f"{task}_all_plots") 
        print(f"output folder: {plot_folder_full}")
        # check path exits 
        if os.path.exists(plot_folder_full):
            # if not re-doing plots - skip this folder 
            if redo == 'no': 
                print("'plot_folder_full' already exists - skipping this folder") 
                continue # move to next for loop iteration 
                
            elif redo == 'yes': 
                print("'plot_folder_full' already exits  and replotting'") 
        else: # if not, make new directory 
            print(f"'plot_folder_full' does not exist - making new directory")
            os.mkdir(plot_folder_full) 

        # -----------------------------------------------
        # save all files for this task as one list 
        task_files = []
        for file_name in file_names: 
            if task in file_name: 
                task_files.append(file_name)  

        # if there are not four files (pose, world, yolo, fps) skip to next folder
        if len(task_files) != 4: 
            print('Not all files available - skipping to next') 
            continue 
            
        # assign correct file to name 
        for task_file in task_files: 
           # print(task_file)
            if 'mediapipe_world' in task_file: 
                world_file = task_file 
            elif 'yolo' in task_file: 
                pixel_file = task_file
            elif 'fps' in task_file: 
                fps_file = task_file 
            else: 
                pose_file = task_file 

        # load data 
        # mp pose 
        df_pose = pd.read_csv(os.path.join(dir_path, 
                                           pose_file),
                              index_col = 0)
        
        #mp world 
        df_world = pd.read_csv(os.path.join(dir_path, 
                                            world_file), 
                               index_col = 0)

        # yolo pixels 
        df_pixel = pd.read_csv(os.path.join(dir_path, 
                                            pixel_file),
                               index_col = 0)


        # frames per second 
        df_fps = pd.read_csv(os.path.join(dir_path, 
                                          fps_file), 
                             index_col = 0)
        fps = df_fps.iloc[0,0]
        # -----------------------------------------------
        # format pose data 
        # remove inf values - replace with nan for min and max calculation 
        df_pose = df_pose.replace([np.inf, -np.inf], np.nan)

        # Set min and max for plots 
        xmin_pose = df_pose['X'].min(skipna = True)
        xmax_pose = df_pose['X'].max(skipna = True)

        ymin_pose = df_pose['Y'].min(numeric_only = True)
        ymax_pose = df_pose['Y'].max(numeric_only = True)

        # Group pose data by frame
        df_pose_grouped = df_pose.groupby("frame")

        # --------------------------------------------
        # format world data 
        # remove inf values - replace with nan for min and max calculation 
        df_world = df_world.replace([np.inf, -np.inf], np.nan)

        # Set min and max for plots 
        xmin_world = df_world['X'].min(skipna = True)
        xmax_world = df_world['X'].max(skipna = True)

        ymin_world = df_world['Y'].min(numeric_only = True)
        ymax_world = df_world['Y'].max(numeric_only = True)

        # Group world data by frame
        df_world_grouped = df_world.groupby("frame")

        # ---------------------------------------------
        # format pixel data 
        # if X and Y both equal zero --> replace with nan 
        # find rows with both X and Y equal to zero 
        mask = (df_pixel['X'] == 0) & (df_pixel['Y'] == 0)

        # for those rows in "mask", replace 0 with np.nan
        df_pixel.loc[mask, ['X', 'Y']] = df_pixel.loc[mask, ['X', 'Y']].replace(0, np.nan)

        # Group pixel data by frame
        df_pixel_grouped = df_pixel.groupby("frame")

        #---------------------------------------------
        # Get landmark data for each frame and save plot with all three data sources  
        print('STEP 1: plotting')

        # plot title 
        # get after BW in file path for plotting 
        split_point = "BW-"
        before, after = dir_path.split(split_point, 1)  # Split once at the first occurrence of ID 

        # get image size for pixel plots 
        video_height, video_width = get_zeno_image_size(dir_path, after, task)
        
        for frame_idx in df_pose_grouped.groups:
            # if frame index > x, stop plotting 
            if frame_idx > (fps * 60): 
                print('***ended plotting early - whole video not plotted') 
                break 
            
            # initiate plot 
            fig = plt.figure(figsize=(15, 5))
            # keep all the same height 
            gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1])  # equal visual width

            
            fig.suptitle(f"{task}: BW-{after}")

            ax1 = plt.subplot(gs[0])
            ax2 = plt.subplot(gs[1])
            ax3 = plt.subplot(gs[2])

            # Mediapipe Pose 
            plot_mp_pose_per_frame(ax = ax1,
                                   frame = frame_idx, 
                                   grouped_df = df_pose_grouped,
                                   x_min = xmin_pose,
                                   x_max = xmax_pose,
                                   y_min = ymin_pose,
                                   y_max = ymax_pose)

            # Mediapipe World 
            plot_mp_world_per_frame(ax = ax2, 
                                    frame = frame_idx, 
                                    grouped_df = df_world_grouped, 
                                    x_min = xmin_world, 
                                    x_max = xmax_world, 
                                    y_min = ymin_world, 
                                    y_max = ymax_world)


            # Pixels (Yolo) 
            plot_mp_pixel_per_frame(ax = ax3, 
                                    frame = frame_idx, 
                                    grouped_df = df_pixel_grouped, 
                                    connections = yolo_connections,
                                    x_min = 0, # add extra buffer to see all makers 
                                    x_max = video_width, 
                                    y_min = 0, 
                                    y_max = video_height)
         
            # save plots 
            fig.savefig(os.path.join(plot_folder_full, 
                                    f"frame_{frame_idx:06d}.png"))
           # plt.show()
            plt.close()

        # -------------------------------------------------
        print('STEP 2: creating video')
        # take all plots saved and convert into single video 
        # all files in folder 
        image_file_names = []
        image_file_names = sorted(os.listdir(plot_folder_full))
        image_file_names = [image for image in image_file_names if '.png' in image]

        
        # Read the first image to get the frame size
        first_image_path = os.path.join(plot_folder_full, image_file_names[0])
        frame = cv2.imread(first_image_path)
        height, width, layers = frame.shape
        frame_size = (width, height)

        # Define the codec and create VideoWriter object
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Use 'XVID' for .avi
        out = cv2.VideoWriter(os.path.join(plot_folder_full,
                                           f"{task}_video_plots.mp4"),
                              fourcc, fps, frame_size)

        # Write each image to the video
        for image_name in image_file_names:
            image_path = os.path.join(plot_folder_full, image_name)
            frame = cv2.imread(image_path)
            out.write(frame)

        # Release the writer
        out.release()
        print(f"Video saved in {plot_folder_full}")

        # delete all plot images to save space? 
        print('STEP 3: deleting images after creating video')
        for filename in image_file_names:
            if filename.endswith(".png"):
                file_path = os.path.join(plot_folder_full, filename)
                os.remove(file_path)

------------------------------
input dir_path: C:\Users\mmccu\Box\MM_Personal\5_Projects\BoveLab\3_Data_and_Code\gait_bw_zeno_outputs_006\subset_for_memory_c\BW-0219\2023_08_21\000_raw_pose_data
output folder: C:\Users\mmccu\Box\MM_Personal\5_Projects\BoveLab\3_Data_and_Code\gait_bw_zeno_outputs_006\subset_for_memory_c\BW-0219\2023_08_21\000_raw_pose_data\PWS_1_all_plots
'plot_folder_full' does not exist - making new directory
STEP 1: plotting
video path for image size = C:\Users\mmccu\Box\MM_Personal\5_Projects\BoveLab\3_Data_and_Code\2024_10_Megan_BW_Zeno\BW-0219\2023_08_21\gait_vertical_PWS_1.mp4
STEP 2: creating video
Video saved in C:\Users\mmccu\Box\MM_Personal\5_Projects\BoveLab\3_Data_and_Code\gait_bw_zeno_outputs_006\subset_for_memory_c\BW-0219\2023_08_21\000_raw_pose_data\PWS_1_all_plots
STEP 3: deleting images after creating video
------------------------------
input dir_path: C:\Users\mmccu\Box\MM_Personal\5_Projects\BoveLab\3_Data_and_Code\gait_bw_zeno_outputs_006\subset_f