In [12]:
from typing import List, NamedTuple
import numpy as np
import yaml
import io
import zipfile

# NOTE: `frame` here refers to hand pose angles

W = 64


class HandEmgTuple(NamedTuple):
    frame: np.ndarray  # (20,), float32 expected
    emg: np.ndarray  # (W, C), float32 expected


class HandEmgRecordingSegment(NamedTuple):
    couples: List[HandEmgTuple]
    sigma: np.ndarray  # (20,), float32 single final frame

    @property
    def emg(self):
        return np.concatenate([s.emg for s in self.couples])

    @property
    def frames(self):
        return np.stack([s.frame for s in self.couples] + [self.sigma])


HandEmgRecording = List[HandEmgRecordingSegment]


class RecordingWriter:
    """
    A context for writing recording by segments

    Each segment is written in format:
    [ [<20 x float32: frame>, <W x C float32: emg>], [...], ... <20 x float32: sigma frame> ]
    """

    def __init__(self, context: "DatasetWriter", index: int):
        self.context = context
        self.index = index
        self.count = 0

    def add(self, segment: HandEmgRecordingSegment):
        """
        Add a single recording segment to the ZIP archive.

        Args:
            segment: A list of HandEmgTuple samples. Each sample is stored with its frame
                       (20 float32 values) and its emg (W x C float32 values). The number of EMG
                       channels (C) is determined from the first sample and is assumed to be consistent.
        """
        if self.context.archive is None:
            raise RuntimeError("Archive is not open. Use 'with' statement to open it.")

        bio = io.BytesIO()

        # Determine the number of EMG channels (C) from the first sample.
        C = segment.couples[0].emg.shape[1]
        if self.context.C is None:
            # Store C for metadata
            self.context.C = C
            self.context.archive.writestr("metadata.yml", yaml.dump({"C": C}))

        elif self.context.C != C:
            raise ValueError("Inconsistent number of EMG channels across recordings.")

        # Write each sample: frame (20 float32 values) then emg (W * C float32 values).
        for tup in segment.couples:
            # Verify data types and dimensions.
            assert (
                tup.frame.dtype == np.float32
            ), f"Frame dtype must be float32, got {tup.frame.dtype}"
            assert (
                tup.emg.dtype == np.float32
            ), f"EMG dtype must be float32, got {tup.emg.dtype}"
            assert tup.frame.shape == (
                20,
            ), f"Frame shape must be (20,), got {tup.frame.shape}"
            assert (
                tup.emg.shape[0] == W and tup.emg.shape[1] == C
            ), f"EMG shape must be ({W}, {C}), got {tup.emg.shape}"

            bio.write(tup.frame.tobytes())
            bio.write(tup.emg.flatten().tobytes())

        # Write the final frame (sigma) as well.
        bio.write(segment.sigma.tobytes())

        # Save the segment
        self.context.archive.writestr(
            f"recordings/{self.index}/segments/{self.count}", bio.getvalue()
        )
        self.count += 1


class DatasetWriter:
    """
    A context manager for writing segments to a ZIP archive in a proprietary binary format.

    Archive looks like this:

    dataset.zip/
      metadata.yml
      recordings/
        1/
          segments/
           1
           2
        2/
          segments/
            1
            2
    """

    def __init__(self, filename: str):
        self.filename = filename
        self.archive = None
        self.recording_index = 0
        self.C: int | None = None  # To store the number of EMG channels

    def __enter__(self):
        self.archive = zipfile.ZipFile(
            self.filename,
            mode="w",
            compression=zipfile.ZIP_DEFLATED,
            compresslevel=9,
        )
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.archive is not None:
            self.archive.close()

    def add_recording(self):
        self.recording_index += 1
        return RecordingWriter(self, self.recording_index)

In [13]:
import torch


def rotation_matrix_from_vectors(
    vec1: torch.Tensor,
    vec2: torch.Tensor,
    eps: float = 1e-6,
) -> torch.Tensor:
    a = vec1 / vec1.norm()
    b = vec2 / vec2.norm()
    v = torch.cross(a, b, dim=-1)
    if v.norm() < eps:
        return torch.eye(3, dtype=vec1.dtype, device=vec1.device)
    c = torch.dot(a, b)
    s = v.norm()
    K = torch.tensor(
        [[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]],
        dtype=vec1.dtype,
        device=vec1.device,
    )
    R = torch.eye(3, dtype=vec1.dtype, device=vec1.device)
    return R + K + K @ K * ((1 - c) / (s * s))


