In [1]:
import os
import pandas as pd
import numpy as np
import torch
import plotly.express as px
import plotly.graph_objs as go
import plotly.colors
from collections import defaultdict

from sklearn.preprocessing import StandardScaler
from umap import UMAP

from src.ModelUtils import (load_model,
                            train_model,
                            save_model_losses,
                            load_and_plot_losses)
from src.DataUtils import (load_data_from_directory,
                           segment_all_steps,
                           steps_to_tensor,
                           reshape_data)
from src.Plot import (plot_step_trajectory, plot_trajectory,
                      plot_animated_trajectory, plot_trajectory_with_joint_traces,
                      compare_step_features_in_batches, plot_umap_from_step,
                      plot_umap_all_steps, plot_mean_spatial_trajectory,
                      compare_phase_aligned_average_single, compare_phase_aligned_average_xy,
                      plot_phase_aligned_average_single, plot_phase_aligned_average_xy,
                      plot_trajectory_with_joint_trace,
                      compare_spatial_progression_over_time, compare_spatial_progression_xy_over_time,
                      compare_spatial_angle_progression_over_time)

In [2]:
## Directories and constants
FIGURES_DIR = './figures/SCI'
MODELS_DIR = './src/models'
DATA_DIR = './csv/SCI_nimo_placebo'
HEALTHY_KEY = 'Pre'
SICK_KEY = 'Post'
SIDE_KEY = ('Left','Right')

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 5000
LR = 1e-3
PATIENCE = 100 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = os.path.join(MODELS_DIR, 'lstm_VAE_t_SCI_20250609_200452.pt')
BEST_LOSS_PATH = os.path.join(FIGURES_DIR, 'SCI_losses.csv')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = load_data_from_directory(DATA_DIR)

# Print the data structure
for datagroup, mice in data.items():
    print(f"Group: {datagroup}")
    for mouse_direction, runs in mice.items():
        print(f"\t{mouse_direction}: {len(runs)} runs with shapes: {[df.shape for df in runs.values()]}")

Group: SCI_PostSCI
	Left_Mouse164: 3 runs with shapes: [(244, 199), (124, 199), (136, 199)]
	Left_Mouse174: 3 runs with shapes: [(671, 199), (561, 199), (531, 199)]
	Left_Mouse178: 3 runs with shapes: [(579, 199), (701, 199), (631, 199)]
	Left_Mouse180: 1 runs with shapes: [(601, 199)]
	Left_Mouse42: 4 runs with shapes: [(301, 199), (304, 199), (459, 199), (220, 199)]
	Left_Mouse44: 4 runs with shapes: [(176, 199), (171, 199), (134, 199), (161, 199)]
	Left_Mouse46: 4 runs with shapes: [(201, 199), (135, 199), (101, 199), (120, 199)]
	Left_Mouse50: 3 runs with shapes: [(374, 199), (711, 199), (701, 199)]
	Left_Mouse82: 4 runs with shapes: [(260, 199), (167, 199), (201, 199), (253, 199)]
	Left_Mouse84: 1 runs with shapes: [(420, 199)]
	Left_Mouse86: 4 runs with shapes: [(149, 199), (201, 199), (251, 199), (149, 199)]
	Left_Mouse90: 3 runs with shapes: [(391, 199), (377, 199), (421, 199)]
Group: SCI_PreSCI
	Left_Mouse160: 3 runs with shapes: [(424, 199), (301, 199), (231, 199)]
	Left_Mous

In [3]:
# Print the features count
for datagroup, mice in data.items():
    for mouse_direction, runs in mice.items():
        for run, df in runs.items():
            print(f"Number of features: {df.shape[1] + 1 - 4}") # +1 bc of index, -4 to exclude frame, forestep, hindstep, and time
            break
        break
    break

Number of features: 196


