# Analyzing time-dependent data

In [1]:
import os
import pandas as pd
import numpy as np
import plotly.express as px
import torch
import timeit

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from datetime import datetime
from umap import UMAP

from tqdm.notebook import tqdm

from src.Models import LSTMAutoencoder, LSTMVAE, LSTMVAE_t

In [2]:
## Directories and constants
FIGURES_DIR = './figures/'
MODELS_DIR = './src/models'
DATA_DIR = './csv/Pain_Plot_Features'
DATASETS = ['A', 'B', 'C', 'D', 'E']
GROUPS = ['pre', 'post']
DIRECTION = ['left', 'right']

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 500
LR = 1e-3
PATIENCE = 50 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = None#os.path.join(MODELS_DIR, 'lstm_VAE_no_first_last_20250609_121841.pt')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = {}
directory = os.listdir(DATA_DIR)
for file in directory:
    if file.endswith('.csv'):
        components = file.split('_')

        ## Assuming the file naming convention is:
        # dataset_group_mouse_direction_run.csv
        dataset = components[0]
        group = components[1]
        mouse = components[2]
        direction = components[3]
        run = components[4]

        datagroup = dataset + '_' + group
        if datagroup not in data:
            data[datagroup] = {}

        mouse_direction = mouse + '_' + direction
        if mouse_direction not in data[datagroup]:
            data[datagroup][mouse_direction] = {}

        data[datagroup][mouse_direction][run] = pd.read_csv(
            os.path.join(DATA_DIR, file), index_col=0)

In [3]:
## Print the data structure
for datagroup, mice in data.items():
    print(f"Group: {datagroup}")
    for mouse_direction, runs in mice.items():
        print(f"\t{mouse_direction}: {len(runs)} runs with shapes: {[df.shape for df in runs.values()]}")

