In [None]:
from typing import List, NamedTuple
import numpy as np
import yaml
import io
import zipfile

# NOTE: `frame` here refers to hand pose angles

W = 64


class HandEmgTuple(NamedTuple):
    frame: np.ndarray  # (20,), float32 expected
    emg: np.ndarray  # (W, C), float32 expected


class HandEmgRecordingSegment(NamedTuple):
    couples: List[HandEmgTuple]
    sigma: np.ndarray  # (20,), float32 single final frame


HandEmgRecording = List[HandEmgRecordingSegment]


class RecordingWriter:
    """
    A context for writing recording by segments

    Each segment is written in format:
    [ [<20 x float32: frame>, <W x C float32: emg>], [...], ... <20 x float32: sigma frame> ]
    """

    def __init__(self, context: "DatasetWriter", index: int):
        self.context = context
        self.index = index
        self.count = 0

    def add(self, segment: HandEmgRecordingSegment):
        """
        Add a single recording segment to the ZIP archive.

        Args:
            segment: A list of HandEmgTuple samples. Each sample is stored with its frame
                       (20 float32 values) and its emg (W x C float32 values). The number of EMG
                       channels (C) is determined from the first sample and is assumed to be consistent.
        """
        if self.context.archive is None:
            raise RuntimeError("Archive is not open. Use 'with' statement to open it.")

        bio = io.BytesIO()

        # Determine the number of EMG channels (C) from the first sample.
        C = segment.couples[0].emg.shape[1]
        if self.context.C is None:
            # Store C for metadata
            self.context.C = C
            self.context.archive.writestr("metadata.yml", yaml.dump({"C": C}))

        elif self.context.C != C:
            raise ValueError("Inconsistent number of EMG channels across recordings.")

        # Write each sample: frame (20 float32 values) then emg (W * C float32 values).
        for tup in segment.couples:
            # Verify data types and dimensions.
            assert (
                tup.frame.dtype == np.float32
            ), f"Frame dtype must be float32, got {tup.frame.dtype}"
            assert (
                tup.emg.dtype == np.float32
            ), f"EMG dtype must be float32, got {tup.emg.dtype}"
            assert tup.frame.shape == (
                20,
            ), f"Frame shape must be (20,), got {tup.frame.shape}"
            assert (
                tup.emg.shape[0] == W and tup.emg.shape[1] == C
            ), f"EMG shape must be ({W}, {C}), got {tup.emg.shape}"

            bio.write(tup.frame.tobytes())
            bio.write(tup.emg.flatten().tobytes())

        # Write the final frame (sigma) as well.
        bio.write(segment.sigma.tobytes())

        # Save the segment
        self.context.archive.writestr(
            f"recordings/{self.index}/segments/{self.count}", bio.getvalue()
        )
        self.count += 1


class DatasetWriter:
    """
    A context manager for writing segments to a ZIP archive in a proprietary binary format.

    Archive looks like this:

    dataset.zip/
      metadata.yml
      recordings/
        1/
          segments/
           1
           2
        2/
          segments/
            1
            2
    """

    def __init__(self, filename: str):
        self.filename = filename
        self.archive = None
        self.recording_index = -1
        self.C: int | None = None  # To store the number of EMG channels

    def __enter__(self):
        self.archive = zipfile.ZipFile(
            self.filename,
            mode="w",
            compression=zipfile.ZIP_DEFLATED,
            compresslevel=9,
        )
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.archive is not None:
            self.archive.close()

    def add_recording(self):
        self.recording_index += 1
        return RecordingWriter(self, self.recording_index)

In [None]:
from typing import Dict, Tuple
from tqdm import tqdm
import h5py


