# Setup

In [21]:
import os
import sys
import re

import pickle

import scipy


import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import statsmodels.api as sm

import kineticstoolkit.lab as ktk

module_path = os.path.abspath(os.path.join('..')) # or the path to your source code
sys.path.insert(0, module_path)

from markers_analysis import markers
from markers_analysis import constants as consts
from markers_analysis import subject as subj



In [22]:
mpl.rcParams.update({'font.size': 14})

# %matplotlib qt
%matplotlib widget

In [23]:
import importlib

import markers_analysis

importlib.reload(markers_analysis)
importlib.reload(consts)
importlib.reload(subj)
importlib.reload(markers)
importlib.reload(markers_analysis.constants)
importlib.reload(markers_analysis.subject)
importlib.reload(markers_analysis.markers)


<module 'markers_analysis.markers' from 'C:\\Users\\noamg\\OneDrive - post.bgu.ac.il\\Documents\\motor learning lab\\GitHub\\Noam-markers-analysis\\markers_analysis\\markers.py'>

## Subject list and paths

In [24]:
results_path = os.path.join('.', 'pkl')
fig_path = os.path.join('.', 'img')

In [25]:
basename = 'white_ball_hit '
subject_id_list = ['007', '008', '009', '010', '011', '012', '013', '014']
date_list = ['2023_11_20', '2023_11_20', '2023_11_20', '2023_12_03', '2023_12_03', '2023_12_03', '2023_12_14', '2023_12_14']

# subject_id_list = ['007', '008', '009', '010', '011']
# date_list = ['2023_11_20', '2023_11_20', '2023_11_20', '2023_12_03', '2023_12_03']


In [26]:
interconnections = markers.get_interconnections()

all_data = []

for subject_id,date in zip(subject_id_list,date_list):
    marker_file_name = os.path.join(results_path, f'id{subject_id} markers.pkl')

    if os.path.exists(marker_file_name):
        # File exists, unpickle the data
        with open(marker_file_name, 'rb') as file:
            subject_data = pickle.load(file)
    else:
        # Process the file
        subject_data = {}

        info = {}
        info['id'] = subject_id
        info['date'] = date
        info['basename'] = basename
        subject_data['info'] = info
        subject_data['data'] = []

        for filenum,filename in subj.find_marker_data_files(date, subject_id):
            data_dict = subj.load_marker_data_file(filename, interconnections)

            data_dict['filenum'] = filenum
            data_dict['filename'] = filename
            subject_data['data'].append(data_dict)
            print(f"************* Done with subject f{subject_data['info']['id']} file number {subject_data['data'][-1]['filenum']}: {subject_data['data'][-1]['filename']}")

        with open(marker_file_name, 'wb') as file:
            pickle.dump(subject_data, file)

    all_data.append(subject_data)






************* Done with subject f007 file number 1: ..\data\subject_007\2023_11_20\markers\white_ball_hit 1 Backup 2023-12-02 21.33.35.c3d
************* Done with subject f007 file number 2: ..\data\subject_007\2023_11_20\markers\white_ball_hit 2 Backup 2023-12-02 21.33.22.c3d
************* Done with subject f007 file number 3: ..\data\subject_007\2023_11_20\markers\white_ball_hit 3 Backup 2023-12-02 21.33.11.c3d
************* Done with subject f007 file number 4: ..\data\subject_007\2023_11_20\markers\white_ball_hit 4 Backup 2023-12-02 21.32.57.c3d
************* Done with subject f007 file number 5: ..\data\subject_007\2023_11_20\markers\white_ball_hit 5 Backup 2023-12-02 21.32.46.c3d
************* Done with subject f007 file number 6: ..\data\subject_007\2023_11_20\markers\white_ball_hit 6 Backup 2023-12-02 21.32.34.c3d
************* Done with subject f007 file number 7: ..\data\subject_007\2023_11_20\markers\white_ball_hit 7.c3d


FileNotFoundError: [Errno 2] No such file or directory: '.\\pkl\\id007 markers.pkl'

## Cut data files

This is slow so it is disabled and the data is loaded from the pickled file.

In [None]:
all_cut_data = []

