In [1]:
import os
import getpass

from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import login
from transformers import AutoTokenizer
from dotenv import load_dotenv
import matplotlib.pyplot as plt

from tqdm import tqdm

In [2]:

load_dotenv('.envrc')
if 'HF_TOKEN' in os.environ:
    login(token=os.environ['HF_TOKEN'])
else:
    login(token=getpass.getpass('Huggingface token: '))

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
CHUNK_SIZE = 512
TARGET_SIZES = [128, 256]
for target_size in TARGET_SIZES:
    assert CHUNK_SIZE % target_size == 0

INPUT_DATASET = f'MikiV/SimpleStories-SimpleStories-chunked-{CHUNK_SIZE}'
OUTPUT_DATASET_NAMES = [INPUT_DATASET.replace('chunked', 'subchunked').replace(str(CHUNK_SIZE), f'{target_size}x{CHUNK_SIZE // target_size}')
                       for target_size in TARGET_SIZES]

In [4]:
def split_into_subchunks(example, chunk_size, subchunk_size):
    text = example['input_ids']
    subchunks = [
        {'input_ids': text[i*subchunk_size:(i+1)*subchunk_size]} for i in range(chunk_size // subchunk_size)
    ]
    return {'subchunks': subchunks}

In [None]:
input_dataset = load_dataset(INPUT_DATASET)
print(f"Loaded dataset: {input_dataset}")

for target_size, output_dataset_name in zip(TARGET_SIZES, OUTPUT_DATASET_NAMES):
    split = lambda x: split_into_subchunks(x, CHUNK_SIZE, target_size)
    output_dataset_lists = DatasetDict({
        split_name: input_dataset[split_name].map(
            split,
            batched=False,
            remove_columns=input_dataset[split_name].column_names,
            num_proc=4
        )
        for split_name in input_dataset.keys()
    })
    output_dataset = DatasetDict({
        split_name: Dataset.from_dict({
            'input_ids': [subchunk['input_ids'] for example in output_dataset_lists[split_name] for subchunk in example['subchunks']]
        })
        for split_name in output_dataset_lists.keys()
    })
    print(f"{output_dataset_name}: {output_dataset}")
    output_dataset.push_to_hub(output_dataset_name, private=True)

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

Loaded dataset: DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 4136391
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 41774
    })
})
