In [1]:
import os
import pandas as pd
import numpy as np
import torch
import plotly.express as px
import plotly.graph_objs as go
import plotly.colors
from collections import defaultdict
from tqdm import tqdm

from sklearn.preprocessing import StandardScaler
from umap import UMAP

from src.ModelUtils import (load_model,
                            train_model,
                            save_model_losses,
                            load_and_plot_losses)
from src.DataUtils import (load_data_from_directory,
                           segment_all_steps,
                           steps_to_tensor,
                           reshape_data)
from src.Plot import (plot_step_trajectory, plot_trajectory,
                      plot_animated_trajectory, plot_trajectory_with_joint_traces,
                      compare_step_features_in_batches, plot_umap_from_step,
                      plot_umap_all_steps, plot_mean_spatial_trajectory,
                      compare_phase_aligned_average_single, compare_phase_aligned_average_xy,
                      plot_phase_aligned_average_single, plot_phase_aligned_average_xy,
                      plot_trajectory_with_joint_trace,
                      compare_spatial_progression_over_time, compare_spatial_progression_xy_over_time,
                      compare_spatial_angle_progression_over_time)

In [2]:
## Directories and constants
FIGURES_DIR = './figures/Epilepsy'
MODELS_DIR = './src/models'
DATA_DIR = './csv/EpilepsyFeatures/TimeFeatures'
MEAN_FEATURES_FILE = './csv/SCI_pre_acute_hindlimb_mouse_features_2025-06-06_15-11-59.csv'
HEALTHY_KEY = 'MUTreated'
SICK_KEY = 'MUNonTreated'
SIDE_KEY = ('Left','Right')

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 5000
LR = 1e-3
PATIENCE = 100 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = os.path.join(MODELS_DIR, 'lstm_VAE_t_SCI_20250609_200452.pt')
BEST_LOSS_PATH = os.path.join(FIGURES_DIR, 'SCI_losses.csv')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = load_data_from_directory(DATA_DIR)

# Print the data structure
for datagroup, mice in data.items():
    print(f"Group: {datagroup}")
    for mouse_direction, runs in mice.items():
        print(f"\t{mouse_direction}: {len(runs)} runs with shapes: {[df.shape for df in runs.values()]}")

Group: MUCX_Epilepsy
	Mouse1Cage1_Left: 3 runs with shapes: [(217, 247), (139, 247), (172, 247)]
	Mouse1Cage1_Right: 3 runs with shapes: [(292, 247), (191, 247), (193, 247)]
	Mouse1Cage3_Left: 3 runs with shapes: [(133, 247), (124, 247), (135, 247)]
	Mouse1Cage3_Right: 3 runs with shapes: [(136, 247), (158, 247), (104, 247)]
	Mouse2Cage1_Left: 3 runs with shapes: [(161, 247), (235, 247), (139, 247)]
	Mouse2Cage1_Right: 3 runs with shapes: [(141, 247), (171, 247), (158, 247)]
	Mouse2Cage2_Left: 4 runs with shapes: [(195, 247), (112, 247), (135, 247), (241, 247)]
	Mouse2Cage2_Right: 5 runs with shapes: [(171, 247), (161, 247), (158, 247), (235, 247), (181, 247)]
	Mouse2Cage3_Left: 3 runs with shapes: [(171, 247), (201, 247), (135, 247)]
	Mouse2Cage3_Right: 3 runs with shapes: [(161, 247), (115, 247), (100, 247)]
	Mouse3Cage2_Left: 4 runs with shapes: [(103, 247), (96, 247), (261, 247), (221, 247)]
	Mouse3Cage2_Right: 4 runs with shapes: [(299, 247), (174, 247), (107, 247), (162, 247)]
	M

In [3]:
# Print the features count
for datagroup, mice in data.items():
    for mouse_direction, runs in mice.items():
        for run, df in runs.items():
            print(f"Number of features: {df.shape[1] + 1 - 4}") # +1 bc of index, -4 to exclude frame, forestep, hindstep, and time
            break
        break
    break

Number of features: 244


In [4]:
## Segment all steps in the data
# This will create two dictionaries: segmented_hindsteps and segmented_foresteps
# Each dictionary will contain segmented steps for each mouse direction and run
    # "step": step_df,
    # "group": group,
    # "mouse": mouse,
    # "run": run
segmented_hindsteps, segmented_foresteps = segment_all_steps(data)

reshaped_data = reshape_data(data) # reshape the data to have a single DataFrame for each mouse direction and run and access to the hindsteps and foresteps

## Flatten all steps into a single array to compute global mean/std and get the scaler
all_healthy_arrays = [step_dict["step"].values for step_dict in segmented_hindsteps if HEALTHY_KEY in step_dict["group"] and SIDE_KEY[0] in step_dict["mouse"]]
flat_data = np.vstack(all_healthy_arrays)

scaler = StandardScaler()
scaler = scaler.fit(flat_data)

