In [3]:
import json
import os
from glob import glob
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
}
hashes = 'sha1', 'xxh64'

In [4]:
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")

In [30]:
def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
    }

def slice_and_balance(nested_list, size):
    first = []
    balance = []
    current_size = 0

    for sublist in nested_list:
        if current_size < size:
            remaining_space = size - current_size
            if len(sublist) <= remaining_space:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remaining_space])
                balance.append(sublist[remaining_space:])
                current_size = size
        else:
            balance.append(sublist)
    
    return first, balance

def loop(rows, block_size = 1024*16, folder = 'tokenized-16k'):
    rows, index = rows
    out_root = f'{folder}/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    count = 0
    temp = []
    position_ids = []
    last_block, last_position_block = None, None
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for row in tqdm(rows):
            row = tokenizer.apply_chat_template(row, tokenize=False)
            outputs = tokenizer(row, add_special_tokens = False)
            temp.append(outputs['input_ids'])
            position_ids.append(range(len(outputs['input_ids'])))
            count += len(outputs['input_ids'])
            while count >= block_size:
                block, temp = slice_and_balance(temp, block_size)
                block_position, position_ids = slice_and_balance(position_ids, block_size)
                count = count - block_size
                o = collator(block, block_position)
                last_block = block
                last_position_block = block_position
                out.write(o)

        try:
            block, _ = slice_and_balance(last_block, block_size - count)
            block_position, _ = slice_and_balance(last_position_block, block_size - count)
    
            block.extend(temp)
            block_position.extend(position_ids)
    
            o = collator(block, block_position)
            if len(o['input_ids']) == block_size:
                out.write(o)
                return o
        except:
            pass

In [31]:
import json

with open('prepared-reasoning-data.json') as fopen:
    texts = json.load(fopen)

In [39]:
r = loop((texts, 0))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33717/33717 [01:53<00:00, 297.49it/s]


In [40]:
# multiprocessing(texts, loop, cores = 10, returned = False)

In [41]:
dataset = LocalDataset('tokenized-16k/tokenized-0')
len(dataset)

4577

In [37]:
loop((texts, 0), 2048, 'tokenized-2k')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33717/33717 [01:57<00:00, 286.11it/s]


{'input_ids': array([    25,    478,   4402, ...,    346,     13, 200002], dtype=uint32),
 'position_ids': array([1328, 1329, 1330, ..., 1909, 1910, 1911], dtype=uint32),
 'attention_mask': array([ 171,  293, 1584], dtype=uint32)}

In [38]:
dataset = LocalDataset('tokenized-2k/tokenized-0')
len(dataset)

36610