# check amber mmap token count

In [1]:
import json
import os
import sys
import tempfile

import nltk
import requests

from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
from megatron.tokenizer.gpt2_tokenization import (
    PRETRAINED_MERGES_ARCHIVE_MAP,
    PRETRAINED_VOCAB_ARCHIVE_MAP,
)
from tools.merge_datasets import main as merge_main
from tools.preprocess_data import Encoder
from tools.preprocess_data import get_args as build_args
from tools.preprocess_data import main as build_main

__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB = (
    "https://huggingface.co/bert-base-uncased/raw/main/vocab.txt"
)

Zarr-based strategies will not be registered because of missing packages


In [2]:
slice_path = "full-amber-8node_0_token_ids_document"
dataset = MMapIndexedDataset(slice_path)

In [6]:
dataset[0].shape[0]

2050

In [5]:
len(dataset)

76783616

In [7]:
from tqdm import tqdm
def count_ids(dataset):
    count = 0
    for doc_ids in tqdm(dataset):
        count += doc_ids.shape[0]
    return count

In [8]:
total_cnt = count_ids(dataset)
print("Total number of tokens: ", total_cnt)

100%|██████████| 76783616/76783616 [02:22<00:00, 539157.76it/s]

Total number of tokens:  157406412952





In [12]:
(157406412952/10**9)*8

1259.251303616

In [10]:
157*8

1256

# check slimpajama mmap token count

In [None]:
import json
import os
import sys
import tempfile

import nltk
import requests

from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
from megatron.tokenizer.gpt2_tokenization import (
    PRETRAINED_MERGES_ARCHIVE_MAP,
    PRETRAINED_VOCAB_ARCHIVE_MAP,
)
from tools.merge_datasets import main as merge_main
from tools.preprocess_data import Encoder
from tools.preprocess_data import get_args as build_args
from tools.preprocess_data import main as build_main

__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB = (
    "https://huggingface.co/bert-base-uncased/raw/main/vocab.txt"
)

In [14]:
slice_path = "full-gpt2-8node_0_text_document"
dataset2 = MMapIndexedDataset(slice_path)

In [15]:
len(dataset2)

73799329

In [16]:
total_cnt = count_ids(dataset2)
print("Total number of tokens: ", total_cnt)

100%|██████████| 73799329/73799329 [02:17<00:00, 535339.15it/s]

Total number of tokens:  80730556394





In [17]:
80730556394/10**9

80.730556394

In [18]:
80.730556394*8

645.844451152

In [22]:
print("slimpajama .bin dtype=", dataset2[0].dtype)
print("amber .bin dtype=", dataset[0].dtype)

slimpajama .bin dtype= int32
amber .bin dtype= uint16


In [23]:
2**16-1

65535

# check icp mmap disable shuffle
after disable shuffle, check if mmap data order is the same as jsonl data order

In [1]:
import json
import os
import sys
import tempfile

import nltk
import requests

from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
from megatron.tokenizer.gpt2_tokenization import (
    PRETRAINED_MERGES_ARCHIVE_MAP,
    PRETRAINED_VOCAB_ARCHIVE_MAP,
)
from tools.merge_datasets import main as merge_main
from tools.preprocess_data import Encoder
from tools.preprocess_data import get_args as build_args
from tools.preprocess_data import main as build_main
from tools.preprocess_data import Encoder

__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB = (
    "https://huggingface.co/bert-base-uncased/raw/main/vocab.txt"
)

Zarr-based strategies will not be registered because of missing packages


In [2]:
def build_detokenizer():
    extra_args = [
            "--tokenizer-model",
            "/workspace/megatron/baichuan.tokenizer.model",
            "--tokenizer-type",
            "SentencePieceTokenizer",
            "--append-eod",
            "--workers",
            "1",
            "--log-interval",
            "1",
        ]

    sys.argv = [sys.argv[0], "--input", None, "--output-prefix", None,] + extra_args

    encoder = Encoder(build_args())
    encoder.initializer()
    detok = encoder.tokenizer.detokenize
    return detok

detok = build_detokenizer()

In [3]:
import random

# set random seed
random.seed(42)


def load_head(jsonl_path, n=10):
    jsonls = []
    with open(jsonl_path, "r") as f:
        for i, line in enumerate(f):
            _dict = json.loads(line)
            _dict["jsonl_idx"] = i
            jsonls.append(_dict)
            if len(jsonls) == n:
                break
    return jsonls

def load_random_percent(jsonl_path, percent=1.4e-5, n=100):
    jsonls = []
    with open(jsonl_path, "r") as f:
        for i, line in enumerate(f):
            if random.random() < percent:
                _dict = json.loads(line)
                _dict["jsonl_idx"] = i
                jsonls.append(_dict)
                if len(jsonls) == n:
                    break
    return jsonls



In [4]:
# !pip install fast-edit-distance
# !pip install editdistance
from editdistance import eval as edit_distance
from tqdm import tqdm


def check_match(jsonls, dataset, detok, error_rate_threshold=3):
    match_cnt = 0
    not_match_samples = []
    for _dict in tqdm(jsonls):
        jsonl_text = _dict["text"]
        jsonl_idx = _dict["jsonl_idx"]
        dataset_text = detok(dataset[jsonl_idx].tolist())
        dist = edit_distance(jsonl_text, dataset_text)
        error_rate = (dist / len(jsonl_text)) * 100
        if error_rate > error_rate_threshold:
            print(f"Not match! Error rate: {error_rate}% for jsonl_idx: {jsonl_idx}.")
            not_match_samples.append((_dict, error_rate))
            print(f"jsonl_text: \n{jsonl_text}\n\n\n\ndataset_text: \n{dataset_text}")
        else:
            match_cnt += 1
    print(f"Matched {match_cnt} out of {len(jsonls)}")
    return match_cnt, not_match_samples

In [8]:
# jsonl_path = "/workspace/dataset/test/sample.json"
# jsonl_path = "/workspace/rawdata/slimpajama/slimpajama/in_context_pretraining/sorting_output/chunk1_sorted.jsonl"
jsonl_path = "/workspace/rawdata/slimpajama/slimpajama/in_context_pretraining/sorting_output/final_merged_sorted.jsonl"

# mmap_path = "debug-sample-data-sequential_text_document"
# mmap_path = "slimpajama-icp-chunk1_sorted-8node_text_document"
mmap_path = "slimpajama-icp-final_merged_sorted-8node_text_document"

dataset = MMapIndexedDataset(mmap_path)

# jsonls = load_head(jsonl_path, n=10)
jsonls = load_random_percent(jsonl_path, percent=1.4e-5, n=100)
match_cnt, not_match_samples = check_match(jsonls, dataset, detok, error_rate_threshold=3)

 91%|█████████ | 91/100 [00:07<00:01,  8.99it/s]

Not match! Error rate: 3.2710280373831773% for jsonl_idx: 6539480.
jsonl_text: 
Finished the 2013 Buller Gorge Recreational Walk in 03:21:17 in F3544 class!
She came 37th in class and 264th overall!
No photos of Shelley Dunnings in 2013 , we may still be uploading/processing, check back later.



dataset_text: 
Finished the 2013 Buller Gorge Recreational Walk in 03:21:17 in F3544 class!
She came 37th in class and 264th overall!
No photos of Shelley Dunnings in 2013 , we may still be uploading/processing, check back later. <EOD> 


100%|██████████| 100/100 [00:09<00:00, 10.44it/s]

Matched 99 out of 100





In [13]:
len(dataset), len(jsonls)

(58996336, 10)