# Analyzing time-dependent data

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import plotly.express as px
import plotly.graph_objs as go
import plotly.colors
from collections import defaultdict

from sklearn.preprocessing import StandardScaler
from umap import UMAP

from src.ModelUtils import (load_model,
                            train_model,
                            save_model_losses,
                            load_and_plot_losses)
from src.DataUtils import (load_data_from_directory,
                           segment_all_steps,
                           steps_to_tensor,
                           reshape_data)
from src.Plot import (plot_step_trajectory, plot_trajectory,
                      plot_animated_trajectory, plot_trajectory_with_joint_traces,
                      compare_step_features_in_batches, plot_umap_from_step,
                      plot_umap_all_steps, plot_mean_spatial_trajectory,
                      compare_phase_aligned_average_single, compare_phase_aligned_average_xy,
                      plot_phase_aligned_average_single, plot_phase_aligned_average_xy,
                      plot_trajectory_with_joint_trace,
                      compare_spatial_progression_over_time, compare_spatial_progression_xy_over_time,
                      compare_spatial_angle_progression_over_time)

In [None]:
## Directories and constants
FIGURES_DIR = './figures/Pain'
MODELS_DIR = './src/models'
DATA_DIR = './csv/Pain_Plot_Features'
DATASETS = ['A_', 'B_', 'C_', 'D_', 'E_']
HEALTHY_KEY = 'pre'
SICK_KEY = 'post'
SIDE_KEY = ('left','right')

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 500
LR = 1e-3
PATIENCE = 50 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = None#os.path.join(MODELS_DIR, 'lstm_VAE_no_first_last_20250609_121841.pt')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = load_data_from_directory(DATA_DIR)

# Print the data structure
for datagroup, mice in data.items():
    print(f"Group: {datagroup}")
    for mouse_direction, runs in mice.items():
        print(f"\t{mouse_direction}: {len(runs)} runs with shapes: {[df.shape for df in runs.values()]}")

In [None]:
# Print the features count
for datagroup, mice in data.items():
    for mouse_direction, runs in mice.items():
        for run, df in runs.items():
            print(f"Number of features: {df.shape[1] + 1 - 4}") # +1 bc of index, -4 to exclude frame, forestep, hindstep, and time
            break
        break
    break

In [None]:
## Segment all steps in the data
# This will create two dictionaries: segmented_hindsteps and segmented_foresteps
# Each dictionary will contain segmented steps for each mouse direction and run
    # "step": step_df,
    # "group": group,
    # "mouse": mouse,
    # "run": run
segmented_hindsteps, segmented_foresteps = segment_all_steps(data)

reshaped_data = reshape_data(data) # reshape the data to have a single DataFrame for each mouse direction and run and access to the hindsteps and foresteps

## Flatten all steps into a single array to compute global mean/std and get the scaler
all_healthy_arrays = [step_dict["step"].values for step_dict in segmented_hindsteps if HEALTHY_KEY in step_dict["group"] and SIDE_KEY[0] in step_dict["mouse"]]
flat_data = np.vstack(all_healthy_arrays)

scaler = StandardScaler()
scaler = scaler.fit(flat_data)

