In [1]:
from typing import Optional, Dict, Any, Union, Sequence
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from streaming import Stream, StreamingDataset
import numpy as np
import torch
from torch.utils.data import DataLoader
from datetime import datetime

In [2]:
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

class StreamingTextDataset(StreamingDataset):
    def __init__(
        self,
        tokenizer: Tokenizer,
        max_seq_len: int,
        streams: Optional[Sequence[Stream]] = None,
        remote: Optional[str] = None,
        local: Optional[str] = None,
        split: Optional[str] = None,
        download_retry: int = 2,
        download_timeout: float = 60,
        validate_hash: Optional[str] = None,
        keep_zip: bool = False,
        epoch_size: Optional[int] = None,
        predownload: int = 100_000,
        partition_algo: str = "orig",
        num_canonical_nodes: Optional[int] = None,
        batch_size: Optional[int] = None,
        shuffle: bool = False,
        shuffle_algo: str = "py1s",
        shuffle_seed: int = 9176,
        cache_limit: Optional[int] = None,
        **kwargs: Dict[str, Any],
    ):
        # Build Dataset
        super().__init__(
            streams=streams,
            remote=remote,
            local=local,
            split=split,
            download_retry=download_retry,
            download_timeout=download_timeout,
            validate_hash=validate_hash,
            keep_zip=keep_zip,
            epoch_size=epoch_size,
            predownload=predownload,
            partition_algo=partition_algo,
            num_canonical_nodes=num_canonical_nodes,
            batch_size=batch_size,
            shuffle=shuffle,
            shuffle_algo=shuffle_algo,
            shuffle_seed=shuffle_seed,
            cache_limit=cache_limit,
        )
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    # How to tokenize a text sample to a token sample
    def _tokenize(self, text_sample):
        if self.tokenizer._pad_token is None:
            # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
            raise RuntimeError("If tokenizing on-the-fly, tokenizer must have a pad_token_id")

        encoded = self.tokenizer(text_sample["text"], truncation=True, padding="max_length", max_length=self.max_seq_len)
        return {
            "input_ids": encoded["input_ids"],
            "attention_mask": encoded["attention_mask"],
            "token_type_ids": encoded.get("token_type_ids", [0] * self.max_seq_len),
        }

    def _read_binary_tokenized_sample(self, sample):
        seq_len = sample["len"] if "len" in sample else len(sample["input_ids"])

        input_ids = np.frombuffer(sample["input_ids"], dtype=np.int64).copy()
        if "attention_mask" in sample:
            attention_mask = np.frombuffer(sample["attention_mask"], dtype=np.int64).copy()
        else:
            attention_mask = np.ones_like(input_ids)

        # calculate padding
        pad_len = self.max_seq_len - seq_len

        # pad or truncate input_ids and attention_mask
        if pad_len > 0:
            input_ids = np.pad(input_ids, (0, pad_len), constant_values=self.tokenizer.pad_token_id)
            attention_mask = np.pad(attention_mask, (0, pad_len), constant_values=0)
        elif pad_len < 0:
            input_ids = input_ids[: self.max_seq_len]
            attention_mask = attention_mask[: self.max_seq_len]

        token_type_ids = np.zeros(self.max_seq_len, dtype=np.int64)

        return {
            "input_ids": input_ids.tolist(),
            "attention_mask": attention_mask.tolist(),
            "token_type_ids": token_type_ids.tolist(),
            "source": sample["source"],
        }

    # How to process a sample
    def __getitem__(self, idx: int) -> Union[Dict[str, Any], torch.Tensor]:
        sample = super().__getitem__(idx)
        if "text" in sample:
            token_sample = {
                **self._tokenize(sample),
                "source": sample["_source"],
            }
        elif "input_ids" in sample:
            token_sample = self._read_binary_tokenized_sample(sample)
        else:
            raise RuntimeError("StreamingTextDataset needs samples to have a `text` or `input_ids` column")
        return token_sample


