In [None]:
import os
import pandas as pd
import numpy as np
import torch
import plotly.express as px
import plotly.graph_objs as go
import plotly.colors
from collections import defaultdict

from sklearn.preprocessing import StandardScaler
from umap import UMAP

from src.ModelUtils import (load_model,
                            train_model,
                            save_model_losses,
                            load_and_plot_losses)
from src.DataUtils import (load_data_from_directory,
                           segment_all_steps,
                           steps_to_tensor,
                           average_run_features_by_mouse)
from src.PlotUtils import (plot_averaged_xy_trajectories_plotly)

In [None]:
## Directories and constants
FIGURES_DIR = './figures/SCI'
MODELS_DIR = './src/models'
DATA_DIR = './csv/SCI_nimo_placebo'
HEALTHY_KEY = 'Pre'
SICK_KEY = 'Post'
SIDE_KEY = ('Left','Right')

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 5000
LR = 1e-3
PATIENCE = 100 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = os.path.join(MODELS_DIR, 'lstm_VAE_t_SCI_20250609_200452.pt')
BEST_LOSS_PATH = os.path.join(FIGURES_DIR, 'SCI_losses.csv')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = load_data_from_directory(DATA_DIR)

# Print the data structure
for datagroup, mice in data.items():
    print(f"Group: {datagroup}")
    for mouse_direction, runs in mice.items():
        print(f"\t{mouse_direction}: {len(runs)} runs with shapes: {[df.shape for df in runs.values()]}")

In [None]:
# Print the features count
for datagroup, mice in data.items():
    for mouse_direction, runs in mice.items():
        for run, df in runs.items():
            print(f"Number of features: {df.shape[1] + 1 - 4}") # +1 bc of index, -4 to exclude frame, forestep, hindstep, and time
            break
        break
    break

In [None]:
## Segment all steps in the data
# This will create two dictionaries: segmented_hindsteps and segmented_foresteps
# Each dictionary will contain segmented steps for each mouse direction and run
    # "step": step_df,
    # "group": group,
    # "mouse": mouse,
    # "run": run
segmented_hindsteps, segmented_foresteps = segment_all_steps(data)

## Mean over the runs for each mouse
averaged_hindsteps = average_run_features_by_mouse(segmented_hindsteps)

## Flatten all steps into a single array to compute global mean/std and get the scaler
all_healthy_arrays = [step_dict["step"].values for step_dict in segmented_hindsteps if HEALTHY_KEY in step_dict["group"] and SIDE_KEY[0] in step_dict["mouse"]]
flat_data = np.vstack(all_healthy_arrays)

scaler = StandardScaler()
scaler = scaler.fit(flat_data)