In [4]:
## Segment all steps in the data
# This will create two dictionaries: segmented_hindsteps and segmented_foresteps
# Each dictionary will contain segmented steps for each mouse direction and run
    # "step": step_df,
    # "group": group,
    # "mouse": mouse,
    # "run": run
segmented_hindsteps, segmented_foresteps = segment_all_steps(data)

reshaped_data = reshape_data(data) # reshape the data to have a single DataFrame for each mouse direction and run and access to the hindsteps and foresteps

## Flatten all steps into a single array to compute global mean/std and get the scaler
all_healthy_arrays = [step_dict["step"].values for step_dict in segmented_hindsteps if HEALTHY_KEY in step_dict["group"] and SIDE_KEY[0] in step_dict["mouse"]]
flat_data = np.vstack(all_healthy_arrays)

scaler = StandardScaler()
scaler = scaler.fit(flat_data)

## Prepare the data for training
healthy_steps = [s for s in segmented_hindsteps if HEALTHY_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
unhealthy_steps = [s for s in segmented_hindsteps if SICK_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
step_tensor, lengths = steps_to_tensor(healthy_steps, scaler)
if INPUT_DIM is None:
    INPUT_DIM = step_tensor.shape[2]

In [5]:
print(f'Segmented dataframe shhape: {segmented_hindsteps[0]["step"].shape}')
#print(f'Averaged dataframe shhape: {averaged_hindsteps[0]["step"].shape}')
print(f"Step tensor shape: {step_tensor.shape}, \nLengths shape: {lengths.shape}")

Segmented dataframe shhape: (56, 196)
Step tensor shape: torch.Size([308, 134, 196]), 
Lengths shape: torch.Size([308])


In [6]:
### Plot XY pose data
HINDLIMB_KEY = 'hindlimb'
FORELIMB_KEY = 'forelimb'
SPINE_KEY = 'spine'
TAIL_KEY = 'tail'
POSE_KEY = 'pose'
ANGLE_KEY = 'Angle'
ALL_LIMBS = [HINDLIMB_KEY, FORELIMB_KEY, SPINE_KEY, TAIL_KEY]
first_df = next(iter(segmented_hindsteps))["step"]

In [7]:
## Extract poses for each limb
hindlimb_keys = [key for key in first_df.columns if HINDLIMB_KEY in key and POSE_KEY in key]
forelimb_keys = [key for key in first_df.columns if FORELIMB_KEY in key and POSE_KEY in key]
spine_keys = [key for key in first_df.columns if SPINE_KEY in key and POSE_KEY in key]
tail_keys = [key for key in first_df.columns if TAIL_KEY in key and POSE_KEY in key]

all_limbs_keys = [hindlimb_keys, forelimb_keys, spine_keys, tail_keys]
for limb in all_limbs_keys:
    print(f"Limbs poses: {limb}")

Limbs poses: ['rhindlimb hip - X pose (m)', 'rhindlimb hip - Y pose (m)', 'rhindlimb knee - X pose (m)', 'rhindlimb knee - Y pose (m)', 'rhindlimb ankle - X pose (m)', 'rhindlimb ankle - Y pose (m)', 'rhindlimb lHindpaw - X pose (m)', 'rhindlimb lHindpaw - Y pose (m)', 'rhindlimb lHindfingers - X pose (m)', 'rhindlimb lHindfingers - Y pose (m)']
Limbs poses: ['rforelimb shoulder - X pose (m)', 'rforelimb shoulder - Y pose (m)', 'rforelimb elbow - X pose (m)', 'rforelimb elbow - Y pose (m)', 'rforelimb wrist - X pose (m)', 'rforelimb wrist - Y pose (m)', 'rforelimb lForepaw - X pose (m)', 'rforelimb lForepaw - Y pose (m)']
Limbs poses: ['spine head - X pose (m)', 'spine head - Y pose (m)', 'spine spine 1 - X pose (m)', 'spine spine 1 - Y pose (m)', 'spine spine 2 - X pose (m)', 'spine spine 2 - Y pose (m)', 'spine spine 3 - X pose (m)', 'spine spine 3 - Y pose (m)', 'spine spine 4 - X pose (m)', 'spine spine 4 - Y pose (m)', 'spine base - X pose (m)', 'spine base - Y pose (m)']
Limbs po

In [8]:
## Extract angles for each limb
hindlimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and HINDLIMB_KEY in key]
forelimb_angles = [key for key in first_df.columns if ANGLE_KEY in key and FORELIMB_KEY in key]
spine_angles = [key for key in first_df.columns if ANGLE_KEY in key and SPINE_KEY in key]
tail_angles = [key for key in first_df.columns if ANGLE_KEY in key and TAIL_KEY in key]

all_limbs_angles = [hindlimb_angles, forelimb_angles, spine_angles, tail_angles]
for limb in all_limbs_angles:
    print(f"Limbs angles: {limb}")

Limbs angles: ['Angle - rhindlimb - hip (°)', 'Angle - rhindlimb - knee (°)', 'Angle - rhindlimb - ankle (°)', 'Angle - rhindlimb - lHindpaw (°)', 'Angle - rhindlimb - lHindfingers (°)', 'Angle velocity - rhindlimb - hip (rad/s)', 'Angle velocity - rhindlimb - knee (rad/s)', 'Angle velocity - rhindlimb - ankle (rad/s)', 'Angle velocity - rhindlimb - lHindpaw (rad/s)', 'Angle velocity - rhindlimb - lHindfingers (rad/s)', 'Angle acceleration - rhindlimb - hip (rad/s^2)', 'Angle acceleration - rhindlimb - knee (rad/s^2)', 'Angle acceleration - rhindlimb - ankle (rad/s^2)', 'Angle acceleration - rhindlimb - lHindpaw (rad/s^2)', 'Angle acceleration - rhindlimb - lHindfingers (rad/s^2)']
Limbs angles: ['Angle - rforelimb - shoulder (°)', 'Angle - rforelimb - elbow (°)', 'Angle - rforelimb - wrist (°)', 'Angle - rforelimb - lForepaw (°)', 'Angle velocity - rforelimb - shoulder (rad/s)', 'Angle velocity - rforelimb - elbow (rad/s)', 'Angle velocity - rforelimb - wrist (rad/s)', 'Angle velocity

In [9]:
## Extract CoM features
CoM_features = [key for key in first_df.columns if 'CoM' in key in key]
print(f"CoM features: {CoM_features}")

CoM features: ['rforelimb - CoM X (m)', 'rforelimb - CoM Y (m)', 'rhindlimb - CoM X (m)', 'rhindlimb - CoM Y (m)', 'spine - CoM X (m)', 'spine - CoM Y (m)', 'tail - CoM X (m)', 'tail - CoM Y (m)']


In [10]:
## Extract relevant single point features
hindfinger_features = [key for key in first_df.columns if 'hindfingers' in key.lower() and 'Angle' not in key]
hindpaw_features = [key for key in first_df.columns if 'hindpaw' in key.lower() and 'Angle' not in key]
forepaw_features = [key for key in first_df.columns if 'forepaw' in key.lower() and 'Angle' not in key]
head_features = [key for key in first_df.columns if 'head' in key.lower() and 'Angle' not in key]

single_point_features = [hindfinger_features, hindpaw_features, forepaw_features, head_features]
single_point_feature_names = ['hindfingers', 'hindpaw', 'forepaw', 'head']
for features in single_point_features:
    print(f"Single point features: {features}")

Single point features: ['rhindlimb lHindfingers - X pose (m)', 'rhindlimb lHindfingers - X velocity (m/s)', 'rhindlimb lHindfingers - X acceleration (m/s^2)', 'rhindlimb lHindfingers - X jerk (m/s^3)', 'rhindlimb lHindfingers - Y pose (m)', 'rhindlimb lHindfingers - Y velocity (m/s)', 'rhindlimb lHindfingers - Y acceleration (m/s^2)', 'rhindlimb lHindfingers - Y jerk (m/s^3)']
Single point features: ['rhindlimb lHindpaw - X pose (m)', 'rhindlimb lHindpaw - X velocity (m/s)', 'rhindlimb lHindpaw - X acceleration (m/s^2)', 'rhindlimb lHindpaw - X jerk (m/s^3)', 'rhindlimb lHindpaw - Y pose (m)', 'rhindlimb lHindpaw - Y velocity (m/s)', 'rhindlimb lHindpaw - Y acceleration (m/s^2)', 'rhindlimb lHindpaw - Y jerk (m/s^3)']
Single point features: ['rforelimb lForepaw - X pose (m)', 'rforelimb lForepaw - X velocity (m/s)', 'rforelimb lForepaw - X acceleration (m/s^2)', 'rforelimb lForepaw - X jerk (m/s^3)', 'rforelimb lForepaw - Y pose (m)', 'rforelimb lForepaw - Y velocity (m/s)', 'rforelimb

In [None]:
## Save all angle plots
savedir = os.path.join(FIGURES_DIR, 'Angles')
for limb, angles in zip(ALL_LIMBS, all_limbs_angles):
    plot_phase_aligned_average_single(healthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_time_locked_{limb}.svg')) # Plot phase locked for one state
    compare_phase_aligned_average_single(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_time_locked_{limb}.svg')) # Plot phase locked for two states
    compare_spatial_angle_progression_over_time(healthy_steps, unhealthy_steps, feature_keys=angles, figure_path=os.path.join(savedir, f'Angles_healthy_vs_SCI_space_locked_{limb}.svg')) # Plot space locked for two states

In [None]:
## Save all CoM plots
savedir = os.path.join(FIGURES_DIR, 'CoM')

plot_phase_aligned_average_xy(healthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_time_locked.svg')) # Plot phase locked for one state
compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=CoM_features, figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=CoM_features, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, 'CoM_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states

In [None]:
## Save all XY pose plots
savedir = os.path.join(FIGURES_DIR, 'Kinematics')

for name, feature in zip(single_point_feature_names, single_point_features):
    plot_phase_aligned_average_xy(healthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_time_locked.svg')) # Plot phase locked for one state
    compare_phase_aligned_average_xy(healthy_steps, unhealthy_steps, feature_keys=feature, figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_time_locked.svg')) # Plot phase locked for two states
    compare_spatial_progression_xy_over_time(healthy_steps, unhealthy_steps, feature_keys=feature, length_key='rhindlimb lHindfingers - X pose (m)', height_key='rhindlimb lHindfingers - Y pose (m)', figure_path=os.path.join(savedir, f'{name}_healthy_vs_SCI_space_locked.svg')) # Plot space locked for two states

In [None]:
forelimb_keys = [key for key in first_df.columns if 'forelimb' in key.lower() and POSE_KEY in key]
# Plot stick figure forelimb (bottom) and hindlimb (upper) trajectories
savedir = os.path.join(FIGURES_DIR, 'Trajectories/Full')
plot_trajectory_with_joint_traces(reshaped_data, forelimb_keys, hindlimb_keys, figure_path=savedir)

In [None]:
# time is nornmalized by step
plot_mean_spatial_trajectory(segmented_hindsteps) 

In [None]:
plot_umap_from_step(segmented_hindsteps[0])

In [None]:
plot_umap_all_steps(healthy_steps)

In [None]:
plot_umap_all_steps(unhealthy_steps)

In [None]:
compare_step_features_in_batches(segmented_hindsteps[0])

In [None]:
plot_step_trajectory(segmented_hindsteps, hindlimb_keys)

In [None]:
plot_trajectory_with_joint_trace(reshaped_data, hindlimb_keys)

In [None]:
com_keys = [key for key in first_df.columns if 'CoM' in key]
print(f"CoM keys: {com_keys}")
plot_trajectory_with_joint_trace(reshaped_data, com_keys)

In [None]:
plot_animated_trajectory(reshaped_data, hindlimb_keys)

In [None]:
plot_trajectory(reshaped_data, hindlimb_keys)

In [None]:
## Plt one step for xy pose
compare_spatial_progression_over_time(healthy_steps, unhealthy_steps, finger_pose_key='rhindlimb lHindfingers - X pose (m)')

In [None]:
### Training block
model = None
if BEST_MODEL_PATH:
    model = load_model(model_path=BEST_MODEL_PATH, input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
    load_and_plot_losses(losses_file=BEST_LOSS_PATH,
                        title_font_size=TITLE_FONT_SIZE,
                        axis_title_font_size=AXIS_TITLE_FONT_SIZE,
                        legend_font_size=LEGEND_FONT_SIZE)
else:
    model, train_losses, val_losses = train_model(
        step_tensor=step_tensor,
        lengths=lengths,
        input_dim=INPUT_DIM,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCHS,
        lr=LR,
        patience=PATIENCE,
        min_delta=MIN_DELTA,
        models_dir=MODELS_DIR
    )
    save_model_losses(train_losses, val_losses, figures_dir=FIGURES_DIR,
                        title_font_size=TITLE_FONT_SIZE,
                        axis_title_font_size=AXIS_FONT_SIZE,
                        legend_font_size=LEGEND_FONT_SIZE)

assert model is not None, "Model training failed or no model was loaded."

In [None]:
### Evaluate the model on pathological steps
selected_steps = [
    s for s in segmented_hindsteps 
    if SIDE_KEY[0] in s["mouse"] 
]

step_tensor_all, lengths_all = steps_to_tensor(selected_steps, scaler)

## Get all embeddings for selected steps
model.eval()
with torch.no_grad():
    x_hat, mu_t, logvar_t  = model(step_tensor_all, lengths_all)  # (B, T, F), (B, T, L), (B, T, L)
    z_summary = mu_t.mean(dim=1)  # → shape: (B, L)

reducer = UMAP(random_state=42, n_components=3)
embedding = reducer.fit_transform(z_summary.cpu().numpy())  # shape: (B, 3)

# Convert each step’s latent sequence to list of (T_i, L)
# mu_t_masked = [
#     mu_t[i, :lengths_all[i]].cpu().numpy()  # shape: (T_i, L)
#     for i in range(mu_t.shape[0])
# ]

# Stack all latent vectors across all steps
# all_latents = np.concatenate(mu_t_masked, axis=0)  # shape: (sum of T_i, L)

# # Fit UMAP
# umap_coords = UMAP(n_components=3, random_state=42).fit_transform(all_latents)

# ## Now split the coordinates back by original steps
# # for each step, have all the umap coordinates
# umap_split = []
# idx = 0
# for i in range(len(mu_t_masked)):
#     length = mu_t_masked[i].shape[0]
#     umap_split.append(umap_coords[idx:idx + length])  # shape: (T_i, 3)
#     idx += length

In [None]:
datasets = sorted(set(s["group"] for s in selected_steps))
color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}
# if pre is in group, use the first color, otherwise use the second color
for i, ds in enumerate(datasets):
    if HEALTHY_KEY in ds:
        color_map[ds] = plotly.colors.qualitative.Plotly[0]
    else:
        color_map[ds] = plotly.colors.qualitative.Plotly[1]

fig = go.Figure()
legend_shown = defaultdict(bool)

for i, (x, y, z) in enumerate(embedding):
    step_meta = selected_steps[i]
    dataset = step_meta["group"]
    mouse = step_meta["mouse"]
    run = step_meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=[x], y=[y], z=[z],
        mode='markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        marker=dict(size=SCATTER_SIZE, color=color, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run}"]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time by Dataset",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()

In [None]:
## TODO: make the 3D plot with trajectories for both pre and post sci
sample_idx = torch.randperm(mu_t.shape[0])[:100]
mu_t_sample = mu_t[sample_idx]
lengths_sample = lengths_all[sample_idx]
selected_sample = [selected_steps[i] for i in sample_idx.tolist()]

mu_t = mu_t_sample

B, T, L = mu_t.shape

# Flatten all time steps for UMAP
mu_t_flat = mu_t.reshape(-1, L).cpu().numpy()  # (B*T, L)

# Reduce L → 3 using UMAP
embedding_flat = reducer.fit_transform(mu_t_flat)  # (B*T, 3)

# Reshape embedding back to (B, T, 3)
embedding_3d = embedding_flat.reshape(B, T, 3)

print("Moving to 3D plotting...")

# Plot each step as a 3D trajectory
fig = go.Figure()
legend_shown = defaultdict(bool)

for i in range(B):
    emb = embedding_3d[i, :lengths_all[i]]  # shape: T × 3
    meta = selected_steps[i]
    dataset = meta["group"]
    mouse = meta["mouse"]
    run = meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=emb[:, 0],
        y=emb[:, 1],
        z=emb[:, 2],
        mode='lines+markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        marker=dict(size=SCATTER_SIZE, color=color, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run} | t={t}" for t in range(len(emb))]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time (per Step)",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()

In [None]:
def plot_timewise_umap_trajectories(segmented_hindsteps, scaler):
    """
    Plot UMAP projection of all timepoints in each step.
    Each step becomes a trajectory (line) in 3D UMAP space.
    Mice from the same dataset (e.g., 'pre', 'post') share color and legend.
    """
    from collections import defaultdict

    filtered_steps = [s for s in segmented_hindsteps if SIDE_KEY[0] in s["mouse"]]

    step_id = 0
    timepoints = []
    metadata = []

    for s in filtered_steps:
        arr = scaler.transform(s["step"].values)  # (T, D)
        for t in range(arr.shape[0]):
            timepoints.append(arr[t])
            metadata.append({
                "step_id": step_id,
                "t": t,
                "mouse": s["mouse"],
                "group": s["group"],
                "run": s["run"]
            })
        step_id += 1

    X = np.stack(timepoints)
    umap_coords = UMAP(n_components=3, random_state=42).fit_transform(X)

    umap_df = pd.DataFrame(umap_coords, columns=["UMAP1", "UMAP2", "UMAP3"])
    umap_df = pd.concat([umap_df, pd.DataFrame(metadata)], axis=1)

    datasets = sorted(umap_df["group"].unique())
    color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}

    fig = go.Figure()
    legend_shown = defaultdict(bool)

    for step_id, group in umap_df.groupby("step_id"):
        dataset = group["group"].iloc[0]
        color = color_map[dataset]
        show_legend = not legend_shown[dataset]
        legend_shown[dataset] = True

        fig.add_trace(go.Scatter3d(
            x=group["UMAP1"],
            y=group["UMAP2"],
            z=group["UMAP3"],
            mode="lines+markers",
            line=dict(color=color, width=2),
            marker=dict(size=3, color=color),
            name=dataset,
            legendgroup=dataset,
            showlegend=show_legend,
            text=[f"{group['mouse'].iloc[0]} | run={r} | t={t}" for r, t in zip(group["run"], group["t"])]
        ))

    fig.update_layout(
        title="Time-Resolved UMAP of Step Dynamics by Dataset",
        scene=dict(
            xaxis_title="UMAP1",
            yaxis_title="UMAP2",
            zaxis_title="UMAP3"
        ),
        legend=dict(title="Dataset", font=dict(size=12)),
        width=900,
        height=700,
        template="plotly_white"
    )

    fig.show()

# Usage:
plot_timewise_umap_trajectories(segmented_hindsteps, scaler)