In [13]:
import json
import os
from glob import glob
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
from multiprocess import Pool
import itertools

def chunks(l, n):
    for i in range(0, len(l), n):
        yield (l[i: i + n], i // n)

def multiprocessing(strings, function, cores=6, returned=True):
    df_split = chunks(strings, len(strings) // cores)
    pool = Pool(cores)
    pooled = pool.map(function, df_split)
    pool.close()
    pool.join()

    if returned:
        return list(itertools.chain(*pooled))

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
}
hashes = 'sha1', 'xxh64'

In [4]:
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-1.7B-Base')

In [5]:
# !wget https://huggingface.co/datasets/malaysia-ai/pretrain-text-dataset/resolve/main/wikipedia-2023-10-01.jsonl

In [6]:
!wc -l wikipedia-2023-10-01.jsonl

438316 wikipedia-2023-10-01.jsonl


In [7]:
texts = []
with open('wikipedia-2023-10-01.jsonl') as fopen:
    for l in tqdm(fopen):
        l = json.loads(l)
        if len(l) > 1:
            texts.append(l)

438316it [00:01, 284358.71it/s]


In [8]:
def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
    }

def slice_and_balance(nested_list, size):
    first = []
    balance = []
    current_size = 0

    for sublist in nested_list:
        if current_size < size:
            remaining_space = size - current_size
            if len(sublist) <= remaining_space:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remaining_space])
                balance.append(sublist[remaining_space:])
                current_size = size
        else:
            balance.append(sublist)
    
    return first, balance

def loop(rows, block_size = 4096, folder = 'tokenized-4k-qwen'):
    rows, index = rows
    out_root = f'{folder}/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    count = 0
    temp = []
    position_ids = []
    last_block, last_position_block = None, None
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for row in tqdm(rows):
            outputs = tokenizer(row, add_special_tokens = False)
            temp.append(outputs['input_ids'])
            position_ids.append(range(len(outputs['input_ids'])))
            count += len(outputs['input_ids'])
            while count >= block_size:
                block, temp = slice_and_balance(temp, block_size)
                block_position, position_ids = slice_and_balance(position_ids, block_size)
                count = count - block_size
                o = collator(block, block_position)
                last_block = block
                last_position_block = block_position
                out.write(o)
                
        block, _ = slice_and_balance(last_block, block_size - count)
        block_position, _ = slice_and_balance(last_position_block, block_size - count)

        block.extend(temp)
        block_position.extend(position_ids)

        o = collator(block, block_position)
        if len(o['input_ids']) == block_size:
            out.write(o)
            return o

In [9]:
# loop((texts, 0))

In [None]:
_ = multiprocessing(texts, loop, cores = 20)

In [15]:
folders = sorted(glob('tokenized-4k-qwen/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

['tokenized-4k-qwen/tokenized-0',
 'tokenized-4k-qwen/tokenized-1',
 'tokenized-4k-qwen/tokenized-2',
 'tokenized-4k-qwen/tokenized-3',
 'tokenized-4k-qwen/tokenized-4',
 'tokenized-4k-qwen/tokenized-5',
 'tokenized-4k-qwen/tokenized-6',
 'tokenized-4k-qwen/tokenized-7',
 'tokenized-4k-qwen/tokenized-8',
 'tokenized-4k-qwen/tokenized-9',
 'tokenized-4k-qwen/tokenized-10',
 'tokenized-4k-qwen/tokenized-11',
 'tokenized-4k-qwen/tokenized-12',
 'tokenized-4k-qwen/tokenized-13',
 'tokenized-4k-qwen/tokenized-14',
 'tokenized-4k-qwen/tokenized-15',
 'tokenized-4k-qwen/tokenized-16',
 'tokenized-4k-qwen/tokenized-17',
 'tokenized-4k-qwen/tokenized-18',
 'tokenized-4k-qwen/tokenized-19',
 'tokenized-4k-qwen/tokenized-20']

In [16]:
!rm -rf multipacking

In [17]:
with MDSWriter(out='multipacking', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = LocalDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████| 3538/3538 [00:00<00:00, 10516.82it/s]
100%|██████████| 2459/2459 [00:00<00:00, 10172.88it/s]
100%|██████████| 943/943 [00:00<00:00, 5555.07it/s]
100%|██████████| 1168/1168 [00:00<00:00, 20330.62it/s]
100%|██████████| 323/323 [00:00<00:00, 2347.84it/s]
100%|██████████| 604/604 [00:00<00:00, 19859.67it/s]
100%|██████████| 1095/1095 [00:00<00:00, 20381.39it/s]
100%|██████████| 673/673 [00:00<00:00, 4219.48it/s]
100%|██████████| 291/291 [00:00<00:00, 20265.03it/s]
100%|██████████| 270/270 [00:00<00:00, 20451.89it/s]
100%|██████████| 235/235 [00:00<00:00, 19258.35it/s]
100%|██████████| 357/357 [00:00<00:00, 19947.86it/s]
100%|██████████| 524/524 [00:00<00:00, 3537.77it/s]
100%|██████████| 269/269 [00:00<00:00, 20316.70it/s]
100%|██████████| 976/976 [00:00<00:00, 20661.49it/s]
100%|██████████| 1360/1360 [00:00<00:00, 7347.94it/s]
100%|██████████| 941/941 [00:00<00:00, 20656.62it/s]
100%|██████████| 1832/1832 [00:00<00:00, 8782.87it/s]
100%|██████████| 1869/1869 [00:00<00:00,

In [18]:
dataset = LocalDataset('multipacking')
len(dataset)

22206

In [21]:
dataset[-2]

{'attention_mask': array([ 325,  302,   45, 1377,  562,   74,  516,  538,   87,  270],
       dtype=uint32),
 'input_ids': array([57635,  5103,   300, ...,    17,    11, 10371], dtype=uint32),
 'position_ids': array([1483, 1484, 1485, ...,  267,  268,  269], dtype=uint32)}

In [None]:
!hf upload Scicom-intl/mosaic-ms-wikipedia-2023-10-01 multipacking --repo-type=dataset