In [None]:
import os
import pandas as pd
import numpy as np
import torch
import plotly.express as px
import plotly.graph_objs as go
import plotly.colors
from collections import defaultdict

from sklearn.preprocessing import StandardScaler
from umap import UMAP

from src.ModelUtils import (load_model,
                            train_model)
from src.DataUtils import (load_data_from_directory,
                           segment_all_steps,
                           steps_to_tensor)

In [None]:
## Directories and constants
FIGURES_DIR = './figures/'
MODELS_DIR = './src/models'
DATA_DIR = './csv/SCI_Plot_Features'
HEALTHY_KEY = 'Pre'
SICK_KEY = 'Post'
SIDE_KEY = ('Left','Right')

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 5000
LR = 1e-3
PATIENCE = 100 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = None#os.path.join(MODELS_DIR, 'lstm_VAE_no_first_last_20250609_121841.pt')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = load_data_from_directory(DATA_DIR)

# Print the data structure
for datagroup, mice in data.items():
    print(f"Group: {datagroup}")
    for mouse_direction, runs in mice.items():
        print(f"\t{mouse_direction}: {len(runs)} runs with shapes: {[df.shape for df in runs.values()]}")

In [None]:
# Print the features count
for datagroup, mice in data.items():
    for mouse_direction, runs in mice.items():
        for run, df in runs.items():
            print(f"Number of features: {df.shape[1] + 1 - 4}") # +1 bc of index, -4 to exclude frame, forestep, hindstep, and time
            break
        break
    break

In [None]:
## Segment all steps in the data
# This will create two dictionaries: segmented_hindsteps and segmented_foresteps
# Each dictionary will contain segmented steps for each mouse direction and run
    # "step": step_df,
    # "group": group,
    # "mouse": mouse,
    # "run": run
segmented_hindsteps, segmented_foresteps = segment_all_steps(data)

## Flatten all steps into a single array to compute global mean/std and get the scaler
all_healthy_arrays = [step_dict["step"].values for step_dict in segmented_hindsteps if HEALTHY_KEY in step_dict["group"] and SIDE_KEY[0] in step_dict["mouse"]]
flat_data = np.vstack(all_healthy_arrays)

scaler = StandardScaler()
scaler = scaler.fit(flat_data)

## Prepare the data for training
healthy_steps = [s for s in segmented_hindsteps if HEALTHY_KEY in s["group"] and SIDE_KEY[0] in s["mouse"]]
step_tensor, lengths = steps_to_tensor(healthy_steps, scaler)
if INPUT_DIM is None:
    INPUT_DIM = step_tensor.shape[2]

print(f"Step tensor shape: {step_tensor.shape}, \nLengths shape: {lengths.shape}")

In [None]:
### Training block
model = None
if BEST_MODEL_PATH:
    model = load_model(model_path=BEST_MODEL_PATH, input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
else:
    model = train_model(
        step_tensor=step_tensor,
        lengths=lengths,
        input_dim=INPUT_DIM,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCHS,
        lr=LR,
        patience=PATIENCE,
        min_delta=MIN_DELTA,
        models_dir=MODELS_DIR,
        figures_dir=FIGURES_DIR,
        title_font_size=TITLE_FONT_SIZE,
        axis_font_size=AXIS_FONT_SIZE,
        legend_font_size=LEGEND_FONT_SIZE
    )
assert model is not None, "Model training failed or no model was loaded."

In [None]:
### Evaluate the model on pathological steps
selected_steps = [
    s for s in segmented_hindsteps 
    if "left" in s["mouse"] 
]

step_tensor_all, lengths_all = steps_to_tensor(selected_steps, scaler)

## Get all embeddings for selected steps
model.eval()
with torch.no_grad():
    _, mu_t, _ = model(step_tensor_all, lengths_all)  # (B, T, D)

# Convert each step’s latent sequence to list of (T_i, D)
mu_t_masked = [
    mu_t[i, :lengths_all[i]].cpu().numpy()  # shape: (T_i, D)
    for i in range(mu_t.shape[0])
]

# Stack all latent vectors across all steps
all_latents = np.concatenate(mu_t_masked, axis=0)  # shape: (total_timepoints, D)

# Fit UMAP
umap_coords = UMAP(n_components=3, random_state=42).fit_transform(all_latents)

# Now split the coordinates back by original steps
umap_split = []
idx = 0
for i in range(len(mu_t_masked)):
    length = mu_t_masked[i].shape[0]
    umap_split.append(umap_coords[idx:idx + length])  # shape: (T_i, 3)
    idx += length

In [None]:
datasets = sorted(set(s["group"] for s in selected_steps))
color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}
# if pre is in group, use the first color, otherwise use the second color
for i, ds in enumerate(datasets):
    if "pre" in ds:
        color_map[ds] = plotly.colors.qualitative.Plotly[0]
    else:
        color_map[ds] = plotly.colors.qualitative.Plotly[1]