## Prepare the data for training
healthy_steps = [s for s in segmented_hindsteps if HEALTHY_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
unhealthy_steps = [s for s in segmented_hindsteps if SICK_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
step_tensor, lengths = steps_to_tensor(healthy_steps, scaler)
if INPUT_DIM is None:
    INPUT_DIM = step_tensor.shape[2]

In [5]:
print(f'Segmented dataframe shape: {segmented_hindsteps[0]["step"].shape}')
#print(f'Averaged dataframe shhape: {averaged_hindsteps[0]["step"].shape}')
print(f"Step tensor shape: {step_tensor.shape}, \nLengths shape: {lengths.shape}")

Segmented dataframe shape: (60, 244)
Step tensor shape: torch.Size([91, 93, 244]), 
Lengths shape: torch.Size([91])


In [6]:
### Plot XY pose data
HINDLIMB_KEY = 'hindlimb'
FORELIMB_KEY = 'forelimb'
SPINE_KEY = 'spine'
TAIL_KEY = 'tail'
POSE_KEY = 'pose'
ANGLE_KEY = 'Angle'
PHASE_KEY = 'Phase'
ALL_LIMBS = [HINDLIMB_KEY, FORELIMB_KEY, SPINE_KEY, TAIL_KEY]
first_df = next(iter(segmented_hindsteps))["step"]

In [7]:
## Extract poses for each limb
hindlimb_keys = [key for key in first_df.columns if HINDLIMB_KEY in key and POSE_KEY in key]
forelimb_keys = [key for key in first_df.columns if FORELIMB_KEY in key and POSE_KEY in key]
spine_keys = [key for key in first_df.columns if SPINE_KEY in key and POSE_KEY in key]
tail_keys = [key for key in first_df.columns if TAIL_KEY in key and POSE_KEY in key]

all_limbs_keys = [hindlimb_keys, forelimb_keys, spine_keys, tail_keys]
for limb in all_limbs_keys:
    print(f"Limbs poses: {limb}")

Limbs poses: ['rhindlimb hip - X pose (m)', 'rhindlimb hip - Y pose (m)', 'rhindlimb knee - X pose (m)', 'rhindlimb knee - Y pose (m)', 'rhindlimb ankle - X pose (m)', 'rhindlimb ankle - Y pose (m)', 'rhindlimb lHindpaw - X pose (m)', 'rhindlimb lHindpaw - Y pose (m)', 'rhindlimb lHindfingers - X pose (m)', 'rhindlimb lHindfingers - Y pose (m)']
Limbs poses: ['rforelimb shoulder - X pose (m)', 'rforelimb shoulder - Y pose (m)', 'rforelimb elbow - X pose (m)', 'rforelimb elbow - Y pose (m)', 'rforelimb wrist - X pose (m)', 'rforelimb wrist - Y pose (m)', 'rforelimb lForepaw - X pose (m)', 'rforelimb lForepaw - Y pose (m)']
Limbs poses: ['spine head - X pose (m)', 'spine head - Y pose (m)', 'spine spine 1 - X pose (m)', 'spine spine 1 - Y pose (m)', 'spine spine 2 - X pose (m)', 'spine spine 2 - Y pose (m)', 'spine spine 3 - X pose (m)', 'spine spine 3 - Y pose (m)', 'spine spine 4 - X pose (m)', 'spine spine 4 - Y pose (m)', 'spine base - X pose (m)', 'spine base - Y pose (m)']
Limbs po

In [8]:
## Extract angles for each limb
hindlimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and HINDLIMB_KEY in key]
forelimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and FORELIMB_KEY in key]
spine_angles = [key for key in first_df.columns if ANGLE_KEY in key and SPINE_KEY in key]
tail_angles = [key for key in first_df.columns if ANGLE_KEY in key and TAIL_KEY in key]

all_limbs_angles = [hindlimb_angles, forelimb_angles, spine_angles, tail_angles]
for limb in all_limbs_angles:
    print(f"Limbs angles: {limb}")