for subject_data in all_data:
    subject_id = subject_data['info']['id']
    date = subject_data['info']['date']

    table_cut_filename = os.path.join(results_path, f"id{subject_id} cut.pkl")
    marker_cut_filename = os.path.join(results_path, f"id{subject_id} table markers cut.pkl")

    with open(table_cut_filename, 'rb') as file:
        table_data = pickle.load(file)
    removed_table_shots = False

    if os.path.exists(marker_cut_filename):
        # File exists, unpickle the data
        with open(marker_cut_filename, 'rb') as file:
            shots_data = pickle.load(file)
    else:
        shots_data = {}
        shots_data['info'] = subject_data['info']
        shots_data['shots'] = []

        for data_dict in subject_data['data']:
            file_data = data_dict["markers"]

            filenum = data_dict["filenum"]
            frames = data_dict["frames"]
            euler = data_dict["euler"]
            euler_vels = data_dict["euler_vels"]

            ts_rename = [frames, euler, euler_vels]
            ts_names = ['frames', 'angles', 'vels']
            for ts, name in zip(ts_rename, ts_names):
                old_names = list(ts.data.keys())
                new_names = [f'{name}_{on}' for on in old_names]
                for o,n in zip(old_names, new_names):
                    ts.rename_data(o, n, in_place=True)
                file_data.merge(ts, in_place=True)

            file_shots = [s for s in table_data['shots'] if s['filenum'] == filenum]

            indices_to_remove = []
            time = np.arange(start=-1, stop=1, step=0.01)
            for s in file_shots:
                # If we've run out of marker file (why??)
                if file_data.time[-1] < s['zero_time']-1:
                    indices_to_remove.append(table_data['shots'].index(s))
                    continue

                start_time = np.max([file_data.time[0], s['zero_time']-1])
                end_time = np.min([file_data.time[-1], s['zero_time']+1])

                shot_ts = file_data.get_ts_between_times(start_time, end_time, inclusive=True)
                shot_ts.shift(-s['zero_time'], in_place=True)

                shot_ts.resample(time, in_place=True)

                shot_data = {}
                shot_data['filenum'] = filenum
                shot_data['shotnum'] = s['shotnum']
                shot_data['start_time'] = start_time
                shot_data['zero_time'] = s['zero_time']
                shot_data['end_time'] = end_time
                
                shot_data['interconnections'] = data_dict["interconnections"]
                shot_data['global_transform'] = data_dict["global_transform"]

                shot_data['ts'] = shot_ts

                shots_data['shots'].append(shot_data)
            
            for i in reversed(indices_to_remove):
                del table_data['shots'][i]
            if indices_to_remove:
                removed_table_shots = True

        with open(marker_cut_filename, 'wb') as file:
            pickle.dump(shots_data, file)

    if removed_table_shots:
        with open(table_cut_filename, 'wb') as file:
            pickle.dump(table_data, file)

    all_cut_data.append({'marker': shots_data, 'table': table_data})

## Plot position and velocity "learning curves"

# Get successes for each participant

In [None]:
mean_success_list = []
for subject_data in all_cut_data:
    marker_data = subject_data['marker']
    table_data = subject_data['table']

    hits = subj.is_hit(table_data['shots'])
    shots_good = [m for i,m in enumerate(marker_data['shots']) if hits[i]]

    df_list = []
    for shot in shots_good:
        ts = shot['ts']
        success_test_index = ts.get_index_before_time(0)
        shot_df = ts.to_dataframe().iloc[[success_test_index],:]

        df_list.append(shot_df)

    success_df = pd.concat(df_list, axis=0, ignore_index = True)
    # Take robust mean of this subject's movements
    success_df.fillna(np.inf)
    mean_success = success_df.apply(scipy.stats.trim_mean, proportiontocut=0.1)
    success_df.replace(np.inf, np.nan)

    mean_success_list.append(mean_success)

mean_success_df = pd.concat(mean_success_list, axis=1).T


In [None]:
# import statsmodels.api as sem_dfs



mean_success_df

## Outlier removal

### Get a dataframe of the time zero locations and velocities

In [None]:
time_0_data_list = []
for subject_data in all_cut_data:
    marker_data = subject_data['marker']
    table_data = subject_data['table']
    
    plot_list = []
    for shot in marker_data['shots']:
        ts = shot['ts']
        plot_ts = ts.get_subset(['R_SAE', 'R_ELB', 'vels_R_Arm', 'vels_R_Forearm'])
        plot_index = plot_ts.get_index_before_time(0)

        shot_df = plot_ts.to_dataframe().iloc[[plot_index], :]
        plot_list.append(shot_df)

    plot_df = pd.concat(plot_list, ignore_index=True)

    time_0_data_list.append(plot_df)