fig = go.Figure()
legend_shown = defaultdict(bool)

for i, umap_seq in enumerate(umap_split):
    step_meta = selected_steps[i]
    dataset = step_meta["group"]
    mouse = step_meta["mouse"]
    run = step_meta["run"]
    color = color_map.get(dataset, "gray")
    show_legend = not legend_shown[dataset]
    legend_shown[dataset] = True

    fig.add_trace(go.Scatter3d(
        x=umap_seq[:, 0],
        y=umap_seq[:, 1],
        z=umap_seq[:, 2],
        mode='lines+markers',
        name=dataset if show_legend else None,
        legendgroup=dataset,
        showlegend=show_legend,
        line=dict(width=2, color=color),
        marker=dict(size=2, color=color),
        hoverinfo='text',
        text=[f"{dataset} | {mouse} | run={run} | t={t}" for t in range(len(umap_seq))]
    ))

fig.update_layout(
    title="Latent Trajectories Over Time by Dataset",
    scene=dict(
        xaxis_title='UMAP1',
        yaxis_title='UMAP2',
        zaxis_title='UMAP3'
    ),
    legend=dict(title="Dataset"),
    width=900,
    height=700,
    template='plotly_white'
)

fig.show()

In [None]:
def plot_timewise_umap_trajectories(segmented_hindsteps, scaler):
    """
    Plot UMAP projection of all timepoints in each step.
    Each step becomes a trajectory (line) in 3D UMAP space.
    Mice from the same dataset (e.g., 'pre', 'post') share color and legend.
    """
    from collections import defaultdict

    filtered_steps = [s for s in segmented_hindsteps if "left" in s["mouse"]]

    step_id = 0
    timepoints = []
    metadata = []

    for s in filtered_steps:
        arr = scaler.transform(s["step"].values)  # (T, D)
        for t in range(arr.shape[0]):
            timepoints.append(arr[t])
            metadata.append({
                "step_id": step_id,
                "t": t,
                "mouse": s["mouse"],
                "group": s["group"],
                "run": s["run"]
            })
        step_id += 1

    X = np.stack(timepoints)
    umap_coords = UMAP(n_components=3, random_state=42).fit_transform(X)

    umap_df = pd.DataFrame(umap_coords, columns=["UMAP1", "UMAP2", "UMAP3"])
    umap_df = pd.concat([umap_df, pd.DataFrame(metadata)], axis=1)

    datasets = sorted(umap_df["group"].unique())
    color_map = {ds: color for ds, color in zip(datasets, plotly.colors.qualitative.Plotly)}

    fig = go.Figure()
    legend_shown = defaultdict(bool)

    for step_id, group in umap_df.groupby("step_id"):
        dataset = group["group"].iloc[0]
        color = color_map[dataset]
        show_legend = not legend_shown[dataset]
        legend_shown[dataset] = True

        fig.add_trace(go.Scatter3d(
            x=group["UMAP1"],
            y=group["UMAP2"],
            z=group["UMAP3"],
            mode="lines+markers",
            line=dict(color=color, width=2),
            marker=dict(size=3, color=color),
            name=dataset,
            legendgroup=dataset,
            showlegend=show_legend,
            text=[f"{group['mouse'].iloc[0]} | run={r} | t={t}" for r, t in zip(group["run"], group["t"])]
        ))

    fig.update_layout(
        title="Time-Resolved UMAP of Step Dynamics by Dataset",
        scene=dict(
            xaxis_title="UMAP1",
            yaxis_title="UMAP2",
            zaxis_title="UMAP3"
        ),
        legend=dict(title="Dataset", font=dict(size=12)),
        width=900,
        height=700,
        template="plotly_white"
    )

    fig.show()

# Usage:
plot_timewise_umap_trajectories(segmented_hindsteps, scaler)