In [None]:
!wget https://huggingface.co/datasets/mesolitica/Malaysian-SFT/resolve/main/combine/combined-malaysian-sft-10k-sample.jsonl

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-30B-A3B-Instruct-2507')

In [None]:
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
}
hashes = 'sha1', 'xxh64'

In [None]:
combine = []
with open('combined-malaysian-sft-10k-sample.jsonl') as fopen:
    for l in fopen:
        l = json.loads(l)
        combine.append(l)

len(combine)

In [None]:
import gc

def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
    }

def slice_and_balance(nested_list, size):
    first = []
    balance = []
    current_size = 0

    for sublist in nested_list:
        if current_size < size:
            remaining_space = size - current_size
            if len(sublist) <= remaining_space:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remaining_space])
                balance.append(sublist[remaining_space:])
                current_size = size
        else:
            balance.append(sublist)
    
    return first, balance

In [None]:

def loop(files, block_size = 7168):
    rows, index = files
    out_root = f'tokenized-8k/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    count = 0
    temp = []
    position_ids = []
    last_block, last_position_block = None, None
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for row in tqdm(rows):
            prompt = tokenizer.apply_chat_template(row, tokenize=False)
            outputs = tokenizer(prompt, add_special_tokens = False)
            temp.append(outputs['input_ids'])
            position_ids.append(range(len(outputs['input_ids'])))
            count += len(outputs['input_ids'])
            while count >= block_size:
                block, temp = slice_and_balance(temp, block_size)
                block_position, position_ids = slice_and_balance(position_ids, block_size)
                count = count - block_size
                o = collator(block, block_position)
                last_block = block
                last_position_block = block_position
                out.write(o)
                
        block, _ = slice_and_balance(last_block, block_size - count)
        block_position, _ = slice_and_balance(last_position_block, block_size - count)

        block.extend(temp)
        block_position.extend(position_ids)

        o = collator(block, block_position)
        if len(o['input_ids']) == block_size:
            out.write(o)
            return o

In [None]:
from multiprocess import Pool
import itertools


def chunks(l, n):
    for i in range(0, len(l), n):
        yield (l[i: i + n], i // n)


def multiprocessing(strings, function, cores=6, returned=True):
    df_split = chunks(strings, len(strings) // cores)
    pool = Pool(cores)
    pooled = pool.map(function, df_split)
    pool.close()
    pool.join()

    if returned:
        return list(itertools.chain(*pooled))

In [None]:
from multiprocess import Pool

chunks = chunks(combine, 50000)
pool = Pool(10)
pooled = pool.map(loop, chunks)
pool.close()
pool.join()

In [None]:
folders = sorted(glob('tokenized-8k/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

In [None]:
with MDSWriter(out='packing-8k', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = LocalDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

In [None]:
dataset = LocalDataset('packing-8k')
(len(dataset) * 3072) / 1e9

In [None]:
tokenizer.decode(dataset[-3]['input_ids'])


In [None]:
tokenizer.decode(dataset[-2]['input_ids'])


In [None]:
!rm -rf packing-8k

Unit test

In [None]:
from transformers import AutoTokenizer
import numpy as np

# Simulate the environment from latest.ipynb
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B-Base")

texts = [
    "Machine learning enables computers to learn from data.",
    "Natural language processing helps computers understand human language.",
    "Large language models are trained on billions of tokens.",
]

block_size = 40  # small to visualize clearly

# 1️⃣ Tokenize each sample
tokenized = [tokenizer(t, add_special_tokens=False)["input_ids"] for t in texts]
print("---- Original Samples ----")
for i, t in enumerate(tokenized):
    print(f"Sample {i}: len={len(t)} tokens")

# 2️⃣ Exact packing (same logic as latest.ipynb)
def slice_and_balance(nested_list, size):
    first, balance, current_size = [], [], 0
    for sublist in nested_list:
        if current_size < size:
            remain = size - current_size
            if len(sublist) <= remain:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remain])
                balance.append(sublist[remain:])
                current_size = size
        else:
            balance.append(sublist)
    return first, balance

def collator(batch):
    input_ids, position_ids, masks = [], [], []
    for seg in batch:
        L = len(seg)
        input_ids.extend(seg)
        position_ids.extend(range(L))
        masks.append(L)
    return {
        "input_ids": np.array(input_ids, np.uint32),
        "position_ids": np.array(position_ids, np.uint32),
        "attention_mask": np.array(masks, np.uint32),
    }

def pack_exact(samples, block_size=40):
    temp = []
    count = 0
    records = []
    for s in samples:
        temp.append(s)
        count += len(s)
        while count >= block_size:
            block, temp = slice_and_balance(temp, block_size)
            rec = collator(block)
            records.append(rec)
            count -= block_size
    # handle leftover
    if temp:
        rec = collator(temp)
        if len(rec["input_ids"]) == block_size:
            records.append(rec)
    return records

records = pack_exact(tokenized, block_size)

# 3️⃣ Inspect the result
print("\n---- Packed Output ----")
for i, rec in enumerate(records):
    print(f"Record {i}: len(input_ids)={len(rec['input_ids'])}")
    print(f"  attention_mask (segment lengths) = {rec['attention_mask'].tolist()}")
    print(f"  position_ids (first 20) = {rec['position_ids'][:20].tolist()}")

# 4️⃣ Show overall token count conservation
total_input_tokens = sum(len(s) for s in tokenized)
total_packed_tokens = sum(len(r["input_ids"]) for r in records)
print(f"\nOriginal total tokens = {total_input_tokens}")
print(f"Total written tokens  = {total_packed_tokens}")
