# K2 Data Prep

## Set up a conda environment

```shell
conda create -n k2-data-prep python=3.10
conda activate k2-data-prep
pip install fire datasets tqdm
```

## Download necessary datasets
- falcon_refinedweb-json: https://huggingface.co/datasets/tiiuae/falcon-refinedweb
- s2orc: https://github.com/allenai/s2orc?tab=readme-ov-file#download-instructions
- pile: https://the-eye.eu/public/AI/pile/
- redpajama: https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T
- star coder: https://huggingface.co/datasets/bigcode/starcoderdata

*note: The Pile is no longer hosted or distributed by The-Eye or afils. It is however still shared among the AI/ML community presumably for its historical significance. Here's one of the solutions the community provides: https://huggingface.co/datasets/monology/pile-uncopyrighted

## Generate pile data split
```shell
python gen_data_split.py
```

## Dedup Pile data split
```shell
python dedup.py
```

## Prepare Raw Data Chunks
```shell
python prepare_jsonl_chunks.py
python prepare_pile_of_law_chunks.py
python prepare_redpajama_chunks.py
python prepare_starcoder_chunks.py
```

## Tokenize
```shell
python tokenize_datasets.py
```

## Create FIM Data for StarCoder
```shell
python starcoder_fim_main.py --spm_rate 0.
python starcoder_fim_main.py --spm_rate 1.
```

## Gather into 360 chunks
```shell
python gather.py
```

## Shuffle
```shell
python shuffle.py
```

## Print Data Mix
```shell
python analyze.py
```

In [3]:
import json
import tqdm


OUTPUT_DIR = '/lustre/scratch/shared-folders/llm_project/bowen.tan/final_data_chunks'
N_CHUNKS = 360


def main():
    counter = {}
    for line in tqdm.tqdm(open(f'{OUTPUT_DIR}/chunk_0.jsonl')):
        example = json.loads(line)
        subset_name = example['subset_name']

        if subset_name.startswith('starcoder.'):
            subset_name = 'starcoder'
            if 'spm0.0' in example['src_filename']:
                subset_name = subset_name + '.FIM'
            elif 'spm1.0' in example['src_filename']:
                subset_name = subset_name + '.SPM'

        counter[subset_name] = counter.get(subset_name, 0) + 1

    n_total = sum(list(counter.values()))
    print(f'{n_total} samples.')
    for key, value in counter.items():
        print(f'{key}\t{value}\t{value / n_total}')


if __name__ == '__main__':
    main()

4it [00:00, 30.33it/s]

1920926it [04:28, 7159.71it/s]


1920926 samples.
dm-math	19514	0.010158642238170548
pile-of-law	113377	0.059022054988063045
pubmed-abstracts	21738	0.011316417186294527
pubmed-central	37637	0.019593154551502765
redpajama.arxiv	40985	0.021336063960818896
redpajama.book	118312	0.06159112844534355
redpajama.stackexchange	89629	0.046659267457465826
redpajama.wikipedia	196791	0.10244590369436407
refinedweb	904359	0.4707932528374336
s2orc	158929	0.08273561813417071
starcoder.FIM	49345	0.025688131661500756
starcoder.SPM	49497	0.025767260165149516
starcoder	98819	0.051443418434650785
uspto	21994	0.011449686245071387