In [3]:
import os 
os.getcwd()

'/home/nlp/achimoa/workspace/hebrew_text_retrieval/notebooks/data'

In [4]:
def get_dataloader(batch_size=4096, shuffle=False, predownload=100_000):
    # Set up the dataset
    dataset = StreamingTextDataset(
        tokenizer=AutoTokenizer.from_pretrained("/home/nlp/achimoa/workspace/ModernHebrewBERT/tokenizer"),
        max_seq_len=1024,
        local="/home/nlp/achimoa/workspace/hebrew_text_retrieval/data/hebrew_modernbert/v20250428",         # Path to your local MDS directory
        remote=None,                 # No remote; purely local
        split="train",               # MDS split name
        shuffle=shuffle,                # Enable shuffling
        shuffle_algo="py1s",         # Optional: fast shuffling algorithm
        batch_size=batch_size,                # Optional: to control shuffle buffering
        predownload=predownload,            # How many samples to prefetch into the shuffle buffer
    )

    # Wrap in a DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader

In [None]:
dataloader = get_dataloader(batch_size=4096, shuffle=True)

total_samples = 0
total_ones = 0
end = datetime.now()
for i, batch in enumerate(dataloader):
    if i > 10:
        break
    start = datetime.now()
    print(f"Time to load batch #{i}: {(start-end).total_seconds():.2f}s")
    transposed = list(zip(*batch['attention_mask']))  # now [batch_size x seq_len]

    for sample_mask in transposed:
        total_ones += sum(sample_mask).item()
        total_samples += 1
    end = datetime.now()
    print(f"batch #{i}: rolling average = {total_ones/total_samples:.2f} (total_samples = {total_samples}) [{(end - start).total_seconds():.2f}s]")


Because `shuffle_block_size` was not specified, it will default to max(4_000_000 // num_canonical_nodes, 1 << 18) if num_canonical_nodes is not None, otherwise 262144. Prior to Streaming v0.7.0, `shuffle_block_size` defaulted to 262144.


Time to load batch #0: 100.40s
batch #0: rolling average = 348.58 (total_samples = 4096) [29.00s]
Time to load batch #1: 16.55s
batch #1: rolling average = 349.55 (total_samples = 8192) [42.33s]
Time to load batch #2: 16.00s
batch #2: rolling average = 350.71 (total_samples = 12288) [36.05s]
Time to load batch #3: 15.46s
batch #3: rolling average = 349.59 (total_samples = 16384) [35.89s]
Time to load batch #4: 15.22s
batch #4: rolling average = 350.10 (total_samples = 20480) [39.01s]
Time to load batch #5: 17.04s
batch #5: rolling average = 349.05 (total_samples = 24576) [36.01s]
Time to load batch #6: 15.94s
batch #6: rolling average = 349.52 (total_samples = 28672) [34.48s]
Time to load batch #7: 14.39s
batch #7: rolling average = 348.54 (total_samples = 32768) [38.34s]
Time to load batch #8: 33.37s
batch #8: rolling average = 348.16 (total_samples = 36864) [39.65s]
Time to load batch #9: 43.88s
batch #9: rolling average = 348.00 (total_samples = 40960) [40.11s]
Time to load batch #1

In [None]:
dataloader = get_dataloader(batch_size=4096, shuffle=False)

total_samples = 0
total_ones = 0
end = datetime.now()
for i, batch in enumerate(dataloader):
    if i > 10:
        break
    start = datetime.now()
    print(f"Time to load batch #{i}: {(start-end).total_seconds():.2f}s")
    transposed = list(zip(*batch['attention_mask']))  # now [batch_size x seq_len]

    for sample_mask in transposed:
        total_ones += sum(sample_mask).item()
        total_samples += 1
    end = datetime.now()
    print(f"batch #{i}: rolling average = {total_ones/total_samples:.2f} (total_samples = {total_samples}) [{(end - start).total_seconds():.2f}s]")


Because `shuffle_block_size` was not specified, it will default to max(4_000_000 // num_canonical_nodes, 1 << 18) if num_canonical_nodes is not None, otherwise 262144. Prior to Streaming v0.7.0, `shuffle_block_size` defaulted to 262144.


Time to load batch #0: 38.03s
batch #0: rolling average = 100.15 (total_samples = 4096) [41.06s]
Time to load batch #1: 10.41s
batch #1: rolling average = 102.68 (total_samples = 8192) [37.77s]
Time to load batch #2: 9.45s
batch #2: rolling average = 100.97 (total_samples = 12288) [37.36s]
Time to load batch #3: 10.06s
batch #3: rolling average = 99.97 (total_samples = 16384) [35.87s]
Time to load batch #4: 11.23s
batch #4: rolling average = 99.72 (total_samples = 20480) [39.37s]
Time to load batch #5: 9.83s
batch #5: rolling average = 99.15 (total_samples = 24576) [33.08s]
Time to load batch #6: 9.63s
batch #6: rolling average = 99.89 (total_samples = 28672) [36.85s]
Time to load batch #7: 9.09s
batch #7: rolling average = 99.45 (total_samples = 32768) [33.20s]
Time to load batch #8: 9.77s
batch #8: rolling average = 100.21 (total_samples = 36864) [36.83s]
Time to load batch #9: 9.73s
batch #9: rolling average = 101.10 (total_samples = 40960) [32.98s]
Time to load batch #10: 9.21s
bat

In [6]:
dataloader = get_dataloader(batch_size=128, shuffle=True)

total_samples = 0
total_ones = 0
end = datetime.now()
for i, batch in enumerate(dataloader):
    if i > 10:
        break
    start = datetime.now()
    print(f"Time to load batch #{i}: {(start-end).total_seconds():.2f}s")
    transposed = list(zip(*batch['attention_mask']))  # now [batch_size x seq_len]

    for sample_mask in transposed:
        total_ones += sum(sample_mask).item()
        total_samples += 1
    end = datetime.now()
    print(f"batch #{i}: rolling average = {total_ones/total_samples:.2f} (total_samples = {total_samples}) [{(end - start).total_seconds():.2f}s]")


Because `shuffle_block_size` was not specified, it will default to max(4_000_000 // num_canonical_nodes, 1 << 18) if num_canonical_nodes is not None, otherwise 262144. Prior to Streaming v0.7.0, `shuffle_block_size` defaulted to 262144.


Time to load batch #0: 52.09s
batch #0: rolling average = 337.09 (total_samples = 128) [7.44s]
Time to load batch #1: 0.73s
batch #1: rolling average = 349.04 (total_samples = 256) [6.38s]
Time to load batch #2: 0.65s
batch #2: rolling average = 347.68 (total_samples = 384) [6.85s]
Time to load batch #3: 0.59s
batch #3: rolling average = 348.18 (total_samples = 512) [5.94s]
Time to load batch #4: 0.47s
batch #4: rolling average = 340.21 (total_samples = 640) [1.19s]
Time to load batch #5: 0.63s
batch #5: rolling average = 341.70 (total_samples = 768) [1.16s]
Time to load batch #6: 0.49s
batch #6: rolling average = 337.70 (total_samples = 896) [1.20s]
Time to load batch #7: 0.56s
batch #7: rolling average = 339.68 (total_samples = 1024) [1.21s]
Time to load batch #8: 0.56s
batch #8: rolling average = 338.37 (total_samples = 1152) [0.82s]
Time to load batch #9: 0.56s
batch #9: rolling average = 336.34 (total_samples = 1280) [1.25s]
Time to load batch #10: 0.63s
batch #10: rolling average

In [8]:
dataloader = get_dataloader(batch_size=4096, shuffle=True, predownload=10_000)

total_samples = 0
total_ones = 0
end = datetime.now()
for i, batch in enumerate(dataloader):
    if i > 10:
        break
    start = datetime.now()
    print(f"Time to load batch #{i}: {(start-end).total_seconds():.2f}s")
    transposed = list(zip(*batch['attention_mask']))  # now [batch_size x seq_len]
    print(batch['source'])
    for sample_mask in transposed:
        total_ones += sum(sample_mask).item()
        total_samples += 1
    end = datetime.now()
    print(f"batch #{i}: rolling average = {total_ones/total_samples:.2f} (total_samples = {total_samples}) [{(end - start).total_seconds():.2f}s]")


Because `shuffle_block_size` was not specified, it will default to max(4_000_000 // num_canonical_nodes, 1 << 18) if num_canonical_nodes is not None, otherwise 262144. Prior to Streaming v0.7.0, `shuffle_block_size` defaulted to 262144.


Time to load batch #0: 67.58s
['YifatDataBatch2-Round3-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'YifatDataBatch2-Round3-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'AllHebNLIFiles-Deduped-D2.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'HeC4DictaCombined-Clean-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'corpus_sampled_50B.jsonl', 'hebrew_tweets_text_clean_full-Deduped.forgpt.jsonl', 'YifatToCombine-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'HeC4DictaCombined-Clean-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'corpus_sampled_50B.jsonl', 'YifatToCombine-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'AllHebNLIFiles-Deduped-D2.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'BooksNLI2-Combined-Deduped.forgpt.jsonl', 'YifatToCombine-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'YifatDataBatch3-Round5-DedupedD2.forgpt.jsonl', 'HeC4DictaCombined-Clean-Deduped.forgpt.jsonl', 'YifatToCombine-Deduped.forgpt.jsonl', 'corpus_sampled_50B.jsonl', 'YifatToC

In [9]:
dataloader = get_dataloader(batch_size=160, shuffle=True, predownload=10_000)

total_samples = 0
total_ones = 0
end = datetime.now()
for i, batch in enumerate(dataloader):
    if i > 10:
        break
    start = datetime.now()
    print(f"Time to load batch #{i}: {(start-end).total_seconds():.2f}s")
    transposed = list(zip(*batch['attention_mask']))  # now [batch_size x seq_len]
    # print(batch['source'])
    for sample_mask in transposed:
        total_ones += sum(sample_mask).item()
        total_samples += 1
    end = datetime.now()
    print(f"batch #{i}: rolling average = {total_ones/total_samples:.2f} (total_samples = {total_samples}) [{(end - start).total_seconds():.2f}s]")


Because `shuffle_block_size` was not specified, it will default to max(4_000_000 // num_canonical_nodes, 1 << 18) if num_canonical_nodes is not None, otherwise 262144. Prior to Streaming v0.7.0, `shuffle_block_size` defaulted to 262144.


Time to load batch #0: 49.07s
batch #0: rolling average = 358.74 (total_samples = 160) [7.26s]
Time to load batch #1: 0.67s
batch #1: rolling average = 338.45 (total_samples = 320) [1.11s]
Time to load batch #2: 0.74s
batch #2: rolling average = 349.69 (total_samples = 480) [1.09s]
Time to load batch #3: 0.61s
batch #3: rolling average = 340.21 (total_samples = 640) [1.05s]
Time to load batch #4: 0.74s
batch #4: rolling average = 340.14 (total_samples = 800) [1.65s]
Time to load batch #5: 0.71s
batch #5: rolling average = 340.13 (total_samples = 960) [1.16s]
Time to load batch #6: 0.68s
batch #6: rolling average = 338.16 (total_samples = 1120) [1.45s]
Time to load batch #7: 0.67s
batch #7: rolling average = 336.34 (total_samples = 1280) [1.56s]
Time to load batch #8: 0.88s
batch #8: rolling average = 342.21 (total_samples = 1440) [1.48s]
Time to load batch #9: 0.79s
batch #9: rolling average = 343.46 (total_samples = 1600) [1.06s]
Time to load batch #10: 0.99s
batch #10: rolling averag