# Recreating and Improving MiniPile Dataset Creation

**Objectives:**
- [.] Implement and verify MiniPile’s filtering pipeline according to [Kaddour (2023)](https://arxiv.org/abs/2304.08442), but intended for decoder-only model use
- [] Evaluate and compare performances of Pythia $160\text{M}$ pretrained on The Pile vs. trained on the *newly, self-created MiniPile* on MMLU and ARC-Challenge
- [] Improve the dataset creation process, create new SuperMiniPile dataset (ideally smaller and more information-retaining)
- [] Train Pythia $160\text{M}$ on SuperMiniPile, evaluate on MMLU and ARC-Challenge
- [] Evaluate and compare performances of Pythia $1.4\text{B}$ pretrained on The Pile vs. trained on the created MiniPile on the MMLU and ARC benchmarks

In [None]:
#! pip install sentence-transformers

In [18]:
import os
import numpy as np
from tqdm import tqdm
from pathlib import Path
from datasets import load_dataset,  Dataset
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer

In [5]:
base_dir = "/mnt/data"
base_path = Path(base_dir)

def download_model(down_dir: str, target_folder: str, cache_folder: str, repo_id: str, branch: str = "main") -> None:
    down_dir = Path(down_dir)
    target_dir = down_dir / target_folder
    cache_dir = down_dir / cache_folder

    os.makedirs(target_dir, exist_ok=True)
    os.makedirs(cache_dir, exist_ok=True)

    print(f"Downloading {repo_id}/{branch}...")

    while True:
        try:
            snapshot_download(
                repo_id,
                repo_type="model",
                revision=branch,
                cache_dir=str(cache_dir),
                local_dir=str(target_dir)
            )
            break
        except Exception as e:
            print(f"Download attempt failed: {e}")
            continue

---

## Recreating MiniPile Dataset Creation

(1) document embedding extraction,<br>
(2) clustering of embeddings, and<br>
(3) human-guided exclusion of unwanted clusters

- 22 datasubset sources
- 5.91 KiB mean document size (before deduplication)

### Document Embedding Extraction

- MiniPile paper uses term "document": I assume as they are quite large, this refers to individual training examples from "The Pile-Deduplicated"
- "The Pile Deduplicated" predominantly contains english text, as stated in the Pile paper
- `E5-Large` does not require performing sentence-splitting beforehand, I was misguided by the example code at https://huggingface.co/intfloat/e5-large
- I will use `E5-Large` with one "sentence" actually being one "document" for MiniPile

In [6]:
# Starting point is the deduplicated The Pile
# Infer embeddings for all documents using E5-Large

# https://huggingface.co/intfloat/e5-large
download_model(down_dir=base_dir, target_folder="e5-large", 
               cache_folder="e5-large_Cache",
               repo_id="intfloat/e5-large") # Chose this because nothing beyond E5-Large was specified

e5_large = SentenceTransformer(str(base_path / "e5-large"), local_files_only=True) # no .from_pretrained() here

# https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated
pile_dedup = load_dataset("parquet",
                          data_files={
                              "train": str(base_path / "Pile_Deduplicated" / "data" / "train-*.parquet"),
                          },
                          cache_dir=str(base_path / "MiniPile_Cache"),
                          split="train",
                          streaming=True)

Downloading intfloat/e5-large/main...


Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1650 [00:00<?, ?it/s]

Given the model and the local-stream for The Pile, we iterate through the dataset and extract the embeddings for each document.<br>
For convenience and later processing, we will assemble an embedding dataset.

Thing is, when creating the embedding dataset, we need to make sure that embedding indices match the document indices in the original dataset.<br>
This is strictly necessary for the filtering step later on.<br>
To ensure the above code's resulting embedding dataset is correctly aligned with the original dataset, I ran the following code for a small subset of `16384` documents:

```python
saved_embeddings = load_dataset("parquet", data_files=str(embd_dir / "shard_*.parquet"), split="train")

for index in tqdm(range(len(saved_embeddings)), desc="Verifying embeddings"):
    # Newly embed the text at this index
    original_text = next(iter(pile_dedup.skip(index).take(1)))['text']
    generated_embedding = e5_large.encode(original_text, show_progress_bar=False)
    # Newly encoded embedding should correspond to the saved embedding at this index -> index consistency
    saved_embedding = saved_embeddings[index]['embedding']
    if not np.allclose(generated_embedding, saved_embedding, atol=1e-5):
        print(f"Mismatch found at index: {index}")
```

No mismatches were found, which means we can scale the embedding set creation to the full dataset.

In [28]:
# Took the example code from the intfloat/e5-large page
embd_dir = base_path / Path("Pile_Deduplicated_Embeddings")
embd_dir.mkdir(exist_ok=True)

batch_size = 1024
shard_size = batch_size ** 2 # shard embed count upper bound

embedding_shard = []
shard_index = 0

def save_shard(embeddings, output_dir, shard_index):
    shard_path = output_dir / f"shard_{shard_index:08d}.parquet"
    dataset = Dataset.from_dict({"embedding": embeddings})
    dataset.to_parquet(str(shard_path))

# Didn't know tqdm could be used like that
for batch_idx, batch in tqdm(enumerate(pile_dedup.iter(batch_size=batch_size))):
    batch_embds = e5_large.encode(batch['text'], show_progress_bar=True) # Set this to False, good for debug but clutters like hell
    embedding_shard.extend(batch_embds)
    
    if len(embedding_shard) >= shard_size:
        save_shard(embedding_shard, embd_dir, shard_index)
        shard_index += 1
        embedding_shard = []

# Append remaining embeddings
if embedding_shard != []:
    save_shard(embedding_shard, embd_dir, shard_index)

0it [00:00, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

1it [00:59, 59.45s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

2it [01:53, 56.05s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

3it [02:51, 56.92s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

4it [03:54, 59.35s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

5it [04:51, 58.46s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

6it [05:46, 57.59s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

7it [06:45, 57.86s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

8it [07:42, 57.65s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

9it [08:39, 57.47s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

10it [09:36, 57.35s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

11it [10:31, 56.56s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

12it [11:28, 56.84s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

13it [12:26, 57.01s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

14it [13:21, 56.57s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

15it [14:18, 56.64s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

15it [15:16, 61.12s/it]
