In [1]:
import os
from tqdm import tqdm
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained(
    "/home/user/models/rugpt",
    use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token

In [3]:
num_proc = 8

ds = load_dataset('Den4ikAI/russian_instructions_2', num_proc=num_proc, trust_remote_code=True)
ds = ds['train'].train_test_split(test_size=0.005, seed=2357, shuffle=True)
ds['val'] = ds.pop('test')

In [4]:
def process(example):
    text = f"Вопрос: {example['question']}\n### Ответ: {example['answer']}"
    text = tokenizer.bos_token + text + tokenizer.eos_token
    tokens = tokenizer(
        text, add_special_tokens=False,
        truncation=False,
    )['input_ids']
    return {"ids": tokens, "len": len(tokens)}

In [5]:
tokenized = ds.map(
    process,
    remove_columns=['question', 'answer'],
    desc="Tokenizing",
    num_proc=num_proc,
)

Tokenizing (num_proc=8):   0%|          | 0/236094 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2098 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2447 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2165 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2222 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2121 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Tokenizing (num_proc=8):   0%|          | 0/1187 [00:00<?, ? examples/s]

In [6]:
for split, dset in tokenized.items():
    arr_len = np.sum(dset['len'], dtype=np.uint64)
    print(arr_len)
    filename = os.path.join('/home/user/data', f'{split}.bin')
    dtype = np.uint16
    arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
    
    idx = 0
    total_batches = 1024
    for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
        batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
        arr_batch = np.concatenate(batch['ids'])
        arr[idx : idx + len(arr_batch)] = arr_batch
        idx += len(arr_batch)
        
    arr.flush()

30168808


writing /home/user/data/train.bin: 100%|██████████| 1024/1024 [00:03<00:00, 329.73it/s]


152874


writing /home/user/data/val.bin: 100%|██████████| 1024/1024 [00:01<00:00, 586.95it/s]
