In [3]:
%matplotlib widget
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import imageio.v2 as imageio
import pandas as pd
import sys
sys.path.append('../')

from utils import config
from utils import utils

In [4]:
# >>> PREPROCESSING 1 METHODS <<< #
def common_frames(pose, acc):
    """
    pose: has shape (3, N, 18)
    acc: has shape (N, 5)
    Keep frames that are common between POSE and ACC
    """
    pose_frames = pose[0, :, 0]
    acc_frames = acc[:, 0]
    # all the frames of POSE that are in ACC
    mask_pose = np.isin(pose_frames, acc_frames)
    pose_intersect = pose[:, mask_pose, :]

    mask_acc = np.isin(acc_frames, pose_frames)
    acc_intersect = acc[mask_acc]

    return pose_intersect, acc_intersect


def shift_time_start_to_zero(acc):
    """
    acc: has shape (N, 5)
    start the timestamps from zero
    there are some timestamps that are NOT UNIQUE, but commented for now
    """
    timestamps = acc[:, 1]
    t_zero = timestamps - timestamps[0]
    t_zero = t_zero/1000  # convert miliseconds to seconds

    acc_t0 = np.copy(acc)
    acc_t0[:, 1] = t_zero

    # _, unique_indices = np.unique(t_zero, return_index=True)
    # acc_t0 = acc_t0[unique_indices]
    return acc_t0


def regularize_time(acc):
    """
    acc: has shape (N, 5)
    interpolate acc to have regular timestamps
    """
    acc_sample_rate = 100  # Hz
    start_t = int(acc[0, 1])
    end_t = int(acc[-1, 1])
    window = end_t - start_t  # window of time in seconds
    regular_time = np.linspace(
        start=start_t, stop=end_t, num=window*acc_sample_rate)

    original_time = acc[:, 1]
    """
    acc timestamps is a real mess. it's not regular at all. sometimes it goes 1 milisecond up and sometimes 60 miliseconds.
    """
    acc_regular = np.empty((regular_time.shape[0], acc.shape[1]))
    acc_regular[:, 1] = regular_time  # set timestamps (1)
    for i in [0, 2, 3, 4]:  # interpolate frame, x, y, z
        f = interp1d(original_time, acc[:, i], kind='linear')
        acc_regular[:, i] = f(regular_time)

    acc_regular[:, 0] = np.around(acc_regular[:, 0])  # make frames integers

    return acc_regular

def find_pose_frame_time(pose, acc):
    """
    pose: has shape (3, N, 18)
    acc: has shape (N, 5)
    find the timestamps for each frame of pose
    """
    pose_frame = pose[0, :, 0]
    acc_frame_time = acc[:, :2]
    pose_df = pd.DataFrame(pose_frame, columns=['frame'])
    acc_df = pd.DataFrame(acc_frame_time, columns=['frame', 'time'])
    acc_df_deduplicated = acc_df.drop_duplicates(subset=['frame'], keep='first').reset_index(drop=True)
    pose_df = pose_df.merge(acc_df_deduplicated[['frame', 'time']], on='frame', how='left')
    pose_frame_time = pose_df.to_numpy() # shape (N, 2)
    return pose_frame_time

def upsample_pose(pose, acc):
    """
    pose: has shape (3, N, 18)
    acc: has shape (N, 5)
    upsample pose to have the same timestamps as acc
    """
    pose_frame_time = find_pose_frame_time(pose, acc) # shape (N, 2)
    pose_time = pose_frame_time[:, 1] # timestamps
    acc_time = acc[:, 1]

    pose_upsample = np.empty((3, acc_time.shape[0], pose.shape[2] + 1)) # shape (3, N, 19) timestamps added
    pose_upsample[:, :, :2] = acc[:, :2]  # set frmae (0) and timestamps (1)

    for axis in range(pose.shape[0]):  # for x, y, z
        for joint in range(1, pose.shape[2]):  # start from 1 to skip frame
            f = interp1d(pose_time, pose[axis, :, joint], kind='linear', fill_value='extrapolate')
            # joint+2 to account for frame and timestamps
            pose_upsample[axis, :, joint + 1] = f(acc_time)  # set joint (2)

    # create a pose that's not upsampled but has timestamps just for comparison later
    original_pose_with_time = np.empty((3, pose_time.shape[0], pose.shape[2] + 1))
    original_pose_with_time[:, :, :2] = pose_frame_time  # set frame (0) and timestamps (1)
    original_pose_with_time[:, :, 2:] = pose[:, :, 1:]  # set joints (2:)

    return pose_upsample, original_pose_with_time