Limbs angles: ['Angle - rhindlimb - hip (°)', 'Angle - rhindlimb - knee (°)', 'Angle - rhindlimb - ankle (°)', 'Angle - rhindlimb - lHindpaw (°)', 'Angle - rhindlimb - lHindfingers (°)', 'Angle velocity - rhindlimb - hip (rad/s)', 'Angle velocity - rhindlimb - knee (rad/s)', 'Angle velocity - rhindlimb - ankle (rad/s)', 'Angle velocity - rhindlimb - lHindpaw (rad/s)', 'Angle velocity - rhindlimb - lHindfingers (rad/s)', 'Angle acceleration - rhindlimb - hip (rad/s^2)', 'Angle acceleration - rhindlimb - knee (rad/s^2)', 'Angle acceleration - rhindlimb - ankle (rad/s^2)', 'Angle acceleration - rhindlimb - lHindpaw (rad/s^2)', 'Angle acceleration - rhindlimb - lHindfingers (rad/s^2)']
Limbs angles: ['Angle - rforelimb - shoulder (°)', 'Angle - rforelimb - elbow (°)', 'Angle - rforelimb - wrist (°)', 'Angle - rforelimb - lForepaw (°)', 'Angle velocity - rforelimb - shoulder (rad/s)', 'Angle velocity - rforelimb - elbow (rad/s)', 'Angle velocity - rforelimb - wrist (rad/s)', 'Angle velocity

In [9]:
## Extract phases for each limb
hindlimb_phase = [key for key in first_df.columns if PHASE_KEY in key and HINDLIMB_KEY in key]
forelimb_phase = [key for key in first_df.columns if PHASE_KEY in key and FORELIMB_KEY in key]
spine_phase = [key for key in first_df.columns if PHASE_KEY in key and SPINE_KEY in key]
tail_phase = [key for key in first_df.columns if PHASE_KEY in key and TAIL_KEY in key]

all_limbs_phase = [hindlimb_phase, forelimb_phase, spine_phase, tail_phase]
for limb in all_limbs_phase:
    print(f"Limbs phases: {limb}")

Limbs phases: ['Phase - rhindlimb - hip (°)', 'Phase - rhindlimb - knee (°)', 'Phase - rhindlimb - ankle (°)', 'Phase - rhindlimb - lHindpaw (°)', 'Phase velocity - rhindlimb - hip (rad/s)', 'Phase velocity - rhindlimb - knee (rad/s)', 'Phase velocity - rhindlimb - ankle (rad/s)', 'Phase velocity - rhindlimb - lHindpaw (rad/s)', 'Phase acceleration - rhindlimb - hip (rad/s^2)', 'Phase acceleration - rhindlimb - knee (rad/s^2)', 'Phase acceleration - rhindlimb - ankle (rad/s^2)', 'Phase acceleration - rhindlimb - lHindpaw (rad/s^2)']
Limbs phases: ['Phase - rforelimb - shoulder (°)', 'Phase - rforelimb - elbow (°)', 'Phase - rforelimb - wrist (°)', 'Phase velocity - rforelimb - shoulder (rad/s)', 'Phase velocity - rforelimb - elbow (rad/s)', 'Phase velocity - rforelimb - wrist (rad/s)', 'Phase acceleration - rforelimb - shoulder (rad/s^2)', 'Phase acceleration - rforelimb - elbow (rad/s^2)', 'Phase acceleration - rforelimb - wrist (rad/s^2)']
Limbs phases: ['Phase - spine - head (°)', '

In [10]:
## Extract CoM features
CoM_features = [key for key in first_df.columns if 'CoM' in key in key]
print(f"CoM features: {CoM_features}")

CoM features: ['rforelimb - CoM X (m)', 'rforelimb - CoM Y (m)', 'rhindlimb - CoM X (m)', 'rhindlimb - CoM Y (m)', 'spine - CoM X (m)', 'spine - CoM Y (m)', 'tail - CoM X (m)', 'tail - CoM Y (m)']


In [11]:
## Extract relevant single point features
hindfinger_features = [key for key in first_df.columns if 'hindfingers' in key.lower() and 'Angle' not in key]
hindpaw_features = [key for key in first_df.columns if 'hindpaw' in key.lower() and 'Angle' not in key]
forepaw_features = [key for key in first_df.columns if 'forepaw' in key.lower() and 'Angle' not in key]
head_features = [key for key in first_df.columns if 'head' in key.lower() and 'Angle' not in key]

single_point_features = [hindfinger_features, hindpaw_features, forepaw_features, head_features]
single_point_feature_names = ['hindfingers', 'hindpaw', 'forepaw', 'head']
for features in single_point_features:
    print(f"Single point features: {features}")

Single point features: ['rhindlimb lHindfingers - X pose (m)', 'rhindlimb lHindfingers - X velocity (m/s)', 'rhindlimb lHindfingers - X acceleration (m/s^2)', 'rhindlimb lHindfingers - X jerk (m/s^3)', 'rhindlimb lHindfingers - Y pose (m)', 'rhindlimb lHindfingers - Y velocity (m/s)', 'rhindlimb lHindfingers - Y acceleration (m/s^2)', 'rhindlimb lHindfingers - Y jerk (m/s^3)']
Single point features: ['rhindlimb lHindpaw - X pose (m)', 'rhindlimb lHindpaw - X velocity (m/s)', 'rhindlimb lHindpaw - X acceleration (m/s^2)', 'rhindlimb lHindpaw - X jerk (m/s^3)', 'rhindlimb lHindpaw - Y pose (m)', 'rhindlimb lHindpaw - Y velocity (m/s)', 'rhindlimb lHindpaw - Y acceleration (m/s^2)', 'rhindlimb lHindpaw - Y jerk (m/s^3)', 'Phase - rhindlimb - lHindpaw (°)', 'Phase velocity - rhindlimb - lHindpaw (rad/s)', 'Phase acceleration - rhindlimb - lHindpaw (rad/s^2)']
Single point features: ['rforelimb lForepaw - X pose (m)', 'rforelimb lForepaw - X velocity (m/s)', 'rforelimb lForepaw - X accelera

In [12]:
## Save all angle plots
savedir = os.path.join(FIGURES_DIR, 'Angles')
# Make sure the directory exists
os.makedirs(savedir, exist_ok=True)

for limb, angles in tqdm(zip(ALL_LIMBS, all_limbs_angles), total=len(ALL_LIMBS), desc="Processing limbs"):
    tqdm.write(f"Processing limb: {limb}")
    tqdm.write(f"Plotting angles...")
    plot_phase_aligned_average_single(healthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_time_locked_{limb}.svg')) # Plot phase locked for one state
    tqdm.write(f"Plotting phase aligned average for healthy steps...")
    compare_phase_aligned_average_single(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_time_locked_{limb}.svg')) # Plot phase locked for two states
    tqdm.write(f"Plotting phase aligned average for healthy vs unhealthy steps...")
    compare_spatial_angle_progression_over_time(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_space_locked_{limb}.svg')) # Plot space locked for two states

Processing limbs:   0%|          | 0/4 [00:00<?, ?it/s]

Processing limb: hindlimb
Plotting angles...


Processing limbs:   0%|          | 0/4 [00:06<?, ?it/s]

Plotting phase aligned average for healthy steps...


Processing limbs:   0%|          | 0/4 [00:13<?, ?it/s]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs:  25%|██▌       | 1/4 [00:19<00:59, 19.92s/it]

Processing limb: forelimb
Plotting angles...


Processing limbs:  25%|██▌       | 1/4 [00:22<00:59, 19.92s/it]

Plotting phase aligned average for healthy steps...


Processing limbs:  25%|██▌       | 1/4 [00:27<00:59, 19.92s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs:  50%|█████     | 2/4 [00:32<00:31, 15.81s/it]

Processing limb: spine
Plotting angles...


Processing limbs:  50%|█████     | 2/4 [00:36<00:31, 15.81s/it]

Plotting phase aligned average for healthy steps...


Processing limbs:  50%|█████     | 2/4 [00:44<00:31, 15.81s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs:  75%|███████▌  | 3/4 [00:52<00:17, 17.78s/it]

Processing limb: tail
Plotting angles...


Processing limbs:  75%|███████▌  | 3/4 [00:56<00:17, 17.78s/it]

Plotting phase aligned average for healthy steps...


Processing limbs:  75%|███████▌  | 3/4 [01:03<00:17, 17.78s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs: 100%|██████████| 4/4 [01:09<00:00, 17.46s/it]


In [13]:
## Save all phase plots
savedir = os.path.join(FIGURES_DIR, 'Phases')
# Make sure the directory exists
os.makedirs(savedir, exist_ok=True)

for limb, phases in tqdm(zip(ALL_LIMBS, all_limbs_phase), total=len(ALL_LIMBS), desc="Processing limbs"):
    tqdm.write(f"Processing limb: {limb}")
    tqdm.write(f"Plotting phases...")
    plot_phase_aligned_average_single(healthy_steps, feature_keys=phases, figure_path=os.path.join(savedir, f'Phases_healthy_time_locked_{limb}.svg')) # Plot phase locked for one state
    tqdm.write(f"Plotting phase aligned average for healthy steps...")
    compare_phase_aligned_average_single(healthy_steps, unhealthy_steps, feature_keys=phases, figure_path=os.path.join(savedir, f'Phases_healthy_vs_SCI_time_locked_{limb}.svg')) # Plot phase locked for two states
    tqdm.write(f"Plotting phase aligned average for healthy vs unhealthy steps...")
    compare_spatial_angle_progression_over_time(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Phases_healthy_vs_SCI_space_locked_{limb}.svg')) # Plot space locked for two states

Processing limbs:   0%|          | 0/4 [00:00<?, ?it/s]

Processing limb: hindlimb
Plotting phases...


Processing limbs:   0%|          | 0/4 [00:02<?, ?it/s]

Plotting phase aligned average for healthy steps...


Processing limbs:   0%|          | 0/4 [00:08<?, ?it/s]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs:  25%|██▌       | 1/4 [00:14<00:43, 14.66s/it]

Processing limb: forelimb
Plotting phases...


Processing limbs:  25%|██▌       | 1/4 [00:17<00:43, 14.66s/it]

Plotting phase aligned average for healthy steps...


Processing limbs:  25%|██▌       | 1/4 [00:20<00:43, 14.66s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs:  50%|█████     | 2/4 [00:27<00:27, 13.61s/it]

Processing limb: spine
Plotting phases...


Processing limbs:  50%|█████     | 2/4 [00:31<00:27, 13.61s/it]

Plotting phase aligned average for healthy steps...


Processing limbs:  50%|█████     | 2/4 [00:37<00:27, 13.61s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs:  75%|███████▌  | 3/4 [00:44<00:14, 14.93s/it]

Processing limb: tail
Plotting phases...


Processing limbs:  75%|███████▌  | 3/4 [00:46<00:14, 14.93s/it]

Plotting phase aligned average for healthy steps...


Processing limbs:  75%|███████▌  | 3/4 [00:52<00:14, 14.93s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing limbs: 100%|██████████| 4/4 [00:58<00:00, 14.68s/it]


In [14]:
## Save all CoM plots
savedir = os.path.join(FIGURES_DIR, 'CoM')
# Make sure the directory exists
os.makedirs(savedir, exist_ok=True)

plot_phase_aligned_average_xy(healthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_time_locked.svg')) # Plot phase locked for one state
compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=CoM_features, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states

In [15]:
## Save all XY pose plots
savedir = os.path.join(FIGURES_DIR, 'Kinematics')
# Make sure the directory exists
os.makedirs(savedir, exist_ok=True)

for name, feature in tqdm(zip(single_point_feature_names, single_point_features), total=len(single_point_feature_names), desc="Processing single point features"):
    tqdm.write(f"Processing feature: {name}")
    tqdm.write(f"Plotting phase aligned average for healthy steps...")
    plot_phase_aligned_average_xy(healthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_time_locked.svg')) # Plot phase locked for one state
    tqdm.write(f"Plotting phase aligned average for healthy vs unhealthy steps...")
    compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
    tqdm.write(f"Plotting spatial progression over time for healthy vs unhealthy steps...")
    compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=feature, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states

Processing single point features:   0%|          | 0/4 [00:00<?, ?it/s]

Processing feature: hindfingers
Plotting phase aligned average for healthy steps...


Processing single point features:   0%|          | 0/4 [00:01<?, ?it/s]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing single point features:   0%|          | 0/4 [00:05<?, ?it/s]

Plotting spatial progression over time for healthy vs unhealthy steps...


Processing single point features:  25%|██▌       | 1/4 [00:08<00:26,  8.95s/it]

Processing feature: hindpaw
Plotting phase aligned average for healthy steps...


Processing single point features:  25%|██▌       | 1/4 [00:10<00:26,  8.95s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing single point features:  25%|██▌       | 1/4 [00:14<00:26,  8.95s/it]

Plotting spatial progression over time for healthy vs unhealthy steps...


Processing single point features:  50%|█████     | 2/4 [00:17<00:17,  8.74s/it]

Processing feature: forepaw
Plotting phase aligned average for healthy steps...


Processing single point features:  50%|█████     | 2/4 [00:19<00:17,  8.74s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing single point features:  50%|█████     | 2/4 [00:22<00:17,  8.74s/it]

Plotting spatial progression over time for healthy vs unhealthy steps...


Processing single point features:  75%|███████▌  | 3/4 [00:26<00:08,  8.81s/it]

Processing feature: head
Plotting phase aligned average for healthy steps...


Processing single point features:  75%|███████▌  | 3/4 [00:28<00:08,  8.81s/it]

Plotting phase aligned average for healthy vs unhealthy steps...


Processing single point features:  75%|███████▌  | 3/4 [00:31<00:08,  8.81s/it]

Plotting spatial progression over time for healthy vs unhealthy steps...


Processing single point features: 100%|██████████| 4/4 [00:34<00:00,  8.74s/it]


In [12]:
forelimb_keys = [key for key in first_df.columns if 'forelimb' in key.lower() and POSE_KEY in key]
# Plot stick figure forelimb (bottom) and hindlimb (upper) trajectories
savedir = os.path.join(FIGURES_DIR, 'Trajectories/Full')
# Make sure the directory exists
os.makedirs(savedir, exist_ok=True)

plot_trajectory_with_joint_traces(reshaped_data, forelimb_keys, hindlimb_keys, figure_path=savedir)

Plotting Trajectory with Joint Traces: 100%|██████████| 374/374 [02:10<00:00,  2.86it/s]


In [12]:
# time is nornmalized by step
fig = plot_mean_spatial_trajectory(segmented_hindsteps) 

# fig.write_image(os.path.join('/home/filthyweeb/Downloads', 'xy'), format='svg')

In [13]:
plot_umap_from_step(segmented_hindsteps[0])


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [14]:
plot_umap_all_steps(healthy_steps)


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [15]:
plot_umap_all_steps(unhealthy_steps)


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [16]:
compare_step_features_in_batches(segmented_hindsteps[0])

Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_1.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_2.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_3.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_4.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_5.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_6.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_7.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_8.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_9.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_10.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_11.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_12.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_13.html
Saved: ./figures/SCI\Mouse1Cage1_Left_Run3_feature_pairs_batch_14.html
Saved: ./figure

In [17]:
plot_step_trajectory(segmented_hindsteps, hindlimb_keys)

In [18]:
plot_trajectory_with_joint_trace(reshaped_data, hindlimb_keys)

In [None]:
com_keys = [key for key in first_df.columns if 'CoM' in key]
print(f"CoM keys: {com_keys}")
plot_trajectory_with_joint_trace(reshaped_data, com_keys)

In [20]:
plot_animated_trajectory(reshaped_data, hindlimb_keys)

In [21]:
plot_trajectory(reshaped_data, hindlimb_keys)

In [22]:
## Plt one step for xy pose
compare_spatial_progression_over_time(healthy_steps, unhealthy_steps, finger_pose_key='rhindlimb lHindfingers - X pose (m)')

In [23]:
## Check step features vs velocity 
features_to_check = ('Mean step height (m)', 'Mean step length (m)', 'Mean step duration (s)', 'Mean swing duration (s)', 'Mean stance duration (s)', 'Mean step frequency (Hz)')
features_to_check_against = ('Mean step velocity X at stance - lHindpaw (m/s)', 'Mean step velocity Y at stance - lHindpaw (m/s)', 'Mean step velocity X at swing - lHindpaw (m/s)', 'Mean step velocity Y at swing - lHindpaw (m/s)')
# features_to_check_against = ('Mean step velocity X at stance - lHindfingers (m/s)', 'Mean step velocity Y at stance - lHindfingers (m/s)', 'Mean step velocity X at swing - lHindfingers (m/s)', 'Mean step velocity Y at swing - lHindfingers (m/s)')

## Load the mean features
mean_features_df = pd.read_csv(MEAN_FEATURES_FILE)
healthy_mean_features = mean_features_df[mean_features_df['Dataset'].str.lower().str.contains(HEALTHY_KEY.lower())]
sick_mean_features = mean_features_df[mean_features_df['Dataset'].str.lower().str.contains(SICK_KEY.lower())]
print(f"Healthy mean features shape: {healthy_mean_features.shape}\nSick mean features shape: {sick_mean_features.shape}")

Healthy mean features shape: (0, 2684)
Sick mean features shape: (0, 2684)


In [None]:
## Plot features against velocity
from plotly.subplots import make_subplots

# Prepare subplots: rows = len(features_to_check), cols = len(features_to_check_against)
fig = make_subplots(
    rows=len(features_to_check), 
    cols=len(features_to_check_against),
    # subplot_titles=[
    #     f"{y} vs {x}" for y in features_to_check for x in features_to_check_against
    # ],
    horizontal_spacing=0.05,
    vertical_spacing=0.08,
    print_grid=False
)

for i, y_feat in enumerate(features_to_check):
    for j, x_feat in enumerate(features_to_check_against):
        if x_feat in healthy_mean_features.columns and y_feat in healthy_mean_features.columns:
            fig.add_trace(
                go.Scatter(
                    x=healthy_mean_features[x_feat],
                    y=healthy_mean_features[y_feat],
                    mode='markers',
                    marker=dict(size=SCATTER_SIZE, color='green', opacity=1, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
                    name=f"{y_feat} vs {x_feat}",
                    showlegend=False
                ),
                row=i+1, col=j+1
            )
            # Linear regression for healthy group
            x = healthy_mean_features[x_feat].values
            y = healthy_mean_features[y_feat].values
            if len(x) > 1 and len(y) > 1:
                # Remove NaNs
                mask = ~np.isnan(x) & ~np.isnan(y)
                if mask.sum() > 1:
                    x_clean = x[mask]
                    y_clean = y[mask]
                    # Fit linear regression
                    coef = np.polyfit(x_clean, y_clean, 1)
                    x_line = np.linspace(x_clean.min(), x_clean.max(), 100)
                    y_line = np.polyval(coef, x_line)
                    fig.add_trace(
                        go.Scatter(
                            x=x_line,
                            y=y_line,
                            mode='lines',
                            line=dict(color='red', width=2, dash='dash'),
                            showlegend=False
                        ),
                        row=i+1, col=j+1
                    )
                    if j + 1 == 3:
                        m = coef[0] * 100 if i + 1 != 6 else coef[0]
                        print(f"Healthy linear regression for {y_feat} vs {x_feat}: y = {m:.2f}x + {coef[1]:.2f}")
        #     fig.add_trace(
        #         go.Scatter(
        #             x=sick_mean_features[x_feat],
        #             y=sick_mean_features[y_feat],
        #             mode='markers',
        #             marker=dict(size=SCATTER_SIZE, color='yellow', opacity=1, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
        #             name=f"{y_feat} vs {x_feat}",
        #             showlegend=False
        #         ),
        #         row=i+1, col=j+1
        #     )
        #     # Linear regression for healthy group
        #     x = sick_mean_features[x_feat].values
        #     y = sick_mean_features[y_feat].values
        #     if len(x) > 1 and len(y) > 1:
        #         # Remove NaNs
        #         mask = ~np.isnan(x) & ~np.isnan(y)
        #         if mask.sum() > 1:
        #             x_clean = x[mask]
        #             y_clean = y[mask]
        #             # Fit linear regression
        #             coef = np.polyfit(x_clean, y_clean, 1)
        #             x_line = np.linspace(x_clean.min(), x_clean.max(), 100)
        #             y_line = np.polyval(coef, x_line)
        #             fig.add_trace(
        #                 go.Scatter(
        #                     x=x_line,
        #                     y=y_line,
        #                     mode='lines',
        #                     line=dict(color='blue', width=2, dash='dash'),
        #                     showlegend=False
        #                 ),
        #                 row=i+1, col=j+1
        #             )
        #             if j + 1 == 3:
        #                 print(f"Sick linear regression for {y_feat} vs {x_feat}: y = {coef[0]:.2f}x + {coef[1]:.2f}")
        fig.update_xaxes(title_text=x_feat, row=i+1, col=j+1)
        fig.update_yaxes(title_text=y_feat, row=i+1, col=j+1)

fig.update_layout(
    height=300*len(features_to_check), 
    width=350*len(features_to_check_against),
    title_text="Healthy Mean Features: Step Features vs Velocity Features",
    template="plotly_white"
)
print(f"Number of mice: {len(healthy_mean_features['Mouse'].unique())}")
fig.show()
# fig.write_image(os.path.join(os.path.join(FIGURES_DIR, 'Validation'), 'Healthy_Mean_Features_Step_vs_Velocity.svg'), format='svg')

Processing features: 100%|██████████| 6/6 [00:00<00:00, 62.80it/s]

Number of mice: 0





In [25]:
## Print the ratio between stance and swing durations 
healthy_stance_ratio = healthy_mean_features['Mean stance duration (s)'].mean() / healthy_mean_features['Mean step duration (s)'].mean()
healthy_swing_ratio = healthy_mean_features['Mean swing duration (s)'].mean() / healthy_mean_features['Mean step duration (s)'].mean()
healthy_step_duration = healthy_mean_features['Mean step duration (s)'].mean()
healthy_step_duration_std = healthy_mean_features['Mean step duration (s)'].std()

sick_stance_ratio = sick_mean_features['Mean stance duration (s)'].mean() / sick_mean_features['Mean step duration (s)'].mean()
sick_swing_ratio = sick_mean_features['Mean swing duration (s)'].mean() / sick_mean_features['Mean step duration (s)'].mean()
sick_step_duration = sick_mean_features['Mean step duration (s)'].mean()

print(f"Healthy mean stance duration ratio: {healthy_stance_ratio:.2f}, mean swing duration: {healthy_swing_ratio:.2f}, mean step duration: {healthy_step_duration:.2f} ± {healthy_step_duration_std:.2f}")
print(f"Sick mean stance duration ratio: {sick_stance_ratio:.2f}, mean swing duration: {sick_swing_ratio:.2f}, mean step duration: {sick_step_duration:.2f}")

Healthy mean stance duration ratio: nan, mean swing duration: nan, mean step duration: nan ± nan
Sick mean stance duration ratio: nan, mean swing duration: nan, mean step duration: nan


In [26]:
## Pint the average phase excursions and angle excursions
phase_excursion_cols = [col for col in healthy_mean_features.columns if 'phase excursion' in col.lower() and 'std' not in col.lower()]
angle_excursion_cols = [col for col in healthy_mean_features.columns if 'angle excursion' in col.lower() and 'std' not in col.lower()]

# for col in phase_excursion_cols:
#     healthy_mean = healthy_mean_features[col].mean()
#     healthy_std = healthy_mean_features[col].std() / np.sqrt(len(healthy_mean_features))
#     healthy_sem = healthy_mean_features[col].std() / np.sqrt(len(healthy_mean_features))
    
#     print(f"{col}: Healthy mean = {healthy_mean:.2f} ± {healthy_sem:.2f}, std = {healthy_std:.2f}")

# for col in angle_excursion_cols:
#     healthy_mean = healthy_mean_features[col].mean()
#     healthy_sem = healthy_mean_features[col].std() / np.sqrt(len(healthy_mean_features))
#     sick_mean = sick_mean_features[col].mean()
#     sick_sem = sick_mean_features[col].std() / np.sqrt(len(sick_mean_features))
    
#     print(f"{col}: Healthy mean = {healthy_mean:.2f} ± {healthy_sem:.2f}, Sick mean = {sick_mean:.2f} ± {sick_sem:.2f}")

hip_col = [col for col in phase_excursion_cols if 'hip' in col.lower()]
knee_col = [col for col in phase_excursion_cols if 'knee' in col.lower()]
finger_col = [col for col in phase_excursion_cols if 'paw' in col.lower()]
cols_list = [hip_col, knee_col, finger_col]
cols = ('hip', 'knee', 'paw')

print('number of steps:', len(segmented_hindsteps))
for bp, col in zip(cols, cols_list):
    healthy_mean = healthy_mean_features[col].mean()
    healthy_std = healthy_mean_features[col].std()
    healthy_sem = healthy_std / np.sqrt(len(healthy_mean_features))
    sick_mean = sick_mean_features[col].mean()
    sick_std = sick_mean_features[col].std()
    sick_sem = sick_std / np.sqrt(len(sick_mean_features))
    
    print("---------")
    print(f"{bp} phase excursion: Healthy mean = {healthy_mean.values} ± {healthy_std.values} sem = {healthy_sem.values} \n"
          f"Sick mean = {sick_mean.values} ± {sick_std.values} sem = {sick_sem.values}")


number of steps: 1169
---------
hip phase excursion: Healthy mean = [nan] ± [nan] sem = [nan] 
Sick mean = [nan] ± [nan] sem = [nan]
---------
knee phase excursion: Healthy mean = [nan] ± [nan] sem = [nan] 
Sick mean = [nan] ± [nan] sem = [nan]
---------
paw phase excursion: Healthy mean = [nan] ± [nan] sem = [nan] 
Sick mean = [nan] ± [nan] sem = [nan]


In [27]:
from scipy.stats import ttest_ind_from_stats

# Bellardita values (mean ± SD, n = 10 steps, for trot gait)
bellardita_hip_mean = 62
bellardita_hip_sd = 2
bellardita_knee_mean = 80
bellardita_knee_sd = 2
bellardita_paw_mean = 92
bellardita_paw_sd = 3
n_bellardita = 10

# Your values (mean ± SD, n = 3932)
your_hip_mean = 62.68777983
your_hip_sd = 9.55570075
your_knee_mean = 75.19756741
your_knee_sd = 4.02287821
your_paw_mean = 121.09001636
your_paw_sd = 13.85861924
n_your = 3932

# Perform t-tests
t_hip, p_hip = ttest_ind_from_stats(your_hip_mean, your_hip_sd, n_your,
                                    bellardita_hip_mean, bellardita_hip_sd, n_bellardita)

t_knee, p_knee = ttest_ind_from_stats(your_knee_mean, your_knee_sd, n_your,
                                      bellardita_knee_mean, bellardita_knee_sd, n_bellardita)

t_paw, p_paw = ttest_ind_from_stats(your_paw_mean, your_paw_sd, n_your,
                                    bellardita_paw_mean, bellardita_paw_sd, n_bellardita)

# Print results
print("Hip t-test: t = {:.3f}, p = {:.3e}".format(t_hip, p_hip))
if p_hip < 0.05:
    print(f"Hip phase excursion is significantly different from Bellardita's values. with a confidence level of {1 - p_hip:.2f}")
else:
    print("Hip phase excursion is not significantly different from Bellardita's values.")
print("Knee t-test: t = {:.3f}, p = {:.3e}".format(t_knee, p_knee))
if p_knee < 0.05:
    print(f"Knee phase excursion is significantly different from Bellardita's values. with a confidence level of {1 - p_knee:.2f}")
else:
    print("Knee phase excursion is not significantly different from Bellardita's values.")
print("Paw t-test: t = {:.3f}, p = {:.3e}".format(t_paw, p_paw))
if p_paw < 0.05:
    print(f"Paw phase excursion is significantly different from Bellardita's values. with a confidence level of {1 - p_paw:.2f}")
else:
    print("Paw phase excursion is not significantly different from Bellardita's values.")

Hip t-test: t = 0.228, p = 8.200e-01
Hip phase excursion is not significantly different from Bellardita's values.
Knee t-test: t = -3.774, p = 1.633e-04
Knee phase excursion is significantly different from Bellardita's values. with a confidence level of 1.00
Paw t-test: t = 6.637, p = 3.647e-11
Paw phase excursion is significantly different from Bellardita's values. with a confidence level of 1.00


In [28]:
fig = px.scatter_3d(
    healthy_mean_features,
    x='Mean stance duration (s)',
    y='Mean swing duration (s)',
    z='Mean step velocity X at stance - lHindpaw (m/s)',
    color='Mouse',
    hover_name='Mouse',
    title="Healthy Mean Features: Stance Duration vs Swing Duration vs Velocity",
    labels={
        'Mean stance duration (s)': 'Average Stance Duration (s)',
        'Mean swing duration (s)': 'Average Swing Duration (s)',
        'Mean step velocity X at stance - lHindpaw (m/s)': 'Average Step Velocity X at Stance (m/s)'
    },
    template="plotly_white"
)
fig.update_traces(marker=dict(size=3, line=dict(width=SCATTER_LINE_WIDTH, color='black'), color = 'blue'))
fig.update_layout(
    legend=dict(font=dict(size=LEGEND_FONT_SIZE)),
    title=dict(font=dict(size=TITLE_FONT_SIZE)),
    scene=dict(
        xaxis_title='Average Stance Duration (s)',
        yaxis_title='Average Swing Duration (s)',
        zaxis_title='Average Step Velocity X at Stance (m/s)',
        xaxis=dict(tickfont=dict(size=AXIS_FONT_SIZE)),
        yaxis=dict(tickfont=dict(size=AXIS_FONT_SIZE)),
        zaxis=dict(tickfont=dict(size=AXIS_FONT_SIZE)),
    )
)
fig.show()

In [29]:
# Linear regression for healthy group with intercept fixed at 0
x = healthy_mean_features[x_feat].values
y = healthy_mean_features[y_feat].values
if len(x) > 1 and len(y) > 1:
    # Remove NaNs
    mask = ~np.isnan(x) & ~np.isnan(y)
    if mask.sum() > 1:
        x_clean = x[mask]
        y_clean = y[mask]
        
        # Fit linear regression with intercept=0
        # Use np.linalg.lstsq for regression through origin
        slope = np.sum(x_clean * y_clean) / np.sum(x_clean * x_clean)
        
        # Alternative using sklearn:
        # from sklearn.linear_model import LinearRegression
        # model = LinearRegression(fit_intercept=False)
        # model.fit(x_clean.reshape(-1, 1), y_clean)
        # slope = model.coef_[0]
        
        x_line = np.linspace(x_clean.min(), x_clean.max(), 100)
        y_line = slope * x_line  # No intercept term
        
        fig.add_trace(
            go.Scatter(
                x=x_line,
                y=y_line,
                mode='lines',
                line=dict(color='red', width=2, dash='dash'),
                showlegend=False
            ),
            row=i+1, col=j+1
        )
        if j + 1 == 3:
            print(f"Healthy linear regression for {y_feat} vs {x_feat}: y = {slope:.2f}x")

In [30]:
### Training block
model = None
if BEST_MODEL_PATH:
    model = load_model(model_path=BEST_MODEL_PATH, input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
    load_and_plot_losses(losses_file=BEST_LOSS_PATH,
                        title_font_size=TITLE_FONT_SIZE,
                        axis_title_font_size=AXIS_TITLE_FONT_SIZE,
                        legend_font_size=LEGEND_FONT_SIZE)
else:
    model, train_losses, val_losses = train_model(
        step_tensor=step_tensor,
        lengths=lengths,
        input_dim=INPUT_DIM,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCHS,
        lr=LR,
        patience=PATIENCE,
        min_delta=MIN_DELTA,
        models_dir=MODELS_DIR
    )
    save_model_losses(train_losses, val_losses, figures_dir=FIGURES_DIR,
                        title_font_size=TITLE_FONT_SIZE,
                        axis_title_font_size=AXIS_FONT_SIZE,
                        legend_font_size=LEGEND_FONT_SIZE)

assert model is not None, "Model training failed or no model was loaded."

VAE_t device: cuda
Loaded model from epoch 1042 with validation loss 8.1214


RuntimeError: Error(s) in loading state_dict for LSTMVAE_t:
	size mismatch for encoder_lstm.weight_ih_l0: copying a param with shape torch.Size([256, 196]) from checkpoint, the shape in current model is torch.Size([256, 244]).
	size mismatch for output_layer.weight: copying a param with shape torch.Size([196, 64]) from checkpoint, the shape in current model is torch.Size([244, 64]).
	size mismatch for output_layer.bias: copying a param with shape torch.Size([196]) from checkpoint, the shape in current model is torch.Size([244]).

In [None]:
### Evaluate the model on pathological steps
selected_steps = [
    s for s in segmented_hindsteps 
    if SIDE_KEY[0] in s["mouse"] 
]

step_tensor_all, lengths_all = steps_to_tensor(selected_steps, scaler)

## Get all embeddings for selected steps
model.eval()
with torch.no_grad():
    x_hat, mu_t, logvar_t  = model(step_tensor_all, lengths_all)  # (B, T, F), (B, T, L), (B, T, L)
    z_summary = mu_t.mean(dim=1)  # → shape: (B, L)

reducer = UMAP(random_state=42, n_components=3)
embedding = reducer.fit_transform(z_summary.cpu().numpy())  # shape: (B, 3)

# Convert each step’s latent sequence to list of (T_i, L)
# mu_t_masked = [
#     mu_t[i, :lengths_all[i]].cpu().numpy()  # shape: (T_i, L)
#     for i in range(mu_t.shape[0])
# ]

# Stack all latent vectors across all steps
# all_latents = np.concatenate(mu_t_masked, axis=0)  # shape: (sum of T_i, L)

# # Fit UMAP
# umap_coords = UMAP(n_components=3, random_state=42).fit_transform(all_latents)

# ## Now split the coordinates back by original steps
# # for each step, have all the umap coordinates
# umap_split = []
# idx = 0
# for i in range(len(mu_t_masked)):
#     length = mu_t_masked[i].shape[0]
#     umap_split.append(umap_coords[idx:idx + length])  # shape: (T_i, 3)
#     idx += length

In [None]:
datasets = sorted(set(s["group"] for s in selected_steps))
color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}
# if pre is in group, use the first color, otherwise use the second color
for i, ds in enumerate(datasets):
    if HEALTHY_KEY in ds:
        color_map[ds] = plotly.colors.qualitative.Plotly[0]
    else:
        color_map[ds] = plotly.colors.qualitative.Plotly[1]

fig = go.Figure()
legend_shown = defaultdict(bool)

for i, (x, y, z) in enumerate(embedding):
    step_meta = selected_steps[i]
    dataset = step_meta["group"]
    mouse = step_meta["mouse"]
    run = step_meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=[x], y=[y], z=[z],
        mode='markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        marker=dict(size=SCATTER_SIZE, color=color, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run}"]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time by Dataset",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()

In [None]:
## TODO: make the 3D plot with trajectories for both pre and post sci
sample_idx = torch.randperm(mu_t.shape[0])[:100]
mu_t_sample = mu_t[sample_idx]
lengths_sample = lengths_all[sample_idx]
selected_sample = [selected_steps[i] for i in sample_idx.tolist()]

mu_t = mu_t_sample

B, T, L = mu_t.shape

# Flatten all time steps for UMAP
mu_t_flat = mu_t.reshape(-1, L).cpu().numpy()  # (B*T, L)

# Reduce L → 3 using UMAP
embedding_flat = reducer.fit_transform(mu_t_flat)  # (B*T, 3)

# Reshape embedding back to (B, T, 3)
embedding_3d = embedding_flat.reshape(B, T, 3)

print("Moving to 3D plotting...")

# Plot each step as a 3D trajectory
fig = go.Figure()
legend_shown = defaultdict(bool)

for i in range(B):
    emb = embedding_3d[i, :lengths_all[i]]  # shape: T × 3
    meta = selected_steps[i]
    dataset = meta["group"]
    mouse = meta["mouse"]
    run = meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=emb[:, 0],
        y=emb[:, 1],
        z=emb[:, 2],
        mode='lines+markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        marker=dict(size=SCATTER_SIZE, color=color, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run} | t={t}" for t in range(len(emb))]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time (per Step)",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()

In [None]:
def plot_timewise_umap_trajectories(segmented_hindsteps, scaler):
    """
    Plot UMAP projection of all timepoints in each step.
    Each step becomes a trajectory (line) in 3D UMAP space.
    Mice from the same dataset (e.g., 'pre', 'post') share color and legend.
    """
    from collections import defaultdict

    filtered_steps = [s for s in segmented_hindsteps if SIDE_KEY[0] in s["mouse"]]

    step_id = 0
    timepoints = []
    metadata = []

    for s in filtered_steps:
        arr = scaler.transform(s["step"].values)  # (T, D)
        for t in range(arr.shape[0]):
            timepoints.append(arr[t])
            metadata.append({
                "step_id": step_id,
                "t": t,
                "mouse": s["mouse"],
                "group": s["group"],
                "run": s["run"]
            })
        step_id += 1

    X = np.stack(timepoints)
    umap_coords = UMAP(n_components=3, random_state=42).fit_transform(X)

    umap_df = pd.DataFrame(umap_coords, columns=["UMAP1", "UMAP2", "UMAP3"])
    umap_df = pd.concat([umap_df, pd.DataFrame(metadata)], axis=1)

    datasets = sorted(umap_df["group"].unique())
    color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}

    fig = go.Figure()
    legend_shown = defaultdict(bool)

    for step_id, group in umap_df.groupby("step_id"):
        dataset = group["group"].iloc[0]
        color = color_map[dataset]
        show_legend = not legend_shown[dataset]
        legend_shown[dataset] = True

        fig.add_trace(go.Scatter3d(
            x=group["UMAP1"],
            y=group["UMAP2"],
            z=group["UMAP3"],
            mode="lines+markers",
            line=dict(color=color, width=2),
            marker=dict(size=3, color=color),
            name=dataset,
            legendgroup=dataset,
            showlegend=show_legend,
            text=[f"{group['mouse'].iloc[0]} | run={r} | t={t}" for r, t in zip(group["run"], group["t"])]
        ))

    fig.update_layout(
        title="Time-Resolved UMAP of Step Dynamics by Dataset",
        scene=dict(
            xaxis_title="UMAP1",
            yaxis_title="UMAP2",
            zaxis_title="UMAP3"
        ),
        legend=dict(title="Dataset", font=dict(size=12)),
        width=900,
        height=700,
        template="plotly_white"
    )

    fig.show()

# Usage:
plot_timewise_umap_trajectories(segmented_hindsteps, scaler)