In [40]:
from typing import List, NamedTuple
import numpy as np
import yaml
import io
import zipfile

# NOTE: `frame` here refers to hand pose angles

W = 64

class HandEmgTuple(NamedTuple):
    frame: np.ndarray  # (20,), float32 expected
    emg: np.ndarray  # (W, C), float32 expected


class HandEmgRecordingSegment(NamedTuple):
    couples: List[HandEmgTuple]
    sigma: np.ndarray  # (20,), float32 single final frame

class RecordingWriter:
    """
    A context for writing recording by segments

    Each segment is written in format:
    [ [<20 x float32: frame>, <W x C float32: emg>], [...], ... <20 x float32: sigma frame> ]
    """

    def __init__(self, context: "DatasetWriter", index: int):
        self.context = context
        self.index = index
        self.count = 0
    
    def add(self, segment: HandEmgRecordingSegment):
        """
        Add a single recording segment to the ZIP archive.

        Args:
            segment: A list of HandEmgTuple samples. Each sample is stored with its frame
                       (20 float32 values) and its emg (W x C float32 values). The number of EMG
                       channels (C) is determined from the first sample and is assumed to be consistent.
        """
        if self.context.archive is None:
            raise RuntimeError("Archive is not open. Use 'with' statement to open it.")

        bio = io.BytesIO()

        # Determine the number of EMG channels (C) from the first sample.
        C = segment.couples[0].emg.shape[1]
        if self.context.C is None:
            # Store C for metadata
            self.context.C = C
            self.context.archive.writestr("metadata.yml", yaml.dump({"C": C}))

        elif self.context.C != C:
            raise ValueError("Inconsistent number of EMG channels across recordings.")

        # Write each sample: frame (20 float32 values) then emg (W * C float32 values).
        for tup in segment.couples:
            # Verify data types and dimensions.
            assert (
                tup.frame.dtype == np.float32
            ), f"Frame dtype must be float32, got {tup.frame.dtype}"
            assert (
                tup.emg.dtype == np.float32
            ), f"EMG dtype must be float32, got {tup.emg.dtype}"
            assert tup.frame.shape == (
                20,
            ), f"Frame shape must be (20,), got {tup.frame.shape}"
            assert (
                tup.emg.shape[0] == W and tup.emg.shape[1] == C
            ), f"EMG shape must be ({W}, {C}), got {tup.emg.shape}"

            bio.write(tup.frame.tobytes())
            bio.write(tup.emg.flatten().tobytes())

        # Write the final frame (sigma) as well.
        bio.write(segment.sigma.tobytes())

        # Save the segment
        self.context.archive.writestr(f"recordings/{self.index}/segments/{self.count}", bio.getvalue())
        self.count += 1

class DatasetWriter:
    """
    A context manager for writing segments to a ZIP archive in a proprietary binary format.

    Archive looks like this:

    dataset.zip/
      metadata.yml
      recordings/
        1/
          segments/
           1
           2
        2/
          segments/
            1
            2
    """

    def __init__(self, filename: str):
        self.filename = filename
        self.archive = None
        self.recording_index = 0
        self.C: int | None = None  # To store the number of EMG channels

    def __enter__(self):
        self.archive = zipfile.ZipFile(
            self.filename,
            mode="w",
            compression=zipfile.ZIP_LZMA,
        )
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.archive is not None:
            self.archive.close()
    
    def add_recording(self):
        self.recording_index += 1
        return RecordingWriter(self, self.recording_index)

In [43]:
from collections import defaultdict
from tqdm import tqdm
import h5py

base_path = "C:/Users/shich/emg2pose_data"

def archive_dataset(segments: List[tuple], filepath: str):
    grouped_segments = defaultdict(lambda: [])
    for s in segments:
        grouped_segments[s[0]].append(s[1:])
    grouped_segments = dict(grouped_segments)

    print("Archiving...")
    with DatasetWriter(filepath) as dataset_writer:
        with tqdm(total=len(segments), ncols=100) as pbar:
            for index, segments in grouped_segments.items():
                with h5py.File(
                    f"{base_path}/2022-04-07-1649318400-8125c-cv-emg-pose-train@2-recording-{index}_left.hdf5",
                    "r",
                ) as f:
                    timeseries = f["emg2pose"]["timeseries"]  # type: ignore
                    joint_angles = timeseries["joint_angles"]  # type: ignore
                    emg = timeseries["emg"]  # type: ignore

                    recording_writer = dataset_writer.add_recording()

                    for slice in segments:
                        start, end = slice[0], slice[1]

                        slices = (end - start) // W
                        real_end = slices * W + start

                        segment = HandEmgRecordingSegment(couples=[], sigma=joint_angles[real_end])  # type: ignore
                        for i in range(slices): # type: ignore
                            emg_slice = emg[start + i * W : start + (i + 1) * W]  # type: ignore
                            joints = joint_angles[start + i * W]  # type: ignore

                            segment.couples.append(HandEmgTuple(frame=joints, emg=emg_slice))  # type: ignore

                        recording_writer.add(segment)
                        pbar.update(1)

In [None]:
archive_dataset([
    (1, 23000, 38000),
    (1, 45000, 60000),
    (1, 100000, 120000),
    (2, 1000, 16000),
    (2, 16500, 50000),
    (2, 64000, 118000),
    (2, 120000, 128000),
    (3, 1100, 16000),
    (3, 51000, 62966),
    (4, 1600, 11900),
    (4, 20000, 32400),
    (4, 101000, 123667),
    (5, 12000, 35000),
    (5, 39000, 50000),
    (5, 68000, 111000),
    (5, 120000, 130000),
    (6, 1000, 16000),
    (6, 31000, 105000),
    (6, 107000, 132000),
    (6, 135000, 156754),
    (7, 1000, 49000),
    (7, 50000, 85000),
    (7, 86000, 135706),
    (8, 0, 133000),
    (8, 135000, 148923),
    (10, 0, 35000),
    (10, 101000, 166715),
    (11, 10000, 35000),
    (11, 55000, 71000),
    (12, 0, 9000),
    (12, 11000, 30000),
    (12, 35000, 39000),
], "../dataset.zip")

Archiving...


100%|███████████████████████████████████████████████████████████████| 32/32 [00:14<00:00,  2.16it/s]
