# Analyzing time-dependent data

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import plotly.express as px
import plotly.graph_objs as go
import plotly.colors
from collections import defaultdict

from sklearn.preprocessing import StandardScaler
from umap import UMAP

from src.ModelUtils import (load_model,
                            train_model,
                            save_model_losses,
                            load_and_plot_losses)
from src.DataUtils import (load_data_from_directory,
                           segment_all_steps,
                           steps_to_tensor,
                           reshape_data)
from src.Plot import (plot_step_trajectory, plot_trajectory,
                      plot_animated_trajectory, plot_trajectory_with_joint_traces,
                      compare_step_features_in_batches, plot_umap_from_step,
                      plot_umap_all_steps, plot_mean_spatial_trajectory,
                      compare_phase_aligned_average_single, compare_phase_aligned_average_xy,
                      plot_phase_aligned_average_single, plot_phase_aligned_average_xy,
                      plot_trajectory_with_joint_trace,
                      compare_spatial_progression_over_time, compare_spatial_progression_xy_over_time,
                      compare_spatial_angle_progression_over_time)

In [2]:
## Directories and constants
FIGURES_DIR = './figures/Pain'
MODELS_DIR = './src/models'
DATA_DIR = './csv/Pain_Plot_Features'
MEAN_FEATURES_FILE = './csv/Pain_hindlimb_mouse_features_2025-06-22_21-53-52.csv'
DATASETS = ['A_', 'B_', 'C_', 'D_', 'E_']
HEALTHY_KEY = 'pre'
SICK_KEY = 'post'
SIDE_KEY = ('left','right')

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 500
LR = 1e-3
PATIENCE = 50 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = None#os.path.join(MODELS_DIR, 'lstm_VAE_no_first_last_20250609_121841.pt')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = load_data_from_directory(DATA_DIR)

# Print the data structure
for datagroup, mice in data.items():
    print(f"Group: {datagroup}")
    for mouse_direction, runs in mice.items():
        print(f"\t{mouse_direction}: {len(runs)} runs with shapes: {[df.shape for df in runs.values()]}")

