In [35]:
from typing import List, Dict, Any
import glob
from datasets import load_dataset, IterableDataset
from itertools import islice, chain
import json
import sentencepiece as spm
from transformers import PreTrainedTokenizerFast

In [18]:
model_file = "/home/nlp/achimoa/workspace/hebrew_text_retrieval/outputs/tokenizer/HebrewModernBERT_mixed_1M_100K.model"
tokenizer = spm.SentencePieceProcessor(model_file=model_file)

In [None]:
def tokenize(text: str) -> List[str]:
    return tokenizer.encode(text)

def filter_valid_text(example):
    return example.get("text") not in [None, "", "null"]

def keep_text_source(example):
    return {
        "text": example["text"],
        "source": example.get("source")
    }

def load_and_sample_by_tokens(files, token_budget, tokenizer, shuffle_buffer, seed):
    dataset = (
        load_dataset("json", data_files=files, split="train", streaming=True)
        .filter(filter_valid_text)
        .map(keep_text_source)
        .shuffle(buffer_size=shuffle_buffer, seed=seed)
    )

    def generator():
        total_tokens = 0
        for i, example in enumerate(dataset):
            token_count = len(tokenizer.encode(example["text"]))
            if total_tokens + token_count > token_budget:
                print(f"Reached token budget limit ({i+1} samples). Stopping.")
                break
            total_tokens += token_count
            yield example

    return generator


In [34]:
# Paths to all files
all_files = glob.glob('../../data/**/*.json.gz', recursive=True)

# Separate StarCoder files from the rest
starcoder_files = [f for f in all_files if 'starcoder' in f]
other_files = [f for f in all_files if 'starcoder' not in f]

In [36]:
def combined_generator():
    starcoder_gen = load_and_sample_by_tokens(
        starcoder_files,
        token_budget=25_000_000_000,
        tokenizer=tokenizer,
        shuffle_buffer=100_000,
        seed=42
    )
    other_gen = load_and_sample_by_tokens(
        other_files,
        token_budget=25_000_000_000,
        tokenizer=tokenizer,
        shuffle_buffer=1_000_000,
        seed=42
    )
    return chain(starcoder_gen(), other_gen())

combined_dataset = IterableDataset.from_generator(combined_generator)

In [None]:
shuffled_dataset = combined_dataset.shuffle(buffer_size=1_000_000, seed=42)

In [None]:
with open(f"../../data/dolma/corpus_sampled_50B.jsonl", "w") as f:
    for sample in shuffled_dataset:
        json.dump(sample, f, ensure_ascii=False)
        f.write("\n")

Resolving data files:   0%|          | 0/49 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/2370 [00:00<?, ?it/s]