## Prepare the data for training
healthy_steps = [s for s in segmented_hindsteps if HEALTHY_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
step_tensor, lengths = steps_to_tensor(healthy_steps, scaler)
if INPUT_DIM is None:
    INPUT_DIM = step_tensor.shape[2]

In [None]:
print(f'Segmented dataframe shhape: {segmented_hindsteps[0]["step"].shape}')
print(f'Averaged dataframe shhape: {averaged_hindsteps[0]["step"].shape}')
print(f"Step tensor shape: {step_tensor.shape}, \nLengths shape: {lengths.shape}")

In [None]:
### Plot XY pose data
## Extract hindlimb pose keys
HINDLIMB_KEY = 'rhindlimb'
POSE_KEY = 'pose'
first_df = next(iter(segmented_hindsteps))["step"]
hindlimb_keys = [key for key in first_df.columns if HINDLIMB_KEY in key and POSE_KEY in key]

print(f"Hindlimb pose keys: {hindlimb_keys}")

In [None]:
# i=0
# for step in averaged_hindsteps:
#     step_df = step["step"][hindlimb_keys]
#     mouse = step["mouse"]
#     group = step["group"]
#     if HEALTHY_KEY not in group:
#         continue
#     if i>10:
#         break
#     i+=1
    
#     fig = go.Figure()
#     x_keys = [key for key in hindlimb_keys if 'x' in key.lower()]
#     y_keys = [key for key in hindlimb_keys if 'y' in key.lower()]
    
#     # Get joint base names (without x/y)
#     joint_names = [x.split('_x')[0] for x in x_keys]
    
#     # For each frame, create a trace connecting all joints
#     for frame_idx, frame in enumerate(step_df.index):
#         frame_data = step_df.loc[frame]
        
#         # Create arrays for x, y positions of all joints
#         x_positions = [frame_data[x] for x in x_keys]
#         y_positions = [frame_data[y] for y in y_keys]
#         z_positions = [frame_idx] * len(x_keys)  # Same z (frame) for all joints
        
#         # Add trace for this frame
#         fig.add_trace(go.Scatter(
#             x=x_positions,
#             y=y_positions,
#             # z=z_positions,
#             mode='lines+markers',
#             name=f"Frame {frame}",
#             line=dict(width=SCATTER_LINE_WIDTH, color=f'rgba(100,100,255,{0.3 + 0.7*frame_idx/len(step_df)})'),
#             marker=dict(
#                 size=SCATTER_SIZE, 
#                 symbol=SCATTER_SYMBOL,
#                 color=f'rgba(100,100,255,{0.3 + 0.7*frame_idx/len(step_df)})'
#             ),
#             text=[f"{joint} - Frame {frame}" for joint in joint_names],
#             hoverinfo='text'
#         ))
    
#     # Add vertical lines connecting same joints across frames
#     for joint_idx in range(len(x_keys)):
#         # Get all x, y positions for this joint across frames
#         joint_x = [step_df.loc[frame, x_keys[joint_idx]] for frame in step_df.index]
#         joint_y = [step_df.loc[frame, y_keys[joint_idx]] for frame in step_df.index]
#         joint_z = list(range(len(step_df)))  # Frames as z values
        
#         fig.add_trace(go.Scatter(
#             x=joint_x,
#             y=joint_y,
#             # z=joint_z,
#             mode='lines',
#             line=dict(width=1, color='rgba(200,200,200,0.5)'),
#             showlegend=False
#         ))
        
#     fig.update_layout(
#         title=f"3D Hindlimb Pose Trajectory - {group} {mouse}",
#         scene=dict(
#             xaxis_title="X Position (m)",
#             yaxis_title="Y Position (m)",
#             # zaxis_title="Frame",
#             aspectmode='manual',
#             # aspectratio=dict(x=1, y=1, z=2)
#         ),
#         width=900,
#         height=700,
#         legend=dict(font=dict(size=LEGEND_FONT_SIZE)),
#         title_font=dict(size=TITLE_FONT_SIZE),
#         template='plotly_white'
#     )
    
#     fig.show()

In [None]:
### Training block
model = None
if BEST_MODEL_PATH:
    model = load_model(model_path=BEST_MODEL_PATH, input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
    load_and_plot_losses(losses_file=BEST_LOSS_PATH,
                        title_font_size=TITLE_FONT_SIZE,
                        axis_title_font_size=AXIS_TITLE_FONT_SIZE,
                        legend_font_size=LEGEND_FONT_SIZE)
else:
    model, train_losses, val_losses = train_model(
        step_tensor=step_tensor,
        lengths=lengths,
        input_dim=INPUT_DIM,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCHS,
        lr=LR,
        patience=PATIENCE,
        min_delta=MIN_DELTA,
        models_dir=MODELS_DIR
    )
    save_model_losses(train_losses, val_losses, figures_dir=FIGURES_DIR,
                        title_font_size=TITLE_FONT_SIZE,
                        axis_title_font_size=AXIS_FONT_SIZE,
                        legend_font_size=LEGEND_FONT_SIZE)

assert model is not None, "Model training failed or no model was loaded."

In [None]:
### Evaluate the model on pathological steps
selected_steps = [
    s for s in segmented_hindsteps 
    if SIDE_KEY[0] in s["mouse"] 
]

step_tensor_all, lengths_all = steps_to_tensor(selected_steps, scaler)

## Get all embeddings for selected steps
model.eval()
with torch.no_grad():
    x_hat, mu_t, logvar_t  = model(step_tensor_all, lengths_all)  # (B, T, F), (B, T, L), (B, T, L)
    z_summary = mu_t.mean(dim=1)  # → shape: (B, L)

reducer = UMAP(random_state=42, n_components=3)
embedding = reducer.fit_transform(z_summary.cpu().numpy())  # shape: (B, 3)

# Convert each step’s latent sequence to list of (T_i, L)
# mu_t_masked = [
#     mu_t[i, :lengths_all[i]].cpu().numpy()  # shape: (T_i, L)
#     for i in range(mu_t.shape[0])
# ]

# Stack all latent vectors across all steps
# all_latents = np.concatenate(mu_t_masked, axis=0)  # shape: (sum of T_i, L)

# # Fit UMAP
# umap_coords = UMAP(n_components=3, random_state=42).fit_transform(all_latents)

# ## Now split the coordinates back by original steps
# # for each step, have all the umap coordinates
# umap_split = []
# idx = 0
# for i in range(len(mu_t_masked)):
#     length = mu_t_masked[i].shape[0]
#     umap_split.append(umap_coords[idx:idx + length])  # shape: (T_i, 3)
#     idx += length

In [None]:
datasets = sorted(set(s["group"] for s in selected_steps))
color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}
# if pre is in group, use the first color, otherwise use the second color
for i, ds in enumerate(datasets):
    if HEALTHY_KEY in ds:
        color_map[ds] = plotly.colors.qualitative.Plotly[0]
    else:
        color_map[ds] = plotly.colors.qualitative.Plotly[1]

fig = go.Figure()
legend_shown = defaultdict(bool)

for i, (x, y, z) in enumerate(embedding):
    step_meta = selected_steps[i]
    dataset = step_meta["group"]
    mouse = step_meta["mouse"]
    run = step_meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=[x], y=[y], z=[z],
        mode='markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        marker=dict(size=SCATTER_SIZE, color=color, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run}"]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time by Dataset",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()

In [None]:
## TODO: make the 3D plot with trajectories for both pre and post sci
sample_idx = torch.randperm(mu_t.shape[0])[:100]
mu_t_sample = mu_t[sample_idx]
lengths_sample = lengths_all[sample_idx]
selected_sample = [selected_steps[i] for i in sample_idx.tolist()]

mu_t = mu_t_sample

B, T, L = mu_t.shape

# Flatten all time steps for UMAP
mu_t_flat = mu_t.reshape(-1, L).cpu().numpy()  # (B*T, L)

# Reduce L → 3 using UMAP
embedding_flat = reducer.fit_transform(mu_t_flat)  # (B*T, 3)

# Reshape embedding back to (B, T, 3)
embedding_3d = embedding_flat.reshape(B, T, 3)

print("Moving to 3D plotting...")

# Plot each step as a 3D trajectory
fig = go.Figure()
legend_shown = defaultdict(bool)

for i in range(B):
    emb = embedding_3d[i, :lengths_all[i]]  # shape: T × 3
    meta = selected_steps[i]
    dataset = meta["group"]
    mouse = meta["mouse"]
    run = meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=emb[:, 0],
        y=emb[:, 1],
        z=emb[:, 2],
        mode='lines+markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        marker=dict(size=SCATTER_SIZE, color=color, line=dict(width=SCATTER_LINE_WIDTH, color='black')),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run} | t={t}" for t in range(len(emb))]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time (per Step)",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()

In [None]:
def plot_timewise_umap_trajectories(segmented_hindsteps, scaler):
    """
    Plot UMAP projection of all timepoints in each step.
    Each step becomes a trajectory (line) in 3D UMAP space.
    Mice from the same dataset (e.g., 'pre', 'post') share color and legend.
    """
    from collections import defaultdict

    filtered_steps = [s for s in segmented_hindsteps if SIDE_KEY[0] in s["mouse"]]

    step_id = 0
    timepoints = []
    metadata = []

    for s in filtered_steps:
        arr = scaler.transform(s["step"].values)  # (T, D)
        for t in range(arr.shape[0]):
            timepoints.append(arr[t])
            metadata.append({
                "step_id": step_id,
                "t": t,
                "mouse": s["mouse"],
                "group": s["group"],
                "run": s["run"]
            })
        step_id += 1

    X = np.stack(timepoints)
    umap_coords = UMAP(n_components=3, random_state=42).fit_transform(X)

    umap_df = pd.DataFrame(umap_coords, columns=["UMAP1", "UMAP2", "UMAP3"])
    umap_df = pd.concat([umap_df, pd.DataFrame(metadata)], axis=1)

    datasets = sorted(umap_df["group"].unique())
    color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}

    fig = go.Figure()
    legend_shown = defaultdict(bool)

    for step_id, group in umap_df.groupby("step_id"):
        dataset = group["group"].iloc[0]
        color = color_map[dataset]
        show_legend = not legend_shown[dataset]
        legend_shown[dataset] = True

        fig.add_trace(go.Scatter3d(
            x=group["UMAP1"],
            y=group["UMAP2"],
            z=group["UMAP3"],
            mode="lines+markers",
            line=dict(color=color, width=2),
            marker=dict(size=3, color=color),
            name=dataset,
            legendgroup=dataset,
            showlegend=show_legend,
            text=[f"{group['mouse'].iloc[0]} | run={r} | t={t}" for r, t in zip(group["run"], group["t"])]
        ))

    fig.update_layout(
        title="Time-Resolved UMAP of Step Dynamics by Dataset",
        scene=dict(
            xaxis_title="UMAP1",
            yaxis_title="UMAP2",
            zaxis_title="UMAP3"
        ),
        legend=dict(title="Dataset", font=dict(size=12)),
        width=900,
        height=700,
        template="plotly_white"
    )

    fig.show()

# Usage:
plot_timewise_umap_trajectories(segmented_hindsteps, scaler)