In [1]:
import os

os.environ['HF_HOME'] = '/home/ubuntu/scicom'

In [2]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="Scicom-intl/sort-multilingual-tts",
    repo_type='dataset', local_dir='./'
)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 441 files: 100%|██████████| 441/441 [00:00<00:00, 2839.54it/s]


'/home/ubuntu/scicom'

In [3]:
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
from glob import glob
import numpy as np
import json
import pandas as pd
from tqdm import tqdm
from multiprocess import Pool
import itertools

def chunks(l, n):
    for i in range(0, len(l), n):
        yield (l[i: i + n], i // n)

def multiprocessing(strings, function, cores=6, returned=True):
    df_split = chunks(strings, len(strings) // cores)
    pool = Pool(cores)
    pooled = pool.map(function, df_split)
    pool.close()
    pool.join()

    if returned:
        return list(itertools.chain(*pooled))



In [4]:
class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
    'audio': 'str',
    'text': 'str'
}
hashes = 'sha1', 'xxh64'

In [7]:
# df = pd.read_parquet('sort-merge.parquet')
# df

In [6]:
with open('bin-packing.json') as fopen:
    bin_packing = json.load(fopen)

len(bin_packing)

614552

In [9]:
!rm -rf multipacking
!mkdir multipacking

In [10]:
def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
        'audio': '',
        'text': '',
    }

def loop(indices):
    indices, index = indices
    out_root = f'multipacking/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    df = pd.read_parquet('sort-merge.parquet')
    datasets = {}
    for f in glob('tokenized-4k-qwen3/*'):
        dataset = LocalDataset(local=f)
        datasets[f] = dataset

    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for packing in tqdm(indices):
            rows = df.iloc[packing]
            temp = []
            position_ids = []
            for k in range(rows.shape[0]):
                input_ids = datasets[rows.iloc[k]['f']][rows.iloc[k]['i']]['input_ids']
                temp.append(input_ids)
                position_ids.append(range(len(input_ids)))
                
            o = collator(temp, position_ids)
            out.write(o)

In [13]:
# loop((bin_packing[:100], 0))

In [None]:
multiprocessing(bin_packing, loop, cores = 30, returned = False)

100%|██████████| 20485/20485 [00:53<00:00, 380.33it/s]
100%|██████████| 2/2 [00:00<00:00, 41.28it/s]3.61it/s]
100%|██████████| 20485/20485 [01:03<00:00, 324.95it/s]
100%|██████████| 20485/20485 [01:02<00:00, 327.16it/s]
100%|██████████| 20485/20485 [01:06<00:00, 308.59it/s]
100%|██████████| 20485/20485 [01:09<00:00, 295.83it/s]
100%|██████████| 20485/20485 [01:12<00:00, 283.08it/s]
100%|██████████| 20485/20485 [01:14<00:00, 275.91it/s]
100%|██████████| 20485/20485 [01:11<00:00, 288.40it/s]
100%|██████████| 20485/20485 [01:17<00:00, 263.90it/s]
100%|██████████| 20485/20485 [01:18<00:00, 259.58it/s]
100%|██████████| 20485/20485 [01:20<00:00, 254.34it/s]
100%|██████████| 20485/20485 [01:36<00:00, 212.32it/s]
100%|██████████| 20485/20485 [01:25<00:00, 239.94it/s]
100%|██████████| 20485/20485 [01:22<00:00, 249.18it/s]
100%|██████████| 20485/20485 [01:26<00:00, 236.32it/s]
100%|██████████| 20485/20485 [01:39<00:00, 206.53it/s]
100%|██████████| 20485/20485 [01:27<00:00, 233.93it/s]
100%|█████

In [None]:
folders = sorted(glob('multipacking/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

In [None]:
!rm -rf multipacking-final

In [None]:
with MDSWriter(out='multipacking-final', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = LocalDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass