In [None]:
import os
import sys

# Go one level up from 'notebooks/' to project root
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

# Set PYTHONPATH environment variable
os.environ["PYTHONPATH"] = project_root

# Also update sys.path so Python knows to look there for imports
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Optional: verify
print("PYTHONPATH =", os.environ["PYTHONPATH"])

In [None]:
from emg_hand_tracking.model import V42

c = V42
m = c.load_from_checkpoint(f"../checkpoints/s0_{c.name()}.ckpt", map_location="cpu")
m = m.eval()

In [None]:
import time
from emg2pose.kinematics import forward_kinematics
import emg2pose.visualization as visualization
import matplotlib.pyplot as plt
from emg_hand_tracking.dataset import load_recordings

d = load_recordings("../datasets/s1.z", m.emg_samples_per_frame)

seg = d[10][-1]
print(len(seg.couples))

emg = seg.emg
poses = seg.frames

import plotly.io as pio
import torch

start_time = time.time()
with torch.no_grad():
    y_hat = m.forward(
        emg=torch.tensor(emg, dtype=torch.float32),
        initial_poses=torch.tensor(
            poses[: m.frames_per_window, :],
            dtype=torch.float32,
        ),
    ).cpu()  # (T + 1 - I, 20)
end_time = time.time()

predictions_per_second = y_hat.shape[0] / (end_time - start_time)
print(f"Predictions per second: {predictions_per_second:.2f}")
print()

# downsample 2x
y_hat = y_hat[::2, :]
y = torch.tensor(poses[m.frames_per_window :, :][::2, :], dtype=torch.float32)

print("Ground Truth")
pio.show(
    visualization.get_plotly_animation_for_joint_angles(y.numpy(), title="Ground Truth")
)

print("Estimated")
pio.show(
    visualization.get_plotly_animation_for_joint_angles(
        y_hat.numpy(), title="Estimated"
    )
)

print("Progressive Error")
landmarks_pred = forward_kinematics(y_hat.unsqueeze(0).permute(0, 2, 1))[0]  # (S, L, 3)
landmarks_gt = forward_kinematics(y.unsqueeze(0).permute(0, 2, 1))[0]  # (S, L, 3)

sq_delta = (landmarks_pred - landmarks_gt) ** 2  # (S, L, 3)
err_per_lmk = sq_delta.sum(dim=-1).sqrt()  # (S, L)
err_per_prediction = err_per_lmk.mean(dim=-1).numpy()  # S
err_per_prediction_std = err_per_lmk.std(dim=-1).numpy()  # S

plt.figure(figsize=(10, 4))
plt.plot(err_per_prediction, label="Mean Error", color="b")

plt.fill_between(
    range(len(err_per_prediction)),
    err_per_prediction - err_per_prediction_std,
    err_per_prediction + err_per_prediction_std,
    color="b",
    alpha=0.2,
    label="Standard Deviation",
)

plt.title("Error per Prediction with Standard Deviation")
plt.xlabel("Prediction Step")
plt.ylabel("Error")
plt.grid(True)
plt.tight_layout()
plt.legend()
plt.show()