In [17]:
from huggingface_hub import HfApi
from datasets import load_dataset

In [2]:
api = HfApi()
repo_files = api.list_repo_files("allenai/c4", repo_type="dataset")

In [4]:
repo_files

['.gitattributes',
 'README.md',
 'en.noblocklist/c4-train.00000-of-01024.json.gz',
 'en.noblocklist/c4-train.00001-of-01024.json.gz',
 'en.noblocklist/c4-train.00002-of-01024.json.gz',
 'en.noblocklist/c4-train.00003-of-01024.json.gz',
 'en.noblocklist/c4-train.00004-of-01024.json.gz',
 'en.noblocklist/c4-train.00005-of-01024.json.gz',
 'en.noblocklist/c4-train.00006-of-01024.json.gz',
 'en.noblocklist/c4-train.00007-of-01024.json.gz',
 'en.noblocklist/c4-train.00008-of-01024.json.gz',
 'en.noblocklist/c4-train.00009-of-01024.json.gz',
 'en.noblocklist/c4-train.00010-of-01024.json.gz',
 'en.noblocklist/c4-train.00011-of-01024.json.gz',
 'en.noblocklist/c4-train.00012-of-01024.json.gz',
 'en.noblocklist/c4-train.00013-of-01024.json.gz',
 'en.noblocklist/c4-train.00014-of-01024.json.gz',
 'en.noblocklist/c4-train.00015-of-01024.json.gz',
 'en.noblocklist/c4-train.00016-of-01024.json.gz',
 'en.noblocklist/c4-train.00017-of-01024.json.gz',
 'en.noblocklist/c4-train.00018-of-01024.json.gz'

In [10]:
tree = api.list_repo_tree(
    "allenai/c4",
    path_in_repo="multilingual",
    recursive=True,
    repo_type="dataset",
)

In [11]:
tree_files = [file for file in tree]

In [13]:
len(tree_files)

59410

In [15]:
tree_files

[RepoFile(path='multilingual/c4-af-validation.tfrecord-00000-of-00001.json.gz', size=2246813, blob_id='9f9fa70f024b504ed8e1d4bea368a9ade03e91bb', lfs=BlobLfsInfo(size=2246813, sha256='341a806fcb73f4c59217f0f14ea24cac7a5611af441b15e76d281e3c53d8e695', pointer_size=132), last_commit=None, security=None),
 RepoFile(path='multilingual/c4-af.tfrecord-00000-of-00064.json.gz', size=36767863, blob_id='e44630649f50a217c91719b93f532340439b85ab', lfs=BlobLfsInfo(size=36767863, sha256='a9222d9694768e4e3cc2e9fed8ce0367e380db46aa029374ea52bb7e6108b47c', pointer_size=133), last_commit=None, security=None),
 RepoFile(path='multilingual/c4-af.tfrecord-00001-of-00064.json.gz', size=37092484, blob_id='864fd9c34c5728cd2907fc374df5a7d367fce974', lfs=BlobLfsInfo(size=37092484, sha256='3dcb8e41928ec01698a149c9a9e15324558b6a49a520b0f7f88deb785f28d6d9', pointer_size=133), last_commit=None, security=None),
 RepoFile(path='multilingual/c4-af.tfrecord-00002-of-00064.json.gz', size=36028152, blob_id='c6009b9be094a

In [16]:
def assign_batches(files, target_batch_size_bytes=1_000_000_000):  # 1GB
    batches = []
    current_batch = []
    current_batch_size = 0
    
    # Sort files by size to help create more balanced batches
    sorted_files = sorted(files, key=lambda x: x.size, reverse=True)
    
    for file in sorted_files:
        # If adding this file would exceed target batch size and we already have files in the batch
        if current_batch_size + file.size > target_batch_size_bytes and current_batch:
            batches.append(current_batch)
            current_batch = [file]
            current_batch_size = file.size
        else:
            current_batch.append(file)
            current_batch_size += file.size
    
    # Add the last batch if not empty
    if current_batch:
        batches.append(current_batch)
    
    # Create a mapping of file path to batch number
    file_to_batch = {}
    for batch_idx, batch in enumerate(batches):
        for file in batch:
            file_to_batch[file.path] = batch_idx
    
    # Create result with file path and batch number
    result = [(file.path, file_to_batch[file.path], file.size) for file in files]
    
    # Print batch statistics
    total_size_gb = sum(file.size for file in files) / 1_000_000_000
    print(f"Total size: {total_size_gb:.2f} GB")
    print(f"Number of batches: {len(batches)}")
    print(f"Average batch size: {total_size_gb/len(batches):.2f} GB")
    
    return result, batches

batch_assignments, batches = assign_batches(tree_files)

# Display the first few assignments
batch_assignments[:10]

Total size: 10624.47 GB
Number of batches: 13506
Average batch size: 0.79 GB


[('multilingual/c4-af-validation.tfrecord-00000-of-00001.json.gz',
  13505,
  2246813),
 ('multilingual/c4-af.tfrecord-00000-of-00064.json.gz', 13223, 36767863),
 ('multilingual/c4-af.tfrecord-00001-of-00064.json.gz', 13203, 37092484),
 ('multilingual/c4-af.tfrecord-00002-of-00064.json.gz', 13262, 36028152),
 ('multilingual/c4-af.tfrecord-00003-of-00064.json.gz', 13203, 37104125),
 ('multilingual/c4-af.tfrecord-00004-of-00064.json.gz', 13201, 37131148),
 ('multilingual/c4-af.tfrecord-00005-of-00064.json.gz', 13214, 36913294),
 ('multilingual/c4-af.tfrecord-00006-of-00064.json.gz', 13261, 36066423),
 ('multilingual/c4-af.tfrecord-00007-of-00064.json.gz', 13257, 36172293),
 ('multilingual/c4-af.tfrecord-00008-of-00064.json.gz', 13184, 37349683)]

In [22]:
test = [i[0] for i in batch_assignments if i[1] == 0]

In [33]:
dataset = load_dataset(
    "allenai/c4",
    data_files=test,
    cache_dir="./c4",
    num_proc=4
)

In [51]:
import tldextract


def get_tld(batch):
    urls = batch["url"]
    domains = [tldextract.extract(url).domain + "." + tldextract.extract(url).suffix for url in urls]
    return {"domain": domains}

dataset = dataset.select_columns(["url"])
dataset = dataset.map(get_tld, batched=True, num_proc=4)

Map (num_proc=4): 100%|██████████| 546712/546712 [00:00<00:00, 632285.50 examples/s]


In [54]:
dataset.push_to_hub(
    "nhagar/c4_test",
    data_dir="chunks",
    max_shard_size="1GB"
)

Creating parquet from Arrow format: 100%|██████████| 547/547 [00:00<00:00, 2950.39ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:05<00:00,  5.72s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/nhagar/c4_test/commit/47c362103cd24223c7fc4df6a4b8b5b445f00d45', commit_message='Upload dataset', commit_description='', oid='47c362103cd24223c7fc4df6a4b8b5b445f00d45', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/nhagar/c4_test', endpoint='https://huggingface.co', repo_type='dataset', repo_id='nhagar/c4_test'), pr_revision=None, pr_num=None)

In [55]:
dataset

DatasetDict({
    train: Dataset({
        features: ['url', 'domain'],
        num_rows: 546712
    })
})