## Prepare the data for training
healthy_steps = [s for s in segmented_hindsteps if HEALTHY_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
unhealthy_steps = [s for s in segmented_hindsteps if SICK_KEY in s["group"] and SIDE_KEY[0] in s["mouse"] and (DATASETS[2] in s["group"])]
step_tensor, lengths = steps_to_tensor(healthy_steps, scaler)
if INPUT_DIM is None:
    INPUT_DIM = step_tensor.shape[2]

In [None]:
print(f'Segmented dataframe shape: {segmented_hindsteps[0]["step"].shape}')
#print(f'Averaged dataframe shhape: {averaged_hindsteps[0]["step"].shape}')
print(f"Step tensor shape: {step_tensor.shape}, \nLengths shape: {lengths.shape}")

In [None]:
### Plot XY pose data
HINDLIMB_KEY = 'hindlimb'
FORELIMB_KEY = 'forelimb'
SPINE_KEY = 'spine'
TAIL_KEY = 'tail'
POSE_KEY = 'pose'
ANGLE_KEY = 'Angle'
ALL_LIMBS = [HINDLIMB_KEY, FORELIMB_KEY, SPINE_KEY, TAIL_KEY]
first_df = next(iter(segmented_hindsteps))["step"]

In [None]:
## Extract poses for each limb
hindlimb_keys = [key for key in first_df.columns if HINDLIMB_KEY in key and POSE_KEY in key]
forelimb_keys = [key for key in first_df.columns if FORELIMB_KEY in key and POSE_KEY in key]
spine_keys = [key for key in first_df.columns if SPINE_KEY in key and POSE_KEY in key]
tail_keys = [key for key in first_df.columns if TAIL_KEY in key and POSE_KEY in key]

all_limbs_keys = [hindlimb_keys, forelimb_keys, spine_keys, tail_keys]
for limb in all_limbs_keys:
    print(f"Limbs poses: {limb}")

In [None]:
## Extract angles for each limb
hindlimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and HINDLIMB_KEY in key]
forelimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and FORELIMB_KEY in key]
spine_angles = [key for key in first_df.columns if ANGLE_KEY in key and SPINE_KEY in key]
tail_angles = [key for key in first_df.columns if ANGLE_KEY in key and TAIL_KEY in key]

all_limbs_angles = [hindlimb_angles, forelimb_angles, spine_angles, tail_angles]
for limb in all_limbs_angles:
    print(f"Limbs angles: {limb}")

In [None]:
## Extract CoM features
CoM_features = [key for key in first_df.columns if 'CoM' in key in key]
print(f"CoM features: {CoM_features}")

In [None]:
## Extract relevant single point features
hindfinger_features = [key for key in first_df.columns if 'hindfingers' in key.lower() and 'Angle' not in key]
knee_features = [key for key in first_df.columns if 'knee' in key.lower() and 'Angle' not in key]
ankle_features = [key for key in first_df.columns if 'ankle' in key.lower() and 'Angle' not in key]
hip_features = [key for key in first_df.columns if 'hip' in key.lower() and 'Angle' not in key]
shoulder_features = [key for key in first_df.columns if 'shoulder' in key.lower() and 'Angle' not in key]
hindpaw_features = [key for key in first_df.columns if 'hindpaw' in key.lower() and 'Angle' not in key]
forepaw_features = [key for key in first_df.columns if 'forepaw' in key.lower() and 'Angle' not in key]
head_features = [key for key in first_df.columns if 'head' in key.lower() and 'Angle' not in key]

single_point_features = (hindfinger_features, knee_features, ankle_features, hip_features, hindpaw_features, forepaw_features, hip_features, head_features)
single_point_feature_names = ('hindfingers', 'knee', 'ankle', 'hip', 'hindpaw', 'forepaw', 'shoulder', 'head')
for features in single_point_features:
    print(f"Single point features: {features}")

In [None]:
## Save all angle plots
savedir = os.path.join(FIGURES_DIR, 'Angles')
for limb, angles in zip(ALL_LIMBS, all_limbs_angles):
    plot_phase_aligned_average_single(healthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_time_locked_{limb}.svg')) # Plot phase locked for one state
    compare_phase_aligned_average_single(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_time_locked_{limb}.svg')) # Plot phase locked for two states
    compare_spatial_angle_progression_over_time(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_space_locked_{limb}.svg')) # Plot space locked for two states

In [None]:
## Save all CoM plots
savedir = os.path.join(FIGURES_DIR, 'CoM')

plot_phase_aligned_average_xy(healthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_time_locked.svg')) # Plot phase locked for one state
compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=CoM_features, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states

In [None]:
## Save all XY pose plots
savedir = os.path.join(FIGURES_DIR, 'Kinematics')

for name, feature in zip(single_point_feature_names, single_point_features):
    plot_phase_aligned_average_xy(healthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_time_locked.svg')) # Plot phase locked for one state
    compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
    compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=feature, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states