In [5]:
# >>> PLOTS methods <<< #
def plot_acc(acc):
    fig, ax = plt.subplots(3, 1, figsize=(12, 6))
    fig.subplots_adjust(hspace=0.5)
    ax[0].plot(acc[:, 1], acc[:, 2], label="x")
    ax[0].set_xlabel("time (s)")
    ax[0].set_ylabel("acceleration (m/s^2)")
    ax[0].legend()

    ax[1].plot(acc[:, 1], acc[:, 3], label="y")
    ax[1].legend()
    ax[1].set_xlabel("time (s)")
    ax[1].set_ylabel("acceleration (m/s^2)")

    ax[2].plot(acc[:, 1], acc[:, 4], label="z")
    ax[2].legend()
    ax[2].set_xlabel("time (s)")
    ax[2].set_ylabel("acceleration (m/s^2)")

    plt.show()


def compare_2_acc(original_acc, modified_acc, offset=50):
    """
    acc: has shape (N, 5)
    plot two ACCs against each other to compare if they are aligned or not.
    """
    original_df = pd.DataFrame(original_acc, columns=[
                               'frame', 'time', 'x', 'y', 'z']).set_index('time')
    modified_df = pd.DataFrame(modified_acc, columns=[
                               'frame', 'time', 'x', 'y', 'z']).set_index('time')

    # Merge DataFrames on time index with an outer join
    merged_df_x = original_df[['x']].merge(modified_df[[
                                           'x']], left_index=True, right_index=True, how='outer', suffixes=('_orig', '_mod'))
    merged_df_y = original_df[['y']].merge(modified_df[[
                                           'y']], left_index=True, right_index=True, how='outer', suffixes=('_orig', '_mod'))
    merged_df_z = original_df[['z']].merge(modified_df[[
                                           'z']], left_index=True, right_index=True, how='outer', suffixes=('_orig', '_mod'))

    # # Plotting
    fig, ax = plt.subplots(3, 1, figsize=(12, 6))
    ax[0].plot(merged_df_x.index, merged_df_x['x_orig'],
               label="Original X", color='g', alpha=0.5)
    ax[0].plot(merged_df_x.index, merged_df_x['x_mod'] +
               offset, label="Modified Y", color='r', alpha=0.5)
    ax[0].legend()
    ax[0].set_xlabel("Time")
    ax[0].set_ylabel("(m/s^2)")

    ax[1].plot(merged_df_y.index, merged_df_y['y_orig'],
               label="Original Y", color='g', alpha=0.5)
    ax[1].plot(merged_df_y.index, merged_df_y['y_mod'] +
               offset, label="Modified Y", color='r', alpha=0.5)
    ax[1].legend()
    ax[1].set_xlabel("Time")
    ax[1].set_ylabel("(m/s^2)")

    ax[2].plot(merged_df_z.index, merged_df_z['z_orig'],
               label="Original Z", color='g', alpha=0.5)
    ax[2].plot(merged_df_z.index, merged_df_z['z_mod'] +
               offset, label="Modified Z", color='r', alpha=0.5)
    ax[2].legend()
    ax[2].set_xlabel("Time")
    ax[2].set_ylabel("(m/s^2)")

    plt.show()


def plot_pose(pose, joint):
    fig, ax = plt.subplots(3, 1)
    fig.subplots_adjust(hspace=0.5)
    ax[0].plot(pose[0, :, 1], pose[0, :, joint], label="x")
    ax[0].legend()
    ax[1].plot(pose[0, :, 1], pose[1, :, joint], label="y")
    ax[1].legend()
    ax[2].plot(pose[0, :, 1], pose[2, :, joint], label="z")
    ax[2].legend()

    plt.show()


def compare_2_pose(original_pose, modified_pose, ch_idx=0, joint_idx=-1):
    """
    pose: has shape (3, N, 19)
    ch_idx: x, y, or z
    joint_idx: the index of joint to plot
    plot two POSSs against each other to compare if they are aligned or not.
    """
    # NOTE for some reason I need to transpose because the shape is (3, N) instead of (N, 3 !!!
    original = original_pose[ch_idx, :, [0, 1, joint_idx]].T
    modified = modified_pose[ch_idx, :, [0, 1, joint_idx]].T

    original_df = pd.DataFrame(
        original, columns=['frame', 'time', 'joint']).set_index('time')
    modified_df = pd.DataFrame(
        modified, columns=['frame', 'time', 'joint']).set_index('time')

    # Merge DataFrames on time index with an outer join
    merged_df = original_df[['joint']].merge(modified_df[[
                                             'joint']], left_index=True, right_index=True, how='outer', suffixes=('_orig', '_mod'))

    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    fig.subplots_adjust(hspace=0.5)
    ax.plot(merged_df.index,
            merged_df['joint_orig'], label="Original", color='g')
    ax.plot(merged_df.index, merged_df['joint_mod'],
            label="Modified", color='r', alpha=0.4)
    ax.legend()
    ax.set_xlabel("Time")
    ax.set_ylabel("(m/s^2)")

    plt.show()