def archive_dataset(
    grouped_segments: Dict[str, List[Tuple[int, int]]],
    filepath: str,
    desc: str,
    filter_rec_len: int = 32 * 2,  # for 32 fps, thats 2 sec
    filter_segment_len: int = 8,  # for 32 fps, thats 0.25 sec
):
    with DatasetWriter(filepath) as dataset_writer:
        with tqdm(
            total=sum(len(v) for v in grouped_segments.values()),
            ncols=100,
            desc=f"Archiving {desc}",
        ) as pbar:
            for path, segments in grouped_segments.items():
                rec: HandEmgRecording = []
                with h5py.File(path, "r") as f:
                    timeseries = f["emg2pose"]["timeseries"]  # type: ignore
                    joint_angles = timeseries["joint_angles"]  # type: ignore
                    emg = timeseries["emg"]  # type: ignore

                    segment = None
                    for slice in segments:
                        start, end = slice[0], slice[1]
                        slices = (end - start) // W
                        real_end = slices * W + start

                        segment = HandEmgRecordingSegment(couples=[], sigma=joint_angles[real_end])  # type: ignore

                        for i in range(slices):  # type: ignore
                            emg_slice = emg[start + i * W : start + (i + 1) * W]  # type: ignore
                            joints = joint_angles[start + i * W]  # type: ignore
                            segment.couples.append(HandEmgTuple(frame=joints, emg=emg_slice))  # type: ignore

                        if len(segment.couples) >= filter_segment_len:
                            rec.append(segment)
                        pbar.update(1)

                if sum(len(segment.couples) for segment in rec) >= filter_rec_len:
                    recording_writer = dataset_writer.add_recording()
                    for segment in rec:
                        recording_writer.add(segment)

In [None]:
def extract_segments(
    file: str,
    min_segment_length: int = 4096,
    tail_trim: int = 512,
):
    """
    Extract no ik failure segments from a emg2pose recording.
    """
    # load raw segments
    with h5py.File(file, "r") as f:
        timeseries: np.ndarray = f["emg2pose"]["timeseries"]  # type: ignore
        joint_angles: np.ndarray = timeseries["joint_angles"]  # (T, 20)  # type: ignore

        # get ik_failure mask
        zeros = np.zeros_like(joint_angles)
        is_zero = np.isclose(joint_angles, zeros)
        ik_failure_mask = ~np.all(is_zero, axis=-1)  # trues if no ik failure

        ones = np.where(ik_failure_mask)[0]

        if ones.shape[0] == 0:
            # the whole file is ik failure
            return []

        boundaries = np.where(np.diff(ones) != 1)[0]
        segments = [
            (ones[i], ones[j])
            for i, j in zip(
                np.insert(boundaries + 1, 0, 0),
                np.append(boundaries, len(ones) - 1),
            )
        ]

    # trim tails of the segments since because of interpolation frames nearby the ik failure are not valid, so we need to throw them out
    segments = [(s[0] + tail_trim, s[1] - tail_trim) for s in segments]
    # there can be segments of negative length

    # finally, filter segments by length
    return [s for s in segments if (s[1] - s[0]) >= min_segment_length]

In [None]:
import pandas as pd

base_path = "C:/Users/shich/emg2pose_data"

metadata = pd.read_csv(f"{base_path}/metadata.csv")
metadata.head()

In [None]:
# check that no users has collisions by slice of the id
users = metadata["user"].unique()
first_two_symbols = {user[-4:] for user in users}
has_collisions = len(first_two_symbols) != len(users)
has_collisions

In [None]:
sessions = metadata["session"].unique()
print(f"Found {len(sessions)} sessions")

In [None]:
start = 0
count = 32
sessions = sessions[start : start + count]

# for each session, load recordings, split to segments and archive them as a separate dataset
for i, session in enumerate(sessions):
    user = metadata[metadata["session"] == session]["user"].unique()[0]

    # load left hand recordings for the session
    recordings = metadata[
        (metadata["session"] == session) & (metadata["side"] == "left")
    ]["filename"].unique()

    # load segments from each recording
    grouped_segments = {}
    for recording in recordings:
        fname = f"{base_path}/{recording}.hdf5"
        segments = extract_segments(fname)
        if len(segments) != 0:
            grouped_segments[fname] = segments

    if len(grouped_segments.keys()) == 0:
        raise ValueError(f"Cannot proper segmentate recordings from {session}")

    # archive segments to a dataset
    archive_dataset(
        grouped_segments,
        f"../datasets/s{start + i + 1}.z",
        desc=f"{i + 1}/{len(sessions)}",
    )