In [1]:
import os
import sys

# Go one level up from 'notebooks/' to project root
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

# Set PYTHONPATH environment variable
os.environ["PYTHONPATH"] = project_root

# Also update sys.path so Python knows to look there for imports
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Optional: verify
print("PYTHONPATH =", os.environ["PYTHONPATH"])

PYTHONPATH = c:\Users\shich\Src\thesis\hand_emg_regression


In [2]:
from emg_hand_tracking.model import Model

m = Model.load_from_checkpoint(f"../checkpoints/c0.ckpt", map_location="cpu")
m = m.eval()

In [3]:
import torch
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots


def plot_3d_hands_dual(
    sequence1: torch.Tensor | np.ndarray,
    sequence2: torch.Tensor | np.ndarray,
    name1: str = "Sequence 1",
    name2: str = "Sequence 2",
):
    if isinstance(sequence1, torch.Tensor):
        sequence1 = sequence1.detach().cpu().numpy()
    if isinstance(sequence2, torch.Tensor):
        sequence2 = sequence2.detach().cpu().numpy()

    n_frames = sequence1.shape[0]
    n_pts = sequence1.shape[1]
    labels = [str(i) for i in range(n_pts)]

    connections = [
        (0, 1),
        (1, 2),
        (2, 3),
        (0, 4),
        (4, 5),
        (5, 6),
        (6, 7),
        (0, 8),
        (8, 9),
        (9, 10),
        (10, 11),
        (0, 12),
        (12, 13),
        (13, 14),
        (14, 15),
        (0, 16),
        (16, 17),
        (17, 18),
        (18, 19),
        (1, 4),
        (4, 8),
        (8, 12),
        (12, 16),
    ]

    fig = make_subplots(
        rows=1,
        cols=2,
        specs=[[{"type": "scene"}, {"type": "scene"}]],
        horizontal_spacing=0.05,
    )

    def create_hand_trace(seq, frame_idx, color):
        data = [
            go.Scatter3d(
                x=seq[frame_idx, :, 0],
                y=seq[frame_idx, :, 1],
                z=seq[frame_idx, :, 2],
                mode="markers+text",
                marker=dict(size=4, color=color),
                text=labels,
                textposition="top center",
                textfont=dict(size=8),
                name="Landmarks",
                showlegend=False,
            )
        ]
        for i, j in connections:
            data.append(
                go.Scatter3d(
                    x=[seq[frame_idx, i, 0], seq[frame_idx, j, 0]],
                    y=[seq[frame_idx, i, 1], seq[frame_idx, j, 1]],
                    z=[seq[frame_idx, i, 2], seq[frame_idx, j, 2]],
                    mode="lines",
                    line=dict(color="black", width=2),
                    showlegend=False,
                )
            )
        return data

    fig.add_traces(create_hand_trace(sequence1, 0, "blue"), rows=1, cols=1)
    fig.add_traces(create_hand_trace(sequence2, 0, "green"), rows=1, cols=2)

    frames = []
    for k in range(n_frames):
        frame = go.Frame(
            data=create_hand_trace(sequence1, k, "blue")
            + create_hand_trace(sequence2, k, "green"),
            name=str(k),
        )
        frames.append(frame)

    fig.frames = frames

    fig.update_layout(
        updatemenus=[
            {
                "buttons": [
                    {
                        "args": [
                            None,
                            {
                                "frame": {"duration": 33, "redraw": True},
                                "fromcurrent": True,
                            },
                        ],
                        "label": "Play",
                        "method": "animate",
                    },
                    {
                        "args": [
                            [None],
                            {
                                "frame": {"duration": 0, "redraw": True},
                                "mode": "immediate",
                            },
                        ],
                        "label": "Pause",
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 30},
                "showactive": False,
                "type": "buttons",
                "x": 0.1,
                "xanchor": "right",
                "y": 0,
                "yanchor": "top",
            }
        ],
        sliders=[
            {
                "pad": {"b": 10, "t": 10},
                "len": 0.9,
                "x": 0.1,
                "xanchor": "left",
                "y": 0,
                "yanchor": "top",
                "steps": [
                    {
                        "args": [
                            [f.name],
                            {
                                "frame": {"duration": 33, "redraw": True},
                                "mode": "immediate",
                            },
                        ],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(frames)
                ],
            }
        ],
        scene=dict(
            xaxis=dict(range=[2, -2], title="X"),
            yaxis=dict(range=[-1, 3], title="Y"),
            zaxis=dict(range=[-2, 2], title="Z"),
            camera=dict(eye=dict(x=0.2, y=0.2, z=0.2)),
            aspectmode="cube",
        ),
        scene2=dict(
            xaxis=dict(range=[2, -2], title="X"),
            yaxis=dict(range=[-1, 3], title="Y"),
            zaxis=dict(range=[-2, 2], title="Z"),
            camera=dict(eye=dict(x=0.2, y=0.2, z=0.2)),
            aspectmode="cube",
        ),
        annotations=[
            dict(
                text=name1,
                showarrow=False,
                x=0.225,
                xref="paper",
                y=1.05,
                yref="paper",
                font=dict(size=22),
            ),
            dict(
                text=name2,
                showarrow=False,
                x=0.775,
                xref="paper",
                y=1.05,
                yref="paper",
                font=dict(size=22),
            ),
        ],
        margin=dict(l=0, r=0, b=10, t=40),
        height=500,
    )

    fig.show()

In [4]:
import time
import plotly.graph_objs as go
from emg_hand_tracking.dataset import (
    HandEmgRecordingSegment,
    load_recordings,
    get_pose_format,
)
from emg_hand_tracking.model.fk import FK_BY_POSE_FORMAT

dataset_path = "../datasets/0.zip"
d = load_recordings(dataset_path, m.emg_samples_per_frame)

seg = d[0][-1]
print(len(seg.couples))

# select val window (that was not used for training)
seg = HandEmgRecordingSegment(couples=seg.couples[-248:], sigma=seg.sigma)

emg = seg.emg
poses = seg.frames

start_time = time.time()
with torch.no_grad():
    y_hat = m.forward(
        emg=torch.tensor(emg, dtype=torch.float32),
        initial_poses=torch.tensor(
            poses[: m.frames_per_window, :],
            dtype=torch.float32,
        ),
    ).cpu()  # (T + 1 - I, 20)
end_time = time.time()

predictions_per_second = y_hat.shape[0] / (end_time - start_time)
print(f"Predictions per second: {predictions_per_second:.2f}")
print()

# downsample 2x
y_hat = y_hat[::2, :]
y = torch.tensor(poses[m.frames_per_window :, :][::2, :], dtype=torch.float32)

raw_fk = FK_BY_POSE_FORMAT[get_pose_format(dataset_path)]


def fk(y):
    return raw_fk(y) / 90.0


landmarks_pred = fk(y_hat.unsqueeze(0))[0]  # (S, L, 3)
landmarks_gt = fk(y.unsqueeze(0))[0]  # (S, L, 3)

plot_3d_hands_dual(landmarks_gt, landmarks_pred, "Ground Truth", "Estimation")


def plot_error_percentiles(
    data: np.ndarray,
    title: str,
    xlabel: str,
    ylabel: str,
):
    """
    Plots percentile bands of prediction errors over time using Plotly.

    Parameters:
        data (np.ndarray): Array of shape (S, L), where S is number of steps, L is number of landmarks.
    """
    assert data.ndim == 2, "Input must be a 2D array of shape (S, L)"

    S = data.shape[0]
    x = np.arange(S)
    quantile_levels = np.arange(10, 100 + 1, 10)

    fig = go.Figure()

    # Add percentile bands
    for i in range(len(quantile_levels) // 2):
        low = np.percentile(data, quantile_levels[i], axis=1)
        high = np.percentile(data, quantile_levels[-(i + 1)], axis=1)
        fig.add_trace(
            go.Scatter(
                x=np.concatenate([x, x[::-1]]),
                y=np.concatenate([low, high[::-1]]),
                fill="toself",
                fillcolor=f"rgba(31, 119, 180, {0.15 + 0.1 * i})",
                line=dict(color="rgba(255,255,255,0)"),
                showlegend=False,
            )
        )

    # Median line
    median = np.percentile(data, 50, axis=1)
    fig.add_trace(
        go.Scatter(
            x=x,
            y=median,
            mode="lines",
            line=dict(color="black", width=2),
            showlegend=False,
        )
    )

    # Y-axis upper limit: 100th percentile + 5% headroom
    y_max = np.percentile(data, 100) * 1.05

    fig.update_layout(
        title=dict(text=title, x=0.5, xanchor="center", pad=dict(t=10)),
        xaxis=dict(
            title=xlabel,
            showgrid=True,
            zeroline=False,
            showline=False,
            ticks="outside",
            constrain="domain",
        ),
        yaxis=dict(
            title=ylabel,
            range=[0, y_max],
            showgrid=True,
            zeroline=False,
            showline=False,
            ticks="outside",
            ticklabelposition="outside",
            automargin=True,
            constrain="domain",
        ),
        template="plotly_white",
        margin=dict(l=40, r=20, t=40, b=40),
        hovermode="x unified",
        width=1000,
        height=350,
    )

    fig.show()


sq_delta = (landmarks_pred - landmarks_gt) ** 2  # (S, L, 3)
err_per_lmk = sq_delta.sum(dim=-1).sqrt().numpy()  # (S, L)
plot_error_percentiles(
    err_per_lmk * 9.0,  # plot S on x axis, and project L as variance bonds
    "Error Progression",
    "Step",
    "Error (cm)",
)

Loading dataset: 100%|██████████| 1/1 [00:00<00:00,  2.72it/s]
Upsampling segments: 100%|██████████| 1/1 [00:00<00:00,  2.53it/s]


92828
Predictions per second: 1475.28



In [8]:
import numpy as np
from scipy.signal import spectrogram


def plot_emg_spectrogram(
    emg: np.ndarray,
    sample_rate: int = 2000,
    channels_to_plot: int = 6,
):
    """
    Plots the spectrogram of EMG signals using Plotly.

    Parameters:
        emg (np.ndarray): EMG array of shape (T, C), where T is time and C is number of channels.
        sample_rate (int): Sampling rate of the EMG signal in Hz.
        channels_to_plot (int): Number of channels to visualize.
    """
    T, C = emg.shape
    channels_to_plot = min(C, channels_to_plot)

    for ch in range(channels_to_plot):
        f, t, Sxx = spectrogram(emg[:, ch], fs=sample_rate, nperseg=64, noverlap=32)
        Sxx_dB = 10 * np.log10(Sxx + 1e-10)  # convert to dB

        fig = go.Figure(
            data=go.Heatmap(
                z=Sxx_dB,
                x=t,
                y=f,
                colorscale="Viridis",
                colorbar=dict(title="Power (dB)"),
            )
        )

        fig.update_layout(
            title=f"Spectrogram - Channel {ch}",
            xaxis_title="Time (s)",
            yaxis_title="Frequency (Hz)",
            width=900,
            height=300,
            margin=dict(l=50, r=30, t=40, b=40),
        )

        fig.show()


# Call the function
plot_emg_spectrogram(emg)