Group: A_postDLC
	mouse1_left: 3 runs with shapes: [(310, 199), (313, 199), (297, 199)]
	mouse1_right: 4 runs with shapes: [(259, 199), (251, 199), (233, 199), (351, 199)]
	mouse2_left: 4 runs with shapes: [(209, 199), (216, 199), (124, 199), (155, 199)]
	mouse2_right: 3 runs with shapes: [(223, 199), (193, 199), (211, 199)]
	mouse3_left: 5 runs with shapes: [(249, 199), (231, 199), (241, 199), (318, 199), (244, 199)]
	mouse3_right: 4 runs with shapes: [(297, 199), (230, 199), (154, 199), (224, 199)]
	mouse4_left: 3 runs with shapes: [(269, 199), (188, 199), (221, 199)]
	mouse4_right: 3 runs with shapes: [(283, 199), (213, 199), (195, 199)]
	mouse5_left: 3 runs with shapes: [(301, 199), (210, 199), (321, 199)]
	mouse5_right: 3 runs with shapes: [(341, 199), (212, 199), (191, 199)]
	mouse6_left: 4 runs with shapes: [(302, 199), (157, 199), (297, 199), (201, 199)]
	mouse6_right: 4 runs with shapes: [(271, 199), (235, 199), (163, 199), (141, 199)]
	mouse7_left: 3 runs with shapes: [(172, 

In [3]:
## Load the mean features
mean_features_df = pd.read_csv(MEAN_FEATURES_FILE)
# healthy_mean_features = mean_features_df[
#     mean_features_df['Dataset'].str.lower().str.contains(HEALTHY_KEY.lower()) &
#     mean_features_df['Dataset'].str.lower().str.contains(SIDE_KEY[0].lower())
# ]
# sick_mean_features = mean_features_df[
#     mean_features_df['Dataset'].str.lower().str.contains(SICK_KEY.lower()) &
#     mean_features_df['Dataset'].str.lower().str.contains(SIDE_KEY[0].lower())
# ]
# sick_mean_features_r = mean_features_df[
#     mean_features_df['Dataset'].str.lower().str.contains(SICK_KEY.lower()) &
#     mean_features_df['Dataset'].str.lower().str.contains(SIDE_KEY[1].lower())
# ]
# healthy_mean_features_r = mean_features_df[
#     mean_features_df['Dataset'].str.lower().str.contains(HEALTHY_KEY.lower()) &
#     mean_features_df['Dataset'].str.lower().str.contains(SIDE_KEY[1].lower())
# ]
healthy_mean_features_lr = mean_features_df[mean_features_df['Dataset'].str.lower().str.contains(HEALTHY_KEY.lower()) &
    mean_features_df['Dataset'].str.lower().str.contains(SIDE_KEY[0].lower())]
sick_mean_features_lr = mean_features_df[mean_features_df['Dataset'].str.lower().str.contains(SICK_KEY.lower()) & 
                                         mean_features_df['Dataset'].str.lower().str.contains(DATASETS[2].lower())&
    mean_features_df['Dataset'].str.lower().str.contains(SIDE_KEY[0].lower())]
therapy_mean_features = mean_features_df[mean_features_df['Dataset'].str.lower().str.contains(SICK_KEY.lower()) & 
                                         mean_features_df['Dataset'].str.lower().str.contains(DATASETS[4].lower())&
    mean_features_df['Dataset'].str.lower().str.contains(SIDE_KEY[0].lower())]

In [None]:
from copy import deepcopy
from scipy.stats import ttest_ind_from_stats

t_test_results = defaultdict(dict)

fore_hind_data = ((healthy_mean_features_lr, sick_mean_features_lr, control_mean_features_lr, treated_control_mean_features_lr),
                (healthy_forelimb_mean_features, sick_forelimb_mean_features, control_forelimb_mean_features, treated_control_forelimb_mean_features))

# healthy_feature_data = healthy_mean_features
# sick_feature_data = sick_mean_features
# control_feature_data = control_mean_features

# healthy_feature_data = healthy_forelimb_mean_features
# sick_feature_data =  sick_forelimb_mean_features
# control_feature_data =  control_forelimb_mean_features

hs_hc_sc_common_features = set(healthy_mean_features.columns) & set(sick_mean_features.columns) & set(control_mean_features.columns)
hs_hc_sc_common_features |= set(healthy_forelimb_mean_features.columns) & set(sick_forelimb_mean_features.columns) & set(control_forelimb_mean_features.columns)
hs_hc_sc_common_features -= {'Dataset', 'Mouse', 'Run'}  # Remove non-feature columns

alpha_value = 0.05  # Significance level for t-tests

hs_common_features = deepcopy(hs_hc_sc_common_features)
hc_common_features = deepcopy(hs_hc_sc_common_features)
sc_common_features = deepcopy(hs_hc_sc_common_features)

hs_hc_common_features = deepcopy(hs_hc_sc_common_features)
hs_sc_common_features = deepcopy(hs_hc_sc_common_features)
hc_sc_common_features = deepcopy(hs_hc_sc_common_features)

hc_sc_mhs_common_features = deepcopy(hs_hc_sc_common_features)

hs_mhtc_common_features = deepcopy(hs_hc_sc_common_features)


drug_less_effect_common_features = deepcopy(hs_hc_sc_common_features)
drug_effect_for_recovery_common_features = deepcopy(hs_hc_sc_common_features)
impaired_affected_by_drug_common_features = deepcopy(hs_hc_sc_common_features)
impaired_not_affected_by_drug_common_features = deepcopy(hs_hc_sc_common_features)
not_epilepsy_affected_by_drug_common_features = deepcopy(hs_hc_sc_common_features)


for i, (healthy_feature_data, sick_feature_data, control_feature_data, treated_control_feature_data) in enumerate(fore_hind_data):
    print('--' * 20)
    print(f"Processing dataset {i+1} / {len(fore_hind_data)}\n\tWith {len(healthy_feature_data)} healthy, {len(sick_feature_data)} sick, {len(control_feature_data)} control mice and {len(treated_control_feature_data)} treated control mice')")
    print(f"Healthy feature data shape: {healthy_feature_data.shape}")
    print(f"Sick feature data shape: {sick_feature_data.shape}")
    print(f"Control feature data shape: {control_feature_data.shape}")
    print(f"Treated control feature data shape: {treated_control_feature_data.shape}")

    total_features = set(healthy_feature_data.columns) & set(sick_feature_data.columns) & set(control_feature_data.columns) & set(treated_control_feature_data.columns)
    total_features.remove('Dataset')  # Remove the 'Dataset' column as it is not a feature
    total_features.remove('Mouse')
    if  'Run' in total_features:
        total_features.remove('Run')

    # Bonferroni correction for multiple comparisons
    corrected_alpha_value = alpha_value  / len(total_features)  # Adjust alpha for multiple comparisons

    for feature in total_features:
        healthy_values = healthy_feature_data[feature].dropna().values
        sick_values = sick_feature_data[feature].dropna().values
        control_values = control_feature_data[feature].dropna().values
        treated_control_values = treated_control_feature_data[feature].dropna().values

        if len(healthy_values) > 0 and len(sick_values) > 0 and len(control_values) > 0 and len(treated_control_values) > 0:
            healthy_mean = np.mean(healthy_values)
            sick_mean = np.mean(sick_values)
            control_mean = np.mean(control_values)
            treated_control_mean = np.mean(treated_control_values)

            healthy_std = np.std(healthy_values, ddof=1)
            sick_std = np.std(sick_values, ddof=1)
            control_std = np.std(control_values, ddof=1)
            treated_control_std = np.std(treated_control_values, ddof=1)

            t_stat_hs, p_value_hs = ttest_ind_from_stats(
                mean1=healthy_mean,
                std1=healthy_std,
                nobs1=len(healthy_values),
                mean2=sick_mean,
                std2=sick_std,
                nobs2=len(sick_values),
                equal_var=False  # Welch's t-test
            )

            t_stat_hc, p_value_hc = ttest_ind_from_stats(
                mean1=healthy_mean,
                std1=healthy_std,
                nobs1=len(healthy_values),
                mean2=control_mean,
                std2=control_std,
                nobs2=len(control_values),
                equal_var=False  # Welch's t-test
            )

            t_stat_sc, p_value_sc = ttest_ind_from_stats(
                mean1=sick_mean,
                std1=sick_std,
                nobs1=len(sick_values),
                mean2=control_mean,
                std2=control_std,
                nobs2=len(control_values),
                equal_var=False  # Welch's t-test
            )

            t_stat_htc, p_value_htc = ttest_ind_from_stats(
                mean1=healthy_mean,
                std1=healthy_std,
                nobs1=len(healthy_values),
                mean2=treated_control_mean,
                std2=treated_control_std,
                nobs2=len(treated_control_values),
                equal_var=False  # Welch's t-test
            )

            t_test_results[feature]['healthy_mean'] = healthy_mean
            t_test_results[feature]['sick_mean'] = sick_mean
            t_test_results[feature]['control_mean'] = control_mean
            t_test_results[feature]['treated_control_mean'] = treated_control_mean

            t_test_results[feature]['healthy_std'] = healthy_std
            t_test_results[feature]['sick_std'] = sick_std
            t_test_results[feature]['control_std'] = control_std
            t_test_results[feature]['treated_control_std'] = treated_control_std

            t_test_results[feature]['t-statistic-hs'] = t_stat_hs
            t_test_results[feature]['t-statistic-hc'] = t_stat_hc
            t_test_results[feature]['t-statistic-sc'] = t_stat_sc
            t_test_results[feature]['t-statistic-htc'] = t_stat_htc

            t_test_results[feature]['p-value-hs'] = p_value_hs
            t_test_results[feature]['p-value-hc'] = p_value_hc
            t_test_results[feature]['p-value-sc'] = p_value_sc
            t_test_results[feature]['p-value-htc'] = p_value_htc

    # Convert results to DataFrame for better readability
    t_test_df = pd.DataFrame(t_test_results).T
    t_test_df = t_test_df.reset_index().rename(columns={'index': 'Feature'})
    t_test_df['Significant-hs'] = t_test_df['p-value-hs'] < corrected_alpha_value  # Mark significant results, which indicates a significant difference between healthy and sick mice
    t_test_df['Significant-hc'] = t_test_df['p-value-hc'] < corrected_alpha_value  # Mark significant results, which indicates a significant difference between healthy and sick mice
    t_test_df['Significant-sc'] = t_test_df['p-value-sc'] < corrected_alpha_value  # Mark significant results, which indicates a significant difference between healthy and sick mice
    t_test_df['Significant-htc'] = t_test_df['p-value-htc'] < corrected_alpha_value  # Mark significant results, which indicates a significant difference between healthy and sick mice
    
    print(f"Number of healthy mice: {len(healthy_forelimb_mean_features)}")
    print(f"Number of sick mice: {len(sick_forelimb_mean_features)}")

    t_test_df = t_test_df[~t_test_df['Feature'].str.contains('at')]
    t_test_df = t_test_df[~t_test_df['Feature'].str.contains('std')]
    t_test_df = t_test_df[~t_test_df['Feature'].str.contains('angle value')]
    t_test_df = t_test_df[~t_test_df['Feature'].str.contains('phase value')]

    print(f"Number of hs significant features: {t_test_df[t_test_df['Significant-hs']].shape[0]}")
    print(f"Number of hc significant features: {t_test_df[t_test_df['Significant-hc']].shape[0]}")
    print(f"Number of sc significant features: {t_test_df[t_test_df['Significant-sc']].shape[0]}")
    print(f"Number of htc significant features: {t_test_df[t_test_df['Significant-htc']].shape[0]}")
    print()

    hs_significant_features = t_test_df[t_test_df['Significant-hs']].sort_values(by='p-value-hs')['Feature'].tolist()
    hc_significant_features = t_test_df[t_test_df['Significant-hc']].sort_values(by='p-value-hc')['Feature'].tolist()
    sc_significant_features = t_test_df[t_test_df['Significant-sc']].sort_values(by='p-value-sc')['Feature'].tolist()
    htc_significant_features = t_test_df[t_test_df['Significant-htc']].sort_values(by='p-value-htc')['Feature'].tolist()

    sc_significant_features_set = set(sc_significant_features)
    print(f"{SICK_KEY} VS {CONTROL_KEY} significantly different features: {sorted(sc_significant_features_set)}")
    print(f"\tLength: {len(sc_significant_features_set)}")

    hs_significant_features_set = set(hs_significant_features)
    print(f"{HEALTHY_KEY} VS {SICK_KEY} significantly different features: {sorted(hs_significant_features_set)}")
    print(f"\tLength: {len(hs_significant_features_set)}")

    hc_significant_features_set = set(hc_significant_features)
    print(f"{HEALTHY_KEY} VS {CONTROL_KEY} significantly different features: {sorted(hc_significant_features_set)}")
    print(f"\tLength: {len(hc_significant_features_set)}")

    htc_significant_features_set = set(htc_significant_features)
    print(f"{HEALTHY_KEY} VS {TREATMENT_CONTROL_KEY} significantly different features: {sorted(htc_significant_features_set)}")
    print(f"\tLength: {len(htc_significant_features_set)}")
    print()


    hs_and_hc_significant_features = hs_significant_features_set & hc_significant_features_set
    print(f"Features significant in both {HEALTHY_KEY} VS {SICK_KEY} and {HEALTHY_KEY} VS {CONTROL_KEY}: {sorted(hs_and_hc_significant_features)}")
    print(f"\tLength: {len(hs_and_hc_significant_features)}")

    hs_sc_significant_features = hs_significant_features_set & sc_significant_features_set
    print(f"Features significant in both {HEALTHY_KEY} VS {SICK_KEY} and {SICK_KEY} VS {CONTROL_KEY}: {sorted(hs_sc_significant_features)}")
    print(f"\tLength: {len(hs_sc_significant_features)}")

    hc_sc_significant_features = hc_significant_features_set & sc_significant_features_set
    print(f"Features significant in both {HEALTHY_KEY} VS {CONTROL_KEY} and {SICK_KEY} VS {CONTROL_KEY}: {sorted(hc_sc_significant_features)}")
    print(f"\tLength: {len(hc_sc_significant_features)}")
    print()

    hc_sc_mhs_significant_features = hc_sc_significant_features - hs_significant_features_set
    print(f"Features significant in hc and sc but not in hs: {sorted(hc_sc_mhs_significant_features)}")
    print(f"\tLength: {len(hc_sc_mhs_significant_features)}")

    hs_sc_hc_significant_features = hs_sc_significant_features & hc_significant_features_set
    print(f"Features significant in all three comparisons: {sorted(hs_sc_hc_significant_features)}")
    print(f"\tLength: {len(hs_sc_hc_significant_features)}")

    hs_mhtc_significant_features = hs_significant_features_set - htc_significant_features_set
    print(f"Features significant in {HEALTHY_KEY} VS {SICK_KEY} but not in {HEALTHY_KEY} VS {TREATMENT_CONTROL_KEY}: {sorted(hs_mhtc_significant_features)}")
    print(f"\tLength: {len(hs_mhtc_significant_features)}")


    impairment_features = hc_significant_features_set
    recovery_features = sc_significant_features_set
    drug_no_effect_features = hs_significant_features_set

    drug_less_effect_features = impairment_features & drug_no_effect_features
    drug_effect_for_recovery_features = recovery_features - drug_no_effect_features
    impaired_affected_by_drug_features = impairment_features & recovery_features
    impaired_not_affected_by_drug_features = impairment_features - recovery_features
    not_epilepsy_affected_by_drug_features = recovery_features - impairment_features

    drug_less_effect_common_features &= drug_less_effect_features
    drug_effect_for_recovery_common_features &= drug_effect_for_recovery_features
    impaired_affected_by_drug_common_features &= impaired_affected_by_drug_features
    impaired_not_affected_by_drug_common_features &= impaired_not_affected_by_drug_features
    not_epilepsy_affected_by_drug_common_features &= not_epilepsy_affected_by_drug_features



    hs_common_features &= hs_significant_features_set
    hc_common_features &= hc_significant_features_set
    sc_common_features &= sc_significant_features_set

    hs_hc_common_features &= hs_and_hc_significant_features
    hs_sc_common_features &= hs_sc_significant_features
    hc_sc_common_features &= hc_sc_significant_features

    hc_sc_mhs_common_features &= hc_sc_mhs_significant_features

    hs_hc_sc_common_features = hs_hc_sc_common_features & hs_sc_hc_significant_features
    hs_mhtc_common_features &= hs_mhtc_significant_features

print('--' * 20)
print(f"{SICK_KEY} VS {CONTROL_KEY} common significantly different features: {sorted(sc_common_features)}")
print(f"\tLength: {len(sc_common_features)}")

print(f"{HEALTHY_KEY} VS {SICK_KEY} common significantly different features: {sorted(hs_common_features)}")
print(f"\tLength: {len(hs_common_features)}")

print(f"{HEALTHY_KEY} VS {CONTROL_KEY} significantly different features: {sorted(hc_common_features)}")
print(f"\tLength: {len(hc_common_features)}")

print(f"{HEALTHY_KEY} VS {TREATMENT_CONTROL_KEY} significantly different features: {sorted(htc_significant_features_set)}")
print(f"\tLength: {len(htc_significant_features_set)}")
print()

print(f"Features significant in both {HEALTHY_KEY} VS {SICK_KEY} and {HEALTHY_KEY} VS {CONTROL_KEY}: {sorted(hs_hc_common_features)}")
print(f"\tLength: {len(hs_hc_common_features)}")

print(f"Features significant in both {HEALTHY_KEY} VS {SICK_KEY} and {SICK_KEY} VS {CONTROL_KEY}: {sorted(hs_sc_common_features)}")
print(f"\tLength: {len(hs_sc_common_features)}")

print(f"Features significant in both {HEALTHY_KEY} VS {CONTROL_KEY} and {SICK_KEY} VS {CONTROL_KEY}: {sorted(hc_sc_common_features)}")
print(f"\tLength: {len(hc_sc_common_features)}")
print()

print(f"Features significant in {HEALTHY_KEY} VS {CONTROL_KEY} but not in {HEALTHY_KEY} VS {TREATMENT_CONTROL_KEY}: {sorted(hs_mhtc_common_features)}")
print(f"\tLength: {len(hs_mhtc_common_features)}")

print(f"Features significant in {HEALTHY_KEY} VS {CONTROL_KEY} and {SICK_KEY} VS {CONTROL_KEY} but not in {HEALTHY_KEY} VS {SICK_KEY}: {sorted(hc_sc_mhs_common_features)}")
print(f"\tLength: {len(hc_sc_mhs_common_features)}")

print(f"Common significant features across all comparisons: {sorted(hs_hc_sc_common_features)}")
print(f"\tLength: {len(hs_hc_sc_common_features)}")
print()

print(f"Common features showing impairment due to epilepsy: {sorted(hc_common_features)}")
print(f"\tLength: {len(hc_common_features)}")

# print(f"Common features where drug has no or less effect: {sorted(drug_less_effect_common_features)}")
# print(f"\tLength: {len(drug_less_effect_common_features)}")

# print(f"Common features where drug has effect for recovery: {sorted(drug_effect_for_recovery_common_features)}")
# print(f"\tLength: {len(drug_effect_for_recovery_common_features)}")

# print(f"Common Epilepsy features affected by drug: {sorted(impaired_affected_by_drug_common_features)}")
# print(f"\tLength: {len(impaired_affected_by_drug_common_features)}")

print(f"Common Epilepsy features not affected by drug: {sorted(impaired_not_affected_by_drug_common_features)}")
print(f"\tLength: {len(impaired_not_affected_by_drug_common_features)}")

print(f"Common features not affected by Epilepsy but affected by drug: {sorted(not_epilepsy_affected_by_drug_common_features)}")
print(f"\tLength: {len(not_epilepsy_affected_by_drug_common_features)}")

print(f"Common Epilepsy features perfectly treated by drug: {sorted(hc_sc_mhs_common_features)}")
print(f"\tLength: {len(hc_sc_mhs_common_features)}")

print(f"Common Epilepsy features partially or not treated by drug: {sorted(hs_hc_sc_common_features)}")
print(f"\tLength: {len(hs_hc_sc_common_features)}")


In [None]:
# Print the features count
for datagroup, mice in data.items():
    for mouse_direction, runs in mice.items():
        for run, df in runs.items():
            print(f"Number of features: {df.shape[1] + 1 - 4}") # +1 bc of index, -4 to exclude frame, forestep, hindstep, and time
            break
        break
    break

In [None]:
## Segment all steps in the data
# This will create two dictionaries: segmented_hindsteps and segmented_foresteps
# Each dictionary will contain segmented steps for each mouse direction and run
    # "step": step_df,
    # "group": group,
    # "mouse": mouse,
    # "run": run
segmented_hindsteps, segmented_foresteps = segment_all_steps(data)

reshaped_data = reshape_data(data) # reshape the data to have a single DataFrame for each mouse direction and run and access to the hindsteps and foresteps

## Flatten all steps into a single array to compute global mean/std and get the scaler
all_healthy_arrays = [step_dict["step"].values for step_dict in segmented_hindsteps if HEALTHY_KEY in step_dict["group"] and SIDE_KEY[0] in step_dict["mouse"]]
flat_data = np.vstack(all_healthy_arrays)

scaler = StandardScaler()
scaler = scaler.fit(flat_data)

## Prepare the data for training
healthy_steps = [s for s in segmented_hindsteps if HEALTHY_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
unhealthy_steps = [s for s in segmented_hindsteps if SICK_KEY in s["group"] and SIDE_KEY[0] in s["mouse"] and (DATASETS[2] in s["group"])]
therapy_steps =  [s for s in segmented_hindsteps if SICK_KEY in s["group"] and SIDE_KEY[0] in s["mouse"] and (DATASETS[4] in s["group"])]
step_tensor, lengths = steps_to_tensor(healthy_steps, scaler)
if INPUT_DIM is None:
    INPUT_DIM = step_tensor.shape[2]

In [None]:
print(f'Segmented dataframe shape: {segmented_hindsteps[0]["step"].shape}')
#print(f'Averaged dataframe shhape: {averaged_hindsteps[0]["step"].shape}')
print(f"Step tensor shape: {step_tensor.shape}, \nLengths shape: {lengths.shape}")

In [None]:
### Plot XY pose data
HINDLIMB_KEY = 'hindlimb'
FORELIMB_KEY = 'forelimb'
SPINE_KEY = 'spine'
TAIL_KEY = 'tail'
POSE_KEY = 'pose'
ANGLE_KEY = 'Angle'
ALL_LIMBS = [HINDLIMB_KEY, FORELIMB_KEY, SPINE_KEY, TAIL_KEY]
first_df = next(iter(segmented_hindsteps))["step"]

In [None]:
## Extract poses for each limb
hindlimb_keys = [key for key in first_df.columns if HINDLIMB_KEY in key and POSE_KEY in key]
forelimb_keys = [key for key in first_df.columns if FORELIMB_KEY in key and POSE_KEY in key]
spine_keys = [key for key in first_df.columns if SPINE_KEY in key and POSE_KEY in key]
tail_keys = [key for key in first_df.columns if TAIL_KEY in key and POSE_KEY in key]

all_limbs_keys = [hindlimb_keys, forelimb_keys, spine_keys, tail_keys]
for limb in all_limbs_keys:
    print(f"Limbs poses: {limb}")

In [None]:
## Extract angles for each limb
hindlimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and HINDLIMB_KEY in key]
forelimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and FORELIMB_KEY in key]
spine_angles = [key for key in first_df.columns if ANGLE_KEY in key and SPINE_KEY in key]
tail_angles = [key for key in first_df.columns if ANGLE_KEY in key and TAIL_KEY in key]

all_limbs_angles = [hindlimb_angles, forelimb_angles, spine_angles, tail_angles]
for limb in all_limbs_angles:
    print(f"Limbs angles: {limb}")

In [None]:
## Extract CoM features
CoM_features = [key for key in first_df.columns if 'CoM' in key in key]
print(f"CoM features: {CoM_features}")

In [None]:
## Extract relevant single point features
hindfinger_features = [key for key in first_df.columns if 'hindfingers' in key.lower() and 'Angle' not in key]
knee_features = [key for key in first_df.columns if 'knee' in key.lower() and 'Angle' not in key]
ankle_features = [key for key in first_df.columns if 'ankle' in key.lower() and 'Angle' not in key]
hip_features = [key for key in first_df.columns if 'hip' in key.lower() and 'Angle' not in key]
shoulder_features = [key for key in first_df.columns if 'shoulder' in key.lower() and 'Angle' not in key]
hindpaw_features = [key for key in first_df.columns if 'hindpaw' in key.lower() and 'Angle' not in key]
forepaw_features = [key for key in first_df.columns if 'forepaw' in key.lower() and 'Angle' not in key]
head_features = [key for key in first_df.columns if 'head' in key.lower() and 'Angle' not in key]

single_point_features = (hindfinger_features, knee_features, ankle_features, hip_features, hindpaw_features, forepaw_features, hip_features, head_features)
single_point_feature_names = ('hindfingers', 'knee', 'ankle', 'hip', 'hindpaw', 'forepaw', 'shoulder', 'head')
for features in single_point_features:
    print(f"Single point features: {features}")

In [None]:
## Save all angle plots
savedir = os.path.join(FIGURES_DIR, 'Angles')
for limb, angles in zip(ALL_LIMBS, all_limbs_angles):
    plot_phase_aligned_average_single(healthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_time_locked_{limb}.svg')) # Plot phase locked for one state
    compare_phase_aligned_average_single(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_time_locked_{limb}.svg')) # Plot phase locked for two states
    compare_spatial_angle_progression_over_time(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_space_locked_{limb}.svg')) # Plot space locked for two states

In [None]:
## Save all CoM plots
savedir = os.path.join(FIGURES_DIR, 'CoM')

plot_phase_aligned_average_xy(healthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_time_locked.svg')) # Plot phase locked for one state
compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=CoM_features, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states

In [None]:
## Save all XY pose plots
savedir = os.path.join(FIGURES_DIR, 'Kinematics')

for name, feature in zip(single_point_feature_names, single_point_features):
    plot_phase_aligned_average_xy(healthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_time_locked.svg')) # Plot phase locked for one state
    compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
    compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=feature, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states