In [104]:
import numpy as np
import torch
from torch.utils.data import IterableDataset
from typing import *
import struct
from functools import reduce

In [116]:
class DatasetWriter:
    def __init__(self, map_file: str, fields: Dict[str, np.dtype]):
        self.file_handle = open(map_file, 'wb')
        self.fields = list(fields.items())
        self.fields.sort()
        self.entry_count = 0
        self.file_handle.write(self.get_header())

    @staticmethod
    def dtype_to_id(dtype: np.dtype):
        dtypes = [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64, np.float16, np.float32, np.float64]
        assert dtype in dtypes
        return dtypes.index(dtype)

    def add_entry(self, **kwargs: np.array):
        assert all(k in [n for n, d in self.fields] for k in kwargs)
        assert all(kwargs[n].dtype == d for n, d in self.fields if n in kwargs)
        entry = [kwargs[name] if name in kwargs else None for name, dtype in self.fields]
        for e in entry:
            entry_header = b''
            if e is None:
                entry_header += struct.pack('I', 0)
            else:
                entry_header += struct.pack('I', len(e.shape))
                entry_header += b''.join(struct.pack('I', d) for d in e.shape)
            self.file_handle.write(entry_header)
            if e is not None:
                self.file_handle.write(e.tobytes('C'))
        self.entry_count += 1

    def get_header(self):
        header_data = b''
        for field_name, field_dtype in self.fields:
            field_name = field_name.encode('utf-8')
            header_data += struct.pack(f'I{len(field_name)}sB', len(field_name), field_name, self.dtype_to_id(field_dtype))
        header_data = struct.pack('I', len(header_data)) + header_data
        header_data = header_data + struct.pack('Q', self.entry_count)
        return header_data

    def finish(self):
        self.file_handle.seek(0)
        self.file_handle.write(self.get_header())
        self.file_handle.close()


class DatasetReader(IterableDataset):
    def __init__(self, map_file: str):
        self.file_handle = open(map_file, 'rb')
        self.fields, self.entry_count = self.__read_header()
        self.cur_entry_index = 0

    @staticmethod
    def id_to_dtype(dtype_id: int):
        dtypes = [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64, np.float16, np.float32, np.float64]
        return dtypes[dtype_id]

    def __read_header(self):
        header_length = struct.unpack('I', self.file_handle.read(4))[0]
        header, header_cursor = self.file_handle.read(header_length), 0
        fields = []
        while header_cursor < header_length:
            name_length = struct.unpack('I', header[header_cursor:header_cursor + 4])[0]
            name, dtype_id = struct.unpack(f'{name_length}sB', header[header_cursor + 4:header_cursor + 4 + name_length + 1])
            fields.append((name, self.id_to_dtype(dtype_id)))
            header_cursor += (4 + name_length + 1)
        entry_count = struct.unpack('Q', self.file_handle.read(8))[0]
        return fields, entry_count

    def reset(self):
        self.file_handle.seek(0)
        self.__read_header()

    @staticmethod
    def calc_entry_field_size(shape, dtype):
        num_elem = reduce(lambda a, b: a * b, shape, 1)
        dtype_size_map = {np.uint8: 1, np.int8: 1, np.uint16: 2, np.int16: 2, np.uint32: 4, np.int32: 4, np.uint64: 8, np.int64: 8,
                          np.float16: 2, np.float32: 4, np.float64: 8}
        return num_elem * dtype_size_map[dtype]

    def __next__(self):
        if self.cur_entry_index == self.entry_count:
            raise StopIteration()
        entry = {}
        for field_name, field_dtype in self.fields:
            num_dims = struct.unpack('I', self.file_handle.read(4))[0]
            if num_dims == 0:
                entry[field_name] = None
                continue
            dims = struct.unpack('I' * num_dims, self.file_handle.read(4 * num_dims))
            data = np.fromfile(self.file_handle, field_dtype, reduce(lambda a, b: a * b, dims, 1))
            print(data, dims, data.reshape(dims))
            entry[field_name] = data


# w = DatasetWriter('test_writer.pt', {'input_ids': np.uint16, 'attn_mask': np.uint8})
# w.add_entry(input_ids=np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16))
# w.finish()
r = DatasetReader('test_writer.pt')
next(r)

[1 2 3 4 5 6] (2, 3) [[1 2 3]
 [4 5 6]]
