In [2]:
HDR_MAGIC = b"LITPKDS"
HDR_SIZE = 24  # bytes

In [3]:
len(HDR_MAGIC)

7

In [4]:
import numpy as np
import torch
import struct

In [5]:
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}

In [54]:
class PackedDatasetIterator:
    def __init__(self, filenames, n_chunks, block_size):

        self._block_idxs = None

        self._filenames = filenames
        self._file_idx = 0

        self._n_chunks = n_chunks

        self._dtype = None
        self._block_size = block_size
        self._n_blocks = None

        self._mmaps = []
        self._buffers = []

        self._block_idxs = []
        self._curr_idx = 0

        self._load_n_chunks()
    
    def _load_n_chunks(self):
        self._close_mmaps()
        self._mmaps = []
        self._buffers = []

        if self._n_chunks > len(self._filenames[self._file_idx :]):
            self._file_idx = 0

        for i in range(self._n_chunks):
            filename = self._filenames[self._file_idx + i]
            # print("filename :",filename)
            # print("i :",i)
            if self._dtype is None:
                self._dtype, self._chunk_size = self._read_header(filename)
                self._n_blocks = self._chunk_size // self._block_size
            # TODO: check header matches with previous files
            mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
            self._mmaps.append(mmap)
            self._buffers.append(memoryview(mmap))

        self._file_idx += self._n_chunks
        n_all_blocks = self._n_chunks * self._n_blocks

        # self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)

        self._block_idxs = range(n_all_blocks)

        self._curr_idx = 0
    
    def _close_mmaps(self):
        for mmap in self._mmaps:
            mmap._mmap.close()
    
    def _read_header(self, path):
        with open(path, "rb") as f:
            magic = f.read(len(HDR_MAGIC))
            print(magic)
            assert magic == HDR_MAGIC, "File doesn't match expected format."
            version = struct.unpack("<Q", f.read(8))
            assert version == (1,)
            (dtype_code,) = struct.unpack("<B", f.read(1))
            dtype = dtypes[dtype_code]
            (chunk_size,) = struct.unpack("<Q", f.read(8))
        return dtype, chunk_size
    
    def __del__(self):
        self._close_mmaps()
        del self._mmaps
        del self._buffers

    def __iter__(self):
        return self
    
    def __next__(self):
        if self._curr_idx >= len(self._block_idxs):
            self._load_n_chunks()
            # TODO: trigger fetching next next n_chunks if remote
        block_idx = self._block_idxs[self._curr_idx]
        chunk_id = block_idx // self._n_blocks
        buffer = self._buffers[chunk_id]
        elem_id = (block_idx % self._n_blocks) * self._block_size
        offset = np.dtype(self._dtype).itemsize * elem_id
        arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
        self._curr_idx += 1
        return torch.from_numpy(arr.astype(np.int64))


In [55]:
import glob

In [56]:
data_config = [
    ("arxiv", 2.5),
    ("book", 4.5),
    ("c4", 15.0),
    ("cc", 67.0),
    ("github", 4.5),
    ("stackexchange", 2.0),
    ("wikipedia", 4.5),
]

In [57]:
prefix = "arxiv"

In [58]:
from pathlib import Path

In [59]:
data_dir = Path("./data/RedPajama-Data-1T-Sample/*")

In [60]:
filenames = glob.glob("/home/rampfire/Downloads/data/lit-redpajama-sample/ar*")

In [61]:
temp = PackedDatasetIterator(filenames,
                        4,
                        4096)

b'LITPKDS'


In [62]:
from torch.utils.data import IterableDataset

In [74]:
class PackedDataset(IterableDataset):
    def __init__(self):
        pass

    def __iter__(self):

        return PackedDatasetIterator(filenames,
                        4,
                        2049)

In [75]:
test_data = PackedDataset()

In [76]:
test_dataloader = DataLoader(test_data,6)

In [77]:
next(iter(test_dataloader)).shape

b'LITPKDS'


torch.Size([6, 2049])