def frame_to_timestamp(frame, frame_time_df):
        return frame_time_df.loc[frame_time_df['frame'] == frame, 'time'].iloc[0]


def acc_label(acc, labels, frame_time):
    labels_df = pd.DataFrame(labels, columns=['start_frame', 'end_frame', 'rep', 'label']) 
    frame_time_df = pd.DataFrame(frame_time, columns=['frame', 'time'])

    fig, ax = plt.subplots(3, 1, figsize=(12, 6))
    fig.subplots_adjust(hspace=0.5)
    ax[0].plot(acc[:, 1], acc[:, 2], label="x")
    ax[0].legend()
    ax[1].plot(acc[:, 1], acc[:, 3], label="y")
    ax[1].legend()
    ax[2].plot(acc[:, 1], acc[:, 4], label="z")
    ax[2].legend()
    
    for i, label in labels_df.iterrows():
        start_time = frame_to_timestamp(frame=label['start_frame'], frame_time_df=frame_time_df)
        end_time = frame_to_timestamp(frame=label['end_frame'], frame_time_df=frame_time_df)
        for i in range(3):
            ax[i].axvspan(start_time, end_time, color='red', alpha=0.3)

    plt.show()

def pose_label(pose, labels, frame_time, joint_idx=-1):
    labels_df = pd.DataFrame(labels, columns=['start_frame', 'end_frame', 'rep', 'label']) 
    frame_time_df = pd.DataFrame(frame_time, columns=['frame', 'time'])

    fig, ax = plt.subplots(3, 1, figsize=(12, 6))
    fig.subplots_adjust(hspace=0.5)
    ax[0].plot(pose[0, :, 1], pose[0, :, joint_idx], label="x")
    ax[0].legend()
    ax[1].plot(pose[0, :, 1], pose[1, :, joint_idx], label="y")
    ax[1].legend()
    ax[2].plot(pose[0, :, 1], pose[2, :, joint_idx], label="z")
    ax[2].legend()
    
    for i, label in labels_df.iterrows():
        start_time = frame_to_timestamp(frame=label['start_frame'], frame_time_df=frame_time_df)
        end_time = frame_to_timestamp(frame=label['end_frame'], frame_time_df=frame_time_df)
        for i in range(3):
            ax[i].axvspan(start_time, end_time, color='red', alpha=0.3)

    plt.show()

In [6]:
# >>> PREPROCESS 1: REGULARIZE TIME + UPSAMPLE POSE <<< #
subjects = {}
TMP = ['w01']
for id_prefix in config.TRAIN_W_IDS + config.VAL_W_IDS + config.TEST_W_IDS:
    id_dir = os.path.join(config.mmfit_data_dir, id_prefix)
    pose_file = os.path.join(id_dir, id_prefix + '_' + config.original_pose_file)
    acc_file = os.path.join(id_dir, id_prefix + '_' + config.original_acc_file)
    label_file = os.path.join(id_dir, id_prefix + '_' + config.labels_file)

    pose = utils.load_modality(pose_file)  # shape (3, N, 18) 18 is (frame, 17 joints)
    acc = utils.load_modality(acc_file) # shape (N, 5) 5 is (frame, timestamp, acc_x, acc_y, acc_z)
    labels = utils.load_labels(label_file) # list: (Start Frame, End Frame, Repetition Count, Activity Class)
    
    acc_t0 = shift_time_start_to_zero(acc=acc)
    acc_regular = regularize_time(acc_t0)
    pose_intersect, acc_intersect = common_frames(pose=pose, acc=acc_regular)
    # shape (XYZ, N, (frame, timestamps, 17 joints)) N is same as acc
    pose_upsample, original_pose_with_time = upsample_pose(pose=pose_intersect, acc=acc_intersect)

    subjects[id_prefix] = {
        'pose': pose_upsample,
        'acc': acc_intersect,
    }

In [7]:
# >>> PREPROCESS 2: STANDARDIZE ACC <<< #
train_means = np.empty((len(config.TRAIN_W_IDS), 3)) # 3 for x, y, z and len(TRAIN_W_IDS) for each participant
train_stds = np.empty((len(config.TRAIN_W_IDS), 3)) # 3 for x, y, z and len(TRAIN_W_IDS) for each participant