def normalize_hand(
    hand_3d_points: torch.Tensor | np.ndarray,
    whrist_base: int = 0,
    middle_finger_inner_bone: int = 8,
    point_finger_inner_bone: int = 4,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Normalize one hand (L,3) or a batch of hands (B, L, 3):
      1. translate so point `whrist_base` is at the origin
      2. rotate so `middle_finger_inner_bone` aligns with +Y
      3. scale so that vector to `middle_finger_inner_bone` has length 1
      4. rotate around Y so `point_finger_inner_bone` lies in the +Z half-plane
    Returns same shape as input.
    """
    if not isinstance(hand_3d_points, torch.Tensor):
        hand_3d_points = torch.tensor(hand_3d_points, dtype=torch.float32)

    # if batch of hands, just loop
    if hand_3d_points.dim() == 3:
        B, L, _ = hand_3d_points.shape
        normalized = torch.empty_like(hand_3d_points)
        for b in range(B):
            # recursive call on each (L,3)
            normalized[b] = normalize_hand(
                hand_3d_points[b],
                whrist_base,
                middle_finger_inner_bone,
                point_finger_inner_bone,
                eps,
            )
        return normalized

    # --- below is the original single‑hand logic for shape (L,3) ---
    device, dtype = hand_3d_points.device, hand_3d_points.dtype

    # 1) translate so wrist base → origin
    T = -hand_3d_points[whrist_base]
    hand = hand_3d_points + T

    # 2) rotate so that middle finger inner bone aligns with +Y
    target_y = torch.tensor([0, 1, 0], dtype=dtype, device=device)
    R1 = rotation_matrix_from_vectors(hand[middle_finger_inner_bone], target_y, eps=eps)
    hand = hand @ R1.T

    # 3) scale so middle‐finger inner bone length → 1
    hand = hand / (hand[middle_finger_inner_bone][1] + eps)

    # 4) rotate around Y so the point finger inner bone lies in +Z half-plane
    v = hand[point_finger_inner_bone]
    xz = torch.stack([v[0], v[2]])
    norm_xz = xz.norm()
    if norm_xz >= eps:
        sinA, cosA = xz[0] / norm_xz, xz[1] / norm_xz
        R2 = torch.tensor(
            [[cosA, 0.0, -sinA], [0.0, 1.0, 0.0], [sinA, 0.0, cosA]],
            dtype=dtype,
            device=device,
        )
        hand = hand @ R2.T

    return hand

In [14]:
import torch

# Desired bone lengths
BONE_LENGTHS = {
    (0, 1): 7.0,
    (1, 2): 3.5,
    (2, 3): 2.5,  # Thumb
    (0, 4): 9.0,
    (4, 5): 4.0,
    (5, 6): 2.5,
    (6, 7): 2.0,  # Index
    (0, 8): 9.0,
    (8, 9): 5.0,
    (9, 10): 3.0,
    (10, 11): 2.0,  # Middle
    (0, 12): 8.5,
    (12, 13): 5.0,
    (13, 14): 3.0,
    (14, 15): 2.0,  # Ring
    (0, 16): 8.0,
    (16, 17): 4.0,
    (17, 18): 2.5,
    (18, 19): 2.0,  # Pinky
}


def fix_hand_landmarks_anatomy_batched(joints_batch: torch.Tensor) -> torch.Tensor:
    """
    Normalize bone lengths of hand landmarks in a batched manner.
    Returns normalized hand joints for each batch element.

    Input:
        joints_batch: 3D tensor of shape (B, 20, 3) of hand joint coordinates
    Output:
        3D tensor of shape (B, 20, 3) with normalized bone lengths
    """
    fixed_batch = joints_batch.clone()
    for i in range(joints_batch.shape[0]):
        joints = joints_batch[i]
        fixed = normalize_hand(joints)

        for p, c in BONE_LENGTHS:
            vec = joints[c] - joints[p]  # (3,)
            length = torch.norm(vec)
            if length > 0:
                direction = vec / length
            else:
                direction = torch.zeros(3, device=joints.device)
            target_length = BONE_LENGTHS[(p, c)]
            fixed[c] = fixed[p] + direction * target_length

        fixed_batch[i] = normalize_hand(fixed)

    return fixed_batch

In [15]:
from shared import rot, _V, _P, DEFAULT_MORPHOLOGY
import torch

E = 1e-7


def irot_full(
    v: torch.Tensor,
    p: torch.Tensor,
    v_hat: torch.Tensor,
    eps: torch.Tensor,  # Tolerance for "cosA == 0"
    fallback_beta: torch.Tensor,
):
    """
    Batched inverse rotation function.

    v, p, v_hat: (B, 3) tensors
    Returns: alpha (B,), beta (B,), v_hat (B, 3), p_hat (B, 3)
    """

    q = torch.cross(v, p, dim=-1)

    dot_v = v_hat.mul(v).sum(dim=-1)  # cosB*cosA
    dot_p = v_hat.mul(p).sum(dim=-1)  # sinB*cosA
    dot_q = v_hat.mul(q).sum(dim=-1).clamp(min=-1 + E, max=1 - E)  # sinA

    alpha = torch.asin(dot_q)

    cosA = torch.ones_like(dot_q).sub(dot_q.square()).clamp(min=0).sqrt()

    beta = torch.where(
        cosA.abs().lt(eps),
        fallback_beta,
        torch.atan2(dot_p.div(cosA), dot_v.div(cosA)),
    )

    p_hat = p.mul(beta.cos().unsqueeze(-1)).sub(v.mul(beta.sin().unsqueeze(-1)))

    return alpha, beta, v_hat, p_hat


def irot_alpha(
    v: torch.Tensor,
    p: torch.Tensor,
    v_hat: torch.Tensor,
):
    """
    Batched inverse rotation function.

    v, p, v_hat: (B, 3) tensors
    Returns: alpha (B,), beta (B,), v_hat (B, 3), p_hat (B, 3)
    """
    q = torch.cross(v, p, dim=-1)

    cosA = v_hat.mul(v).sum(dim=-1).clamp(min=-1 + E, max=1 - E)
    alpha = v_hat.mul(q).sum(dim=-1).sign().mul(torch.arccos(cosA))
    beta = torch.zeros_like(cosA)
    p_hat = p

    return alpha, beta, v_hat, p_hat


def inverse_hand_angles_by_landmarks(
    landmarks: torch.Tensor,
):
    """
    landmarks: (B, 20, 3) - known hand landmarks in 3D space
    morphology: (5, 10) - hand morphology (not batched)

    Returns:
    angles: (B, 20) - recovered joint angles
    """

    B = landmarks.shape[0]

    angles = torch.zeros(B, 5, 4, dtype=landmarks.dtype, device=landmarks.device)

    morphology = torch.tensor(
        DEFAULT_MORPHOLOGY,
        dtype=angles.dtype,
        device=angles.device,
    )

    # NOTE: assuming 0, 4, 8, 12, 16 landmarks are existing in the morphology
    # 20 landmarks -> 21 landmarks, restoring thumb base from the morphology
    #   so that now assuming 0, 1, 5, 9, 13, 17 are existing in the morphology
    landmarks = torch.cat(
        [
            landmarks[:, :1, :],
            morphology[0][3:6].expand(B, 1, 3),  # Repeat thumb_base across batch
            landmarks[:, 1:, :],
        ],
        dim=1,
    )

    eps = torch.tensor(0.4, dtype=angles.dtype, device=angles.device)

    V = torch.tensor(_V, dtype=angles.dtype, device=angles.device).expand(B, 3)
    P = torch.tensor(_P, dtype=angles.dtype, device=angles.device).expand(B, 3)

    for i, morph in enumerate(morphology):
        base_idx = i * 4 + 1
        alpha, beta, gamma, fallback_beta = morph[6:10].clone()

        v, p = rot(V, P, alpha, beta)

        # Rotate p around v according to parameter gamma
        sinG, cosG = torch.sin(gamma), torch.cos(gamma)
        p = p.mul(cosG).add(torch.cross(p, v, dim=-1).mul(sinG))

        target = torch.nn.functional.normalize(
            landmarks[:, base_idx + 1].sub(landmarks[:, base_idx]), dim=-1
        )
        alpha, beta, v, p = irot_full(v, p, target, eps, fallback_beta)
        angles[:, i, 0] = alpha
        angles[:, i, 1] = beta

        for j in range(1, 3):
            target = torch.nn.functional.normalize(
                landmarks[:, base_idx + j + 1].sub(landmarks[:, base_idx + j]), dim=-1
            )

            alpha, beta, v, p = irot_alpha(v, p, target)

            angles[:, i, j + 1] = alpha

    return angles.flatten(start_dim=1)

In [16]:
from emg2pose.kinematics import forward_kinematics, load_default_hand_model
import numpy as np
import torch

hand_model = load_default_hand_model()

POINTS_SELECT = [5, 6, 7, 0, 8, 9, 10, 1, 11, 12, 13, 2, 14, 15, 16, 3, 17, 18, 19, 4]


# x: B, C
def emg2pose_forward_hand_kinematics(x: torch.Tensor):
    x = x.unsqueeze(1).permute(0, 2, 1)
    hands = forward_kinematics(x, hand_model).squeeze(1)
    return hands[:, POINTS_SELECT, :]  # B, L, 3

In [None]:
from typing import Dict, Tuple
from tqdm import tqdm
import h5py


def archive_dataset(
    grouped_segments: Dict[str, List[Tuple[int, int]]],
    filepath: str,
    desc: str,
    filter_rec_len: int = 32 * 2,  # for 32 fps, thats 2 sec
    filter_segment_len: int = 8,  # for 32 fps, thats 0.25 sec
):
    with DatasetWriter(filepath) as dataset_writer:
        with tqdm(
            total=sum(len(v) for v in grouped_segments.values()),
            ncols=100,
            desc=f"Archiving {desc}",
        ) as pbar:
            for path, segments in grouped_segments.items():
                rec: HandEmgRecording = []
                with h5py.File(path, "r") as f:
                    timeseries = f["emg2pose"]["timeseries"]  # type: ignore
                    joint_angles = timeseries["joint_angles"]  # type: ignore
                    emg = timeseries["emg"]  # type: ignore

                    segment = None
                    for slice in segments:
                        start, end = slice[0], slice[1]
                        slices = (end - start) // W
                        real_end = slices * W + start

                        segment = HandEmgRecordingSegment(couples=[], sigma=joint_angles[real_end])  # type: ignore

                        for i in range(slices):  # type: ignore
                            emg_slice = emg[start + i * W : start + (i + 1) * W]  # type: ignore
                            joints = joint_angles[start + i * W]  # type: ignore
                            segment.couples.append(HandEmgTuple(frame=joints, emg=emg_slice))  # type: ignore

                        # Reencode frames to our format
                        landmarks = emg2pose_forward_hand_kinematics(
                            torch.tensor(segment.frames)
                        )
                        # new_frames = inverse_hand_angles_by_landmarks(
                        #     fix_hand_landmarks_anatomy_batched(landmarks),
                        # ).numpy()
                        new_frames = inverse_hand_angles_by_landmarks(
                            normalize_hand(landmarks),
                        ).numpy()

                        segment = HandEmgRecordingSegment(
                            couples=[
                                HandEmgTuple(
                                    frame=new_frames[i],
                                    emg=segment.couples[i].emg,
                                )
                                for i in range(slices)
                            ],
                            sigma=new_frames[-1],
                        )

                        if len(segment.couples) >= filter_segment_len:
                            rec.append(segment)
                        pbar.update(1)

                if sum(len(segment.couples) for segment in rec) >= filter_rec_len:
                    recording_writer = dataset_writer.add_recording()
                    for segment in rec:
                        recording_writer.add(segment)

In [18]:
def extract_segments(
    file: str,
    min_segment_length: int = 4096,
    tail_trim: int = 512,
):
    """
    Extract no ik failure segments from a emg2pose recording.
    """
    # load raw segments
    with h5py.File(file, "r") as f:
        timeseries: np.ndarray = f["emg2pose"]["timeseries"]  # type: ignore
        joint_angles: np.ndarray = timeseries["joint_angles"]  # (T, 20)  # type: ignore

        # get ik_failure mask
        zeros = np.zeros_like(joint_angles)
        is_zero = np.isclose(joint_angles, zeros)
        ik_failure_mask = ~np.all(is_zero, axis=-1)  # trues if no ik failure

        ones = np.where(ik_failure_mask)[0]

        if ones.shape[0] == 0:
            # the whole file is ik failure
            return []

        boundaries = np.where(np.diff(ones) != 1)[0]
        segments = [
            (ones[i], ones[j])
            for i, j in zip(
                np.insert(boundaries + 1, 0, 0),
                np.append(boundaries, len(ones) - 1),
            )
        ]

    # trim tails of the segments since because of interpolation frames nearby the ik failure are not valid, so we need to throw them out
    segments = [(s[0] + tail_trim, s[1] - tail_trim) for s in segments]
    # there can be segments of negative length

    # finally, filter segments by length
    return [s for s in segments if (s[1] - s[0]) >= min_segment_length]

In [19]:
import pandas as pd

base_path = "C:/Users/shich/emg2pose_data"

metadata = pd.read_csv(f"{base_path}/metadata.csv")
metadata.head()

Unnamed: 0,session,user,stage,start,end,side,filename,moving_hand,held_out_user,held_out_stage,split,generalization
0,2022-04-07-1649318400-8125c-cv-emg-pose-train@2,29ddab35d7,ThumbsUpDownThumbRotationsCWCCWP,1649400000.0,1649400000.0,left,2022-04-07-1649318400-8125c-cv-emg-pose-train@...,both,True,False,val,user
1,2022-04-07-1649318400-8125c-cv-emg-pose-train@2,29ddab35d7,ThumbsUpDownThumbRotationsCWCCWP,1649400000.0,1649400000.0,right,2022-04-07-1649318400-8125c-cv-emg-pose-train@...,both,True,False,val,user
2,2022-04-07-1649318400-8125c-cv-emg-pose-train@2,29ddab35d7,HandClawGraspFlicks,1649401000.0,1649401000.0,left,2022-04-07-1649318400-8125c-cv-emg-pose-train@...,both,True,False,val,user
3,2022-04-07-1649318400-8125c-cv-emg-pose-train@2,29ddab35d7,HandClawGraspFlicks,1649401000.0,1649401000.0,right,2022-04-07-1649318400-8125c-cv-emg-pose-train@...,both,True,False,val,user
4,2022-04-07-1649318400-8125c-cv-emg-pose-train@2,29ddab35d7,ShakaVulcanPeace,1649401000.0,1649401000.0,left,2022-04-07-1649318400-8125c-cv-emg-pose-train@...,both,True,True,val,user_stage


In [20]:
# check that no users has collisions by slice of the id
users = metadata["user"].unique()
first_two_symbols = {user[-4:] for user in users}
has_collisions = len(first_two_symbols) != len(users)
has_collisions

False

In [21]:
sessions = metadata["session"].unique()
print(f"Found {len(sessions)} sessions")

Found 751 sessions


In [22]:
start = 0
count = 8
sessions = sessions[start : start + count]

# for each session, load recordings, split to segments and archive them as a separate dataset
for i, session in enumerate(sessions):
    user = metadata[metadata["session"] == session]["user"].unique()[0]

    # load left hand recordings for the session
    recordings = metadata[
        (metadata["session"] == session) & (metadata["side"] == "left")
    ]["filename"].unique()

    # load segments from each recording
    grouped_segments = {}
    for recording in recordings:
        fname = f"{base_path}/{recording}.hdf5"
        segments = extract_segments(fname)

        if len(segments) == 0:
            continue

        grouped_segments[fname] = segments

    if len(grouped_segments.keys()) == 0:
        raise ValueError(f"Cannot proper segmentate recordings from {session}")

    # archive segments to a dataset
    archive_dataset(
        grouped_segments,
        f"../datasets/s{start + i + 1}.z",
        desc=f"{i + 1}/{len(sessions)}",
    )

Archiving 1/8: 100%|████████████████████████████████████████████████| 75/75 [00:42<00:00,  1.75it/s]
Archiving 2/8: 100%|████████████████████████████████████████████████| 32/32 [00:14<00:00,  2.17it/s]
Archiving 3/8: 100%|████████████████████████████████████████████████| 63/63 [00:29<00:00,  2.10it/s]
Archiving 4/8: 100%|████████████████████████████████████████████████| 74/74 [00:43<00:00,  1.70it/s]
Archiving 5/8: 100%|████████████████████████████████████████████████| 82/82 [00:29<00:00,  2.74it/s]
Archiving 6/8: 100%|████████████████████████████████████████████████| 62/62 [00:50<00:00,  1.23it/s]
Archiving 7/8: 100%|████████████████████████████████████████████████| 83/83 [00:28<00:00,  2.87it/s]
Archiving 8/8: 100%|████████████████████████████████████████████████| 63/63 [00:43<00:00,  1.45it/s]