### Now remove outliers from it

In [None]:
for plot_df in time_0_data_list:

    # Loop over the columns of the original data frame
    for col_name, col_series in plot_df.items():
        # Remove NaNs from the column
        col_series = col_series.dropna()
        
        col_median = col_series.median()
        col_mad = (col_series - col_median).abs().median()


        # Define the lower and upper bounds for outliers
        lower_bound = col_median - 3 * 1.5 * col_mad
        upper_bound = col_median + 3 * 1.5 * col_mad
        # Replace outliers with NaNs

        col_series = col_series.mask((col_series < lower_bound) | (col_series > upper_bound), np.nan)
        # Add the column to the new data frame
        plot_df[col_name] = col_series

# Now get the distance from the time 0 to the average of success

In [None]:
# fig = plt.figure()
# ax = fig.add_subplot(111)

# Function to interpolate NaN values in a column
def interpolate_column(column):
    indices = np.arange(len(column))
    mask = np.isnan(column)
    column[mask] = np.interp(indices[mask], indices[~mask], column[~mask])
    return column

def dist_to_success(df, mean_success, name, dims=[0,1,2]):
    names = [f'{name}[{i}]' for i in dims]
    delta = np.array(
        [df[n] - mean_success[n] for n in names]
    ).T
    np.apply_along_axis(interpolate_column, axis=0, arr=delta)

    return np.linalg.norm(delta, axis=1)



In [None]:
primary_joint_velocity_angles_dict = {
    'Pelvis': 2, 'Thorax': 1, 'Arm': 0, 'Forearm': 0, 'Hand': 1
}

In [None]:
time_0_dist_list = []

for plot_df, (i,mean_success) in zip(time_0_data_list, mean_success_df.iterrows()):
    
    dist_df = pd.DataFrame()

    for j in ['R_SAE', 'R_ELB']:
        dist_df[j] = dist_to_success(plot_df, mean_success, j)

    for j in ['vels_R_Arm', 'vels_R_Forearm']:
        dist_df[j] = dist_to_success(plot_df, mean_success, j, dims=[0])
    
    time_0_dist_list.append(dist_df)

In [None]:
upper_arm = vel_dfs['Upper arm'][zero_index]
upper_arm_success = upper_arm[hits].mean()

forearm = vel_dfs['Forearm'][zero_index]
forearm_success = forearm[hits].mean()

dist_df['vels_R_Arm'] = np.abs(upper_arm - upper_arm_success)
dist_df['vels_R_Forearm'] = np.abs(forearm - forearm_success)

In [None]:
dist_df

In [None]:

fig, axes = plt.subplots(1, 2, figsize=(15, 4))

# Flatten the axes array if needed
axes = axes.flatten()

plot_in_axes = {'R_SAE': 0, 'R_ELB': 0, 'vels_R_Arm': 1, 'vels_R_Forearm': 1}
labels = {'R_SAE': 'Shoulder', 'R_ELB': 'Elbow', 'vels_R_Arm': 'Upper arm', 'vels_R_Forearm': 'Forearm'}
scale = {'R_SAE': 100, 'R_ELB': 100, 'vels_R_Arm': 1, 'vels_R_Forearm': 1}


# Iterate through columns and plot
for col in dist_df.columns:
    ax = axes[plot_in_axes[col]]
    ax.plot(dist_df.index, dist_df[col]*scale[col], label=labels[col], linewidth=2)

axes[0].axvspan(xmin=0, xmax=15, color='#D2B48C', alpha=0.3, label='Baseline')
axes[1].axvspan(xmin=0, xmax=15, color='#D2B48C', alpha=0.3, label='Baseline')

axes[0].set_xlim(0, 200)
axes[0].legend(loc='upper right', fontsize=12)
axes[0].set_title('Joint position')
axes[0].set_xlabel('Shot')
axes[0].set_ylabel('Error (cm)')

axes[1].set_xlim(0, 200)
axes[1].legend(loc='upper right', fontsize=12)
axes[1].set_title('Angular velocity')
axes[1].set_xlabel('Shot')
axes[1].set_ylabel('Error (deg/sec)')


# Adjust layout for better spacing
plt.tight_layout()

fig_savename = os.path.join(fig_path, f"joint learning {subject_id}.png")
plt.savefig(fig_savename, bbox_inches='tight')