for id_prefix in config.TRAIN_W_IDS:
    acc = subjects[id_prefix]['acc'] # shape (N, 5) 5 is (frame, timestamp, acc_x, acc_y, acc_z)

    acc_data = np.copy(acc[:, 2:])
    mean = np.mean(acc_data, axis=0)
    std = np.std(acc_data, axis=0)

    train_means[config.TRAIN_W_IDS.index(id_prefix)] = mean
    train_stds[config.TRAIN_W_IDS.index(id_prefix)] = std

total_mean = np.mean(train_means, axis=0) # shape (3,)
total_std = np.mean(train_stds, axis=0) # shape (3,)

for id_prefix in config.TRAIN_W_IDS + config.VAL_W_IDS + config.TEST_W_IDS:
    acc = subjects[id_prefix]['acc'] # shape (N, 5) 5 is (frame, timestamp, acc_x, acc_y, acc_z)

    for i in range(2, acc.shape[1]):
        acc[:, i] = (acc[:, i] - total_mean[i-2]) / (total_std[i-2] + 1e-8)

    subjects[id_prefix]['acc'] = acc

    id_dir = os.path.join(config.mmfit_data_dir, id_prefix)
    np.save(os.path.join(id_dir, id_prefix + '_' + config.acc_file), acc)

In [8]:
# # --- TEST synchronization plots --- #
# compare_2_acc(original_acc=acc_t0, modified_acc=acc_regular, offset=10)

# frame_time = find_pose_frame_time(pose=pose_intersect, acc=acc_intersect)
# acc_label(acc=acc_intersect, labels=labels, frame_time=frame_time)

In [9]:
# # --- TEST sanity check --- #
# acc_file_suffix = '_acc-cut.npy'
# acc_std_file_suffix = '_acc-std.npy'
# T = ['w01']
# for id_prefix in T:
#     id_dir = os.path.join(config.mmfit_data_dir, id_prefix)
#     acc_file = os.path.join(id_dir, id_prefix + acc_file_suffix)
#     acc_std_file = os.path.join(id_dir, id_prefix + acc_std_file_suffix)
#     acc = np.load(acc_file) # shape (N, 5) 5 is (frame, timestamp, acc_x, acc_y, acc_z)
#     acc_std = np.load(acc_std_file) # shape (N, 5) 5 is (frame, timestamp, acc_x, acc_y, acc_z)

#     compare_2_acc(original_acc=acc, modified_acc=acc_std)
#     break

In [None]:
# >>> PREPROCESS 3: NORMALIZE POSE <<< #
for id_prefix in config.TRAIN_W_IDS + config.VAL_W_IDS + config.TEST_W_IDS:
    # shape (XYZ, N, (frame, timestamps, 17 joints))
    pose = subjects[id_prefix]['pose']

    # Step 1: Extract MidHip (index 1) and Neck (index 10) coordinates
    midhip_coords = pose[:, :, 2]  # mid hip
    neck_coords = pose[:, :, 11]  # neck

    # Step 2: Calculate Euclidean distances for each sample N
    distances = np.linalg.norm(neck_coords - midhip_coords, axis=0)

    # Step 3: Calculate median distance using a sliding window 1.5 second window
    median_distances = np.array(
        [np.median(distances[max(0, t-150):min(pose.shape[1], t+151)]) for t in range(pose.shape[1])])

    # Step 4 and 5: Normalize all joints coordinates
    pose_normal = np.empty_like(pose)
    pose_normal[:, :, :2] = pose[:, :, :2]  # copy frame and timestamps
    for i in range(3, pose.shape[2]): # start from 3 to skip frame and timestamps and hip joint
        joint_coords = pose[:, :, i]
        # Subtract MidHip coords and divide by the median distance for normalization
        pose_normal[:, :, i] = 2.0 * ((joint_coords - midhip_coords) / median_distances) - 1.0

    # Step 6: Centering the MidHip joint at the origin for all samples
    pose_normal[:, :, 2] = 0

    subjects[id_prefix]['pose'] = pose_normal
    
    id_dir = os.path.join(config.mmfit_data_dir, id_prefix)
    np.save(os.path.join(id_dir, id_prefix + '_' + config.pose_file), pose_normal)

In [None]:
# --- TEST synchronization plots --- #
# plot_pose(pose=pose_normal, joint=-1)
# compare_2_pose(original_pose=pose_upsample, modified_pose=pose_normal, joint_idx=-1)

# frame_time = find_pose_frame_time(pose=pose_intersect, acc=acc_intersect)
# pose_label(pose=pose_normal, labels=labels, frame_time=frame_time, joint_idx=-1)
# acc_label(acc=acc_intersect, labels=labels, frame_time=frame_time)