Group: A_postDLC
	mouse1_left: 3 runs with shapes: [(310, 199), (313, 199), (297, 199)]
	mouse1_right: 4 runs with shapes: [(259, 199), (251, 199), (233, 199), (351, 199)]
	mouse2_left: 4 runs with shapes: [(209, 199), (216, 199), (124, 199), (155, 199)]
	mouse2_right: 3 runs with shapes: [(223, 199), (193, 199), (211, 199)]
	mouse3_left: 5 runs with shapes: [(249, 199), (231, 199), (241, 199), (318, 199), (244, 199)]
	mouse3_right: 4 runs with shapes: [(297, 199), (230, 199), (154, 199), (224, 199)]
	mouse4_left: 3 runs with shapes: [(269, 199), (188, 199), (221, 199)]
	mouse4_right: 3 runs with shapes: [(283, 199), (213, 199), (195, 199)]
	mouse5_left: 3 runs with shapes: [(301, 199), (210, 199), (321, 199)]
	mouse5_right: 3 runs with shapes: [(341, 199), (212, 199), (191, 199)]
	mouse6_left: 4 runs with shapes: [(302, 199), (157, 199), (297, 199), (201, 199)]
	mouse6_right: 4 runs with shapes: [(271, 199), (235, 199), (163, 199), (141, 199)]
	mouse7_left: 3 runs with shapes: [(172, 

In [4]:
# Print the features count
for datagroup, mice in data.items():
    for mouse_direction, runs in mice.items():
        for run, df in runs.items():
            print(f"Number of features: {df.shape[1] + 1 - 4}") # +1 bc of index, -4 to exclude frame, forestep, hindstep, and time
            break
        break
    break

Number of features: 196


In [5]:
def segment_steps_by_phase(df, phase_col="phase"):
    phases = df[phase_col]

    # 1. Find where phase changes
    changes = phases != phases.shift()
    change_points = df.index[changes]

    # 2. Get the phases at these change points
    change_phases = phases.loc[change_points].reset_index(drop=True)

    # 3. Identify "stance" → ... → "swing" → "stance" sequences
    segments = []
    i = 0
    while i < len(change_phases) - 2:
        if change_phases[i] == "stance":
            swing_found = False
            for j in range(i+1, len(change_phases)):
                if change_phases[j] == "swing":
                    swing_found = True
                elif swing_found and change_phases[j] == "stance":
                    # # Skip first and last segment
                    # if i == 0 or j == len(change_phases) - 1:
                    #     i = j - 1
                    #     break
                    start_idx = change_points[i]
                    end_idx = change_points[j] - 1
                    segment = df.loc[start_idx:end_idx].drop(columns=["Step Phase Forelimb", "Step Phase Hindlimb"], errors='ignore')
                    
                    ## Reset x to 0
                    pose_cols = [col for col in segment.columns if 'pose' in col]
                    segment[pose_cols] = segment[pose_cols] - segment[pose_cols].iloc[0]
                    segments.append(segment)
                    i = j - 1  # skip ahead
                    break
            else:
                # no closing stance; go to end
                start_idx = change_points[i]
                segment = df.loc[start_idx:].drop(columns=["Step Phase Forelimb", "Step Phase Hindlimb"], errors='ignore')
                ## Reset x to 0
                pose_cols = [col for col in segment.columns if 'pose' in col]
                segment[pose_cols] = segment[pose_cols] - segment[pose_cols].iloc[0]
                segments.append(segment)
                break
        i += 1

    return segments

segmented_hindsteps = []
segmented_foresteps = []
for group in data:
    for mouse in data[group]:
        for run in data[group][mouse]:
            df = data[group][mouse][run]
            to_drop = ['Frame', "Time (s)"]
            df = df.drop(columns=to_drop, errors='ignore')
            hindsteps = segment_steps_by_phase(df, phase_col="Step Phase Hindlimb")  # or the correct column name
            for step_df in hindsteps:
                segmented_hindsteps.append({
                    "step": step_df,
                    "group": group,
                    "mouse": mouse,
                    "run": run
                })
            foresteps = segment_steps_by_phase(df, phase_col="Step Phase Forelimb")  # or the correct column name
            for step_df in foresteps:
                segmented_foresteps.append({
                    "step": step_df,
                    "group": group,
                    "mouse": mouse,
                    "run": run
                })

In [6]:
# Flatten all steps into a single array to compute global mean/std
all_healthy_arrays = [step_dict["step"].values for step_dict in segmented_hindsteps if "pre" in step_dict["group"] and "left" in step_dict["mouse"]]
flat_data = np.vstack(all_healthy_arrays)

scaler = StandardScaler()
scaler = scaler.fit(flat_data)

In [7]:
def steps_to_tensor(step_dicts, scaler):
    """
        Convert a list of step dictionaries to a padded tensor of shape (num_steps, max_length, num_features).

        Returns:
            - A tensor of shape (num_steps, max_length, num_features) containing the scaled step data (B, T, F).
            - A tensor of lengths for each step indicating the actual length of each step (B).
    """
    step_arrays = [scaler.transform(sd["step"].values) for sd in step_dicts]
    lengths = [len(step) for step in step_arrays]
    max_len = max(lengths)
    dim = step_arrays[0].shape[1]

    padded = np.zeros((len(step_arrays), max_len, dim), dtype=np.float32)
    for i, arr in enumerate(step_arrays):
        padded[i, :len(arr)] = arr

    return torch.tensor(padded), torch.tensor(lengths)

healthy_steps = [s for s in segmented_hindsteps if "pre" in s["group"] and "left" in s["mouse"]]
step_tensor, lengths = steps_to_tensor(healthy_steps, scaler)

print(f"Step tensor shape: {step_tensor.shape}, \nLengths shape: {lengths.shape}")

Step tensor shape: torch.Size([560, 189, 196]), 
Lengths shape: torch.Size([560])


In [None]:
def masked_mse_loss(pred, target, lengths):
    """
        Compute the masked mean squared error loss between predicted and target tensors.
    """
    mask = torch.arange(target.size(1))[None, :].to(lengths.device) < lengths[:, None]
    mask = mask.unsqueeze(-1).expand_as(target)  # (B, T, F)
    mse = (pred - target) ** 2
    masked_mse = (mse * mask).sum() / mask.sum()
    return masked_mse

def masked_vae_loss(x_hat, x, lengths, mu, logvar, free_bits=0.1):
    """
    Masked VAE loss: reconstruction + KL divergence
    """
    B, T, F = x.shape
    mask = torch.arange(T, device=lengths.device)[None, :] < lengths[:, None]
    mask = mask.unsqueeze(-1).expand_as(x)

    recon_loss = ((x_hat - x) ** 2) * mask
    recon_loss = recon_loss.sum() / mask.sum()

    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / B

    kl_div = torch.clamp(kl_div, min=0.1)  # where free_bits ~ 0.1

    return recon_loss + kl_div

def masked_vae_t_loss(x_hat, x, lengths, mu_t, logvar_t, beta=1.0, free_bits=0.5):
    """
    Time-resolved VAE loss with masking support, per-unit free bits.
    """
    B, T, F = x.shape
    device = lengths.device

    # Compute mask
    mask = torch.arange(T, device=device)[None, :] < lengths[:, None]
    mask = mask.unsqueeze(-1).expand_as(x)  # (B, T, F)

    # Reconstruction loss
    recon_loss = ((x_hat - x) ** 2) * mask
    recon_loss = recon_loss.sum() / mask.sum()

    # KL divergence per latent unit
    kl_per_unit = -0.5 * (1 + logvar_t - mu_t.pow(2) - logvar_t.exp())  # (B, T, D)
    
    # Mask time steps (assume all D dimensions are valid per timepoint)
    time_mask = mask[..., 0]  # (B, T)

    # Clamp each latent unit to free bits
    kl_clamped = torch.clamp(kl_per_unit, min=free_bits)
    kl_loss = (kl_clamped * time_mask.unsqueeze(-1)).sum() / time_mask.sum()

    # Debug prints
    # kl_raw = (kl_per_unit * time_mask.unsqueeze(-1)).sum() / time_mask.sum()
    # tqdm.write(f"Reconstruction loss: {recon_loss.item():.4f}, Raw KL: {kl_raw.item():.4f}, Clamped KL: {kl_loss.item():.4f}, β: {beta:.3f}")

    return recon_loss + beta * kl_loss

class EarlyStopping:
    def __init__(self, patience=10, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.should_stop = False

    def step(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True

if INPUT_DIM is None:
    INPUT_DIM = step_tensor.shape[2]

if BEST_MODEL_PATH:
    print(f"Loading best model from {BEST_MODEL_PATH}")
    ## Load the best model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # model = LSTMAutoencoder(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)
    # model = LSTMVAE(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)
    model = LSTMVAE_t(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)
    checkpoint = torch.load(BEST_MODEL_PATH, map_location=device)
    print(f"Loaded model from epoch {checkpoint['epoch']} with validation loss {checkpoint['val_loss']:.4f}")
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
else:
    # Data loader
    train_idx, val_idx = train_test_split(np.arange(len(step_tensor)), test_size=0.2, random_state=42)
    train_data = TensorDataset(step_tensor[train_idx], lengths[train_idx])
    val_data = TensorDataset(step_tensor[val_idx], lengths[val_idx])

    def collate_fn(batch):
        """
            Custom collate function to handle variable-length sequences.
            Returns a padded tensor and lengths.
        """
        x, lengths = zip(*batch)
        # print(f"element shapes: {[xi.shape for xi in x]}, lengths: {lengths}")
        return torch.stack(x).to(device), torch.stack(lengths).to(device)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    # Model, optimizer
    # model = LSTMAutoencoder(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)
    # model = LSTMVAE(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)
    model = LSTMVAE_t(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    early_stopper = EarlyStopping(patience=PATIENCE, min_delta=MIN_DELTA)
    best_val_loss = float('inf')
    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_filname = f'lstm_autoencoder_{now}.pt'
    best_model_path = os.path.join(MODELS_DIR, f'lstm_autoencoder_{now}.pt')

    # Training loop
    val_losses = []
    train_losses = []
    WARMUP_EPOCHS = 20
    for epoch in tqdm(range(NUM_EPOCHS)):
        t1 = timeit.default_timer()
        beta = 1#min(1.0, epoch / WARMUP_EPOCHS)
        model.train()
        train_loss = 0
        for batch_x, batch_lens in train_loader:
            # batch_x, batch_lens = batch_x.to(device), batch_lens.to(device)
            # tqdm.write(f"Batch shape: {batch_x.shape}, Lengths shape: {batch_lens.shape}")
            # tqdm.write(f"Device of batch_x: {batch_x.device}, batch_lens: {batch_lens.device}")
            
            ## LSTMVAE
            # x_hat, mu, logvar = model(batch_x, batch_lens)
            # loss = masked_vae_loss(x_hat, batch_x, batch_lens, mu, logvar)
            
            ## LSTMVAE_t
            x_hat, mu_t, logvar_t = model(batch_x, batch_lens)
            loss = masked_vae_t_loss(x_hat, batch_x, batch_lens, mu_t, logvar_t, beta=beta)
            optimizer.zero_grad()
            ## LSTMAutoencoder
            # pred, _ = model(batch_x, batch_lens)
            # loss = masked_mse_loss(pred, batch_x, batch_lens)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_x, batch_lens in val_loader:
                # batch_x, batch_lens = batch_x.to(device), batch_lens.to(device)
                x_hat, mu_t, logvar_t = model(batch_x, batch_lens)
                loss = masked_vae_t_loss(x_hat, batch_x, batch_lens, mu_t, logvar_t, beta=beta)
                # x_hat, mu, logvar = model(batch_x, batch_lens)
                # loss = masked_vae_loss(x_hat, batch_x, batch_lens, mu, logvar)
                # pred, _ = model(batch_x, batch_lens)
                # loss = masked_mse_loss(pred, batch_x, batch_lens).item()
                val_loss += loss.item()
        avg_train = train_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        val_losses.append(avg_val)
        train_losses.append(avg_train)
        t2 = timeit.default_timer()
            
        # Save model if validation loss improves
        if avg_val < best_val_loss:
            tqdm.write(f"Epoch {epoch+1}, Train: {avg_train:.4f} - Val: {avg_val:.4f} - Time: {t2-t1:.2f}s")
            best_val_loss = avg_val
            best_epoch = epoch
            best_state = model.state_dict()

        early_stopper.step(avg_val)
        if early_stopper.should_stop:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

    torch.save({'model_state_dict':best_state,
                'epoch': epoch,
                'val_loss': best_val_loss},
                best_model_path)
    print(f"Best model saved at: {best_model_path} with val loss: {best_val_loss:.4f}")

    ## Save the losses in a plotly figure and a csv file
    losses_df = pd.DataFrame({
        'epoch': np.arange(len(train_losses)),
        'train_loss': train_losses,
        'val_loss': val_losses
    })
    losses_df.to_csv(os.path.join(FIGURES_DIR, 'losses.csv'), index=False)
    fig = px.line(losses_df, x='epoch', y=['train_loss', 'val_loss'],
                  labels={'value': 'Loss', 'epoch': 'Epoch'},
                  title='Training and Validation Losses')
    fig.update_layout(title_font_size=TITLE_FONT_SIZE,
                      xaxis_title_font_size=AXIS_TITLE_FONT_SIZE,
                      yaxis_title_font_size=AXIS_TITLE_FONT_SIZE,
                      legend_font_size=LEGEND_FONT_SIZE)
    fig.show()

Using device: cpu
VAE_t device: cpu


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1, Train: 9.0251 - Val: 8.8739 - Time: 11.01s
Epoch 11, Train: 8.6602 - Val: 8.5551 - Time: 7.46s
Epoch 21, Train: 8.4516 - Val: 8.4018 - Time: 7.62s
Epoch 31, Train: 8.3536 - Val: 8.3154 - Time: 7.01s
Epoch 41, Train: 8.2940 - Val: 8.2783 - Time: 6.86s
Epoch 51, Train: 8.2723 - Val: 8.2535 - Time: 6.86s
Epoch 61, Train: 8.2702 - Val: 8.2461 - Time: 6.95s
Epoch 71, Train: 8.2460 - Val: 8.2415 - Time: 6.93s
Epoch 81, Train: 8.2479 - Val: 8.2384 - Time: 9.39s
Epoch 91, Train: 8.2365 - Val: 8.2226 - Time: 7.00s
Epoch 101, Train: 8.2312 - Val: 8.2159 - Time: 6.59s
Epoch 111, Train: 8.2255 - Val: 8.2130 - Time: 6.58s
Epoch 121, Train: 8.2211 - Val: 8.2116 - Time: 7.48s
Epoch 131, Train: 8.2201 - Val: 8.2073 - Time: 6.63s
Epoch 141, Train: 8.2167 - Val: 8.2004 - Time: 7.75s
Epoch 151, Train: 8.2161 - Val: 8.2078 - Time: 6.97s
Epoch 161, Train: 8.2084 - Val: 8.1980 - Time: 7.72s
Epoch 171, Train: 8.2102 - Val: 8.1961 - Time: 6.85s
Epoch 181, Train: 8.2041 - Val: 8.1947 - Time: 7.28s
Epo

In [33]:
selected_steps = [
    s for s in segmented_hindsteps 
    if "left" in s["mouse"] 
]

step_tensor_all, lengths_all = steps_to_tensor(selected_steps, scaler)

## Get all embeddings for selected steps
model.eval()
with torch.no_grad():
    # _, embeddings_all = model(step_tensor_all, lengths_all) # Shape: (B, T, F) -> (B, L)
    # _, embeddings_all, _ = model(step_tensor_all, lengths_all)  # Shape: (B, L)
    _, mu_t, _ = model(step_tensor_all, lengths_all)  # (B, T, D)

# Convert each step’s latent sequence to list of (T_i, D)
mu_t_masked = [
    mu_t[i, :lengths_all[i]].cpu().numpy()  # shape: (T_i, D)
    for i in range(mu_t.shape[0])
]

In [35]:
### Plot UMAP over time

# Stack all latent vectors across all steps
all_latents = np.concatenate(mu_t_masked, axis=0)  # shape: (total_timepoints, D)

# Fit UMAP
umap_coords = UMAP(n_components=3, random_state=42).fit_transform(all_latents)

# Now split the coordinates back by original steps
umap_split = []
idx = 0
for i in range(len(mu_t_masked)):
    length = mu_t_masked[i].shape[0]
    umap_split.append(umap_coords[idx:idx + length])  # shape: (T_i, 3)
    idx += length

import plotly.graph_objs as go
import plotly.colors
from collections import defaultdict

datasets = sorted(set(s["group"] for s in selected_steps))
color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}
# if pre is in group, use the first color, otherwise use the second color
for i, ds in enumerate(datasets):
    if "pre" in ds:
        color_map[ds] = plotly.colors.qualitative.Plotly[0]
    else:
        color_map[ds] = plotly.colors.qualitative.Plotly[1]

fig = go.Figure()
legend_shown = defaultdict(bool)

for i, umap_seq in enumerate(umap_split):
    step_meta = selected_steps[i]
    dataset = step_meta["group"]
    mouse = step_meta["mouse"]
    run = step_meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=umap_seq[:, 0],
        y=umap_seq[:, 1],
        z=umap_seq[:, 2],
        mode='lines+markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        line=dict(width=2, color=color),
        marker=dict(size=2, color=color),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run} | t={t}" for t in range(len(umap_seq))]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time by Dataset",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [23]:
### Visualization of embeddings
umap_coords = UMAP(n_components=3, random_state=42).fit_transform(embeddings_all)

umap_df = pd.DataFrame({
    "UMAP1": umap_coords[:, 0],
    "UMAP2": umap_coords[:, 1],
    "UMAP3": umap_coords[:, 2],
    "Group": ["healthy" if "pre" in s["group"] else "unhealthy" for s in selected_steps],
    "Mouse": [s["mouse"] for s in selected_steps],
    "Run": [s["run"] for s in selected_steps],
    "Dataset": [s["group"] for s in selected_steps],
})


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



ValueError: Found array with dim 3. None expected <= 2.

In [None]:
fig = px.scatter_3d(
    umap_df,
    x="UMAP1", y="UMAP2", z="UMAP3",
    color="Dataset",#"Group",                
    hover_data=["Mouse", "Run", "Dataset"],
    title="3D UMAP Projection of Step Embeddings"
)

fig.update_traces(marker=dict(size=SCATTER_SIZE, line=dict(width=SCATTER_LINE_WIDTH, color='DarkSlateGrey')))
fig.update_layout(
    legend=dict(title="Group", font=dict(size=LEGEND_FONT_SIZE)), template="plotly_white",
    width=900,
    height=700,
    scene=dict(
            xaxis=dict(title_font=dict(size=AXIS_TITLE_FONT_SIZE), tickfont=dict(size=AXIS_FONT_SIZE)),
            yaxis=dict(title_font=dict(size=AXIS_TITLE_FONT_SIZE), tickfont=dict(size=AXIS_FONT_SIZE)),
            zaxis=dict(title_font=dict(size=AXIS_TITLE_FONT_SIZE), tickfont=dict(size=AXIS_FONT_SIZE)),
        ),
    title=dict(font=dict(size=TITLE_FONT_SIZE))
    )
fig.show()

In [None]:
import plotly.graph_objects as go
import plotly.colors

def plot_timewise_umap_trajectories(segmented_hindsteps, scaler):
    """
    Plot UMAP projection of all timepoints in each step.
    Each step becomes a trajectory (line) in 3D UMAP space.
    Mice from the same dataset (e.g., 'pre', 'post') share color and legend.
    """
    from collections import defaultdict

    filtered_steps = [s for s in segmented_hindsteps if "left" in s["mouse"]]

    step_id = 0
    timepoints = []
    metadata = []

    for s in filtered_steps:
        arr = scaler.transform(s["step"].values)  # (T, D)
        for t in range(arr.shape[0]):
            timepoints.append(arr[t])
            metadata.append({
                "step_id": step_id,
                "t": t,
                "mouse": s["mouse"],
                "group": s["group"],
                "run": s["run"]
            })
        step_id += 1

    X = np.stack(timepoints)
    umap_coords = UMAP(n_components=3, random_state=42).fit_transform(X)

    umap_df = pd.DataFrame(umap_coords, columns=["UMAP1", "UMAP2", "UMAP3"])
    umap_df = pd.concat([umap_df, pd.DataFrame(metadata)], axis=1)

    datasets = sorted(umap_df["group"].unique())
    color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}

    fig = go.Figure()
    legend_shown = defaultdict(bool)

    for step_id, group in umap_df.groupby("step_id"):
        dataset = group["group"].iloc[0]
        color = color_map[dataset]
        show_legend = not legend_shown[dataset]
        legend_shown[dataset] = True

        fig.add_trace(go.Scatter3d(
            x=group["UMAP1"],
            y=group["UMAP2"],
            z=group["UMAP3"],
            mode="lines+markers",
            line=dict(color=color, width=2),
            marker=dict(size=3, color=color),
            name=dataset,
            legendgroup=dataset,
            showlegend=show_legend,
            text=[f"{group['mouse'].iloc[0]} | run={r} | t={t}" for r, t in zip(group["run"], group["t"])]
        ))

    fig.update_layout(
        title="Time-Resolved UMAP of Step Dynamics by Dataset",
        scene=dict(
            xaxis_title="UMAP1",
            yaxis_title="UMAP2",
            zaxis_title="UMAP3"
        ),
        legend=dict(title="Dataset", font=dict(size=12)),
        width=900,
        height=700,
        template="plotly_white"
    )

    fig.show()

# Usage:
plot_timewise_umap_trajectories(segmented_hindsteps, scaler)