diff --git a/.github/workflows/fa2_unit_tests.yaml b/.github/workflows/fa2_unit_tests.yaml index cc8e58ee8..342be45e3 100644 --- a/.github/workflows/fa2_unit_tests.yaml +++ b/.github/workflows/fa2_unit_tests.yaml @@ -39,7 +39,7 @@ jobs: python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - name: Instal nanotron + - name: Install nanotron run: | python -m pip install --upgrade pip pip install packaging @@ -55,4 +55,4 @@ jobs: - name: Run tests # NOTE: -m fa2 will only run the unit tests that have the mark # "fa2" (these are FA2-related tests) - run: pytest -m fa2 --color=yes --durations=0 --ignore tests/fp8 --verbose tests/ + run: pytest -m fa2 --color=yes --durations=0 --ignore tests/fp8 --ignore tests/nanoset --verbose tests/ diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 000000000..9cbbf6803 --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,15 @@ +on: + push: + +name: Secret Leaks + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main diff --git a/Makefile b/Makefile index b9e181686..0ab20da62 100644 --- a/Makefile +++ b/Makefile @@ -14,3 +14,9 @@ test: --ignore tests/fp8 \ --verbose \ examples/doremi/tests/ + + pip install -r examples/llama/requirements.txt + pytest \ + --color=yes \ + --verbose \ + examples/llama/tests/ diff --git a/docs/nanoset.md b/docs/nanoset.md index 02649bd0b..9dce21b73 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -1,41 +1,42 @@ # Nanosets -Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a kind of datasets based on [numpy memory-mapped arrays](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html). `Nanosets` are capable of serving batches from files containing pre-tokenized datasets. They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches. +Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a dataset for processing tokenized documents with [`datatrove`](https://github.com/huggingface/datatrove). They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches. ## Install To use `Nanosets`, it's necessary to install Nanotron with the `nanosets` flavor. ``` -pip install -e '.[nanosets]' +pip install nanotron[nanosets] ``` This will install the following dependencies: -- `transformers`: To tokenize the datasets -- `datasets`: To preprocess the datasets +- `datatrove`: To preprocess the datasets - `numba`: To compile helper functions in order to speed up the creation of `Nanosets` +- `transformers`: For the tokenizers ## Data pre-processing -To use these datasets, first, we need to preprocess the data. The input format can either be a column of a Hugging Face Dataset or a .json file containing a text sample per line. For example: +To use this dataset, first, we need to preprocess the data using `datatrove`'s `DocumentTokenizer` pipeline. We invite you to take a look at `datatrove`, since it contains multiple features that allow, for example, filter out documents based on specific rules/criteria, extract text content from raw formats or scheduling the preprocessing in a Slurm cluster. We have also added a simple script capable of tokenizing datasets. -
-{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
-{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
-
- -The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. Below we show an example for processing a corpus with the Llama2 tokenizer. +The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. The input format can either be a Hugging Face Dataset, a path to a `.jsonl` or a path to a folder containing multiple `.jsonl` files. Below we show an example for processing a Hugging Face Dataset from the Hub with the Llama3 tokenizer.
-torchrun --nproc-per-node 16 tools/preprocess_data.py \
-       --input HuggingFaceH4/testing_alpaca_small \
-       --split train \
-       --column completion \
-       --output-prefix datasets/testing_alpaca_small \
-       --tokenizer-name-or-path openai-community/gpt2
+python3 tools/preprocess_data.py \
+       --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B \
+       --output-folder datasets/emotion \
+       --n-tasks 16 \
+       hf \
+       --dataset dair-ai/emotion \
 
-The preprocessing script has to be launched with `torchrun` in order to spawn `--nproc-per-node` workers that will preprocess the dataset concurrently. The `--input` dataset can be either a Hugging Face Dataset from the Hub or a `.json` file. The processed dataset will be stored in *`--output-prefix`_input_ids.npy*. In `--tokenizer-name-or-path`, we will have to specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. +First with `--tokenizer-name-or-path` we will specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. Then we specify the `--output-folder` where we will store the tokenized documents and the number of workers with `--n-tasks`. Finally we will indicate the type of dataset (whether if it's a Hugging Face Dataset ["**hf**"] or in jsonl ["**jsonl**"] format) and the dataset that we want to preprocess. Check the different settings with `python3 tools/preprocess_data.py --help`, `python3 tools/preprocess_data.py hf --help` & `python3 tools/preprocess_data.py jsonl --help`. -The output will be one file named, in this case, `datasets/testing_alpaca_small_input_ids.npy`. We will then have to specify this file in the `dataset_path` field in the config file. +Every worker will store in `--output-folder` 3 different kind of files: +- `*.ds` Containing the tokenized documents +- `*.ds.index` Containing the bounds of each tokenized document +- `*.ds.metadata` Containing the number of tokens and tokenizer used + +> [!IMPORTANT] +Remember to introduce the type of dataset to process. e.g. python3 tools/preprocess_data.py --tokenizer-name-or-path gpt2 --n-tasks 16 **jsonl** --dataset raw_datasets/c4-es-json-files ## Working with Nanosets To work with `Nanosets`, we just need to configure 1 argument: -1. `dataset_path`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it: +1. `dataset_folder`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it: 1. If we specify a single path, we will create a `Nanoset` from a single dataset file. ```yaml data_stages: @@ -43,7 +44,7 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 1 data: dataset: - dataset_path: datasets/SlimPajama-6B_input_ids.npy + dataset_folder: datasets/SlimPajama-6B num_loading_workers: 0 seed: 1234 ``` @@ -54,9 +55,9 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 15 data: dataset: - dataset_path: - - datasets/SlimPajama-6B_input_ids.npy - - datasets/testing_alpaca_small_input_ids.npy + dataset_folder: + - datasets/SlimPajama-6B + - datasets/testing_alpaca_small num_loading_workers: 0 seed: 1234 ``` @@ -67,9 +68,9 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 25 data: dataset: - dataset_path: - datasets/SlimPajama-6B_input_ids.npy: 0.8 - datasets/testing_alpaca_small_input_ids.npy: 0.2 + dataset_folder: + datasets/SlimPajama-6B: 0.8 + datasets/testing_alpaca_small: 0.2 num_loading_workers: 0 seed: 1234 ``` @@ -82,7 +83,10 @@ torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml ``` ## Under the hood -`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. The `dataset lengths` of each dataset will be determined by the `(dataset_number_of_tokens - 1) / sequence length`, discarding the last sample if its length < `sequence length`. +`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. Despite most of the extracting logic lies in `DatatroveFolderDataset`, `Nanosets` will take care of the following: +1. Creating dataset mixtures from different dataset folder paths +2. Ensure that in each epoch, we consume each sample only once +3. Ensure that we never exhaust the `DataLoader` Based on the `dataset lengths`, the `dataset weights` and the `number of samples per epoch` (defined as the `sum(dataset lengths)`), we build the two indexes we need in order to extract samples from the `Nanoset` ([build_nanoset_index_helper](../src/nanotron/data/nanoset.py)): - `dataset index`: Contains the index of the dataset from the list of `dataset paths` from which to extract the sample, respecting the established dataset weight. diff --git a/examples/config_nanoset.yaml b/examples/config_nanoset.yaml index 31f23bf0d..127ddb5e0 100644 --- a/examples/config_nanoset.yaml +++ b/examples/config_nanoset.yaml @@ -7,25 +7,25 @@ checkpoints: data_stages: - data: dataset: - dataset_path: datasets/testing_alpaca_small_input_ids.npy + dataset_folder: datasets/c4-es/tokenized num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) start_training_step: 1 - data: dataset: - dataset_path: - - datasets/yelp_review_full_input_ids.npy - - datasets/testing_alpaca_small_input_ids.npy + dataset_folder: + - datasets/SlimPajama-6B/tokenized + - datasets/c4-es/tokenized num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) start_training_step: 15 - data: dataset: - dataset_path: - datasets/testing_alpaca_small_input_ids.npy: 0.8 - datasets/yelp_review_full_input_ids.npy: 0.2 + dataset_folder: + datasets/SlimPajama-6B/tokenized: 0.8 + datasets/c4-es/tokenized: 0.2 num_loading_workers: 1 seed: 42 name: Third purpose training (Blended dataset) @@ -57,7 +57,7 @@ model: initializer_range: 0.02 intermediate_size: 64 is_llama_config: true - max_position_embeddings: 256 + max_position_embeddings: 1024 num_attention_heads: 4 num_hidden_layers: 2 num_key_value_heads: 4 @@ -67,7 +67,7 @@ model: rope_scaling: null tie_word_embeddings: true use_cache: true - vocab_size: 32000 + vocab_size: 50257 optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 @@ -88,11 +88,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 2 + dp: 1 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 2 + tp: 1 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER profiler: null @@ -105,6 +105,6 @@ tokens: limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 2 - sequence_length: 128 + sequence_length: 1024 train_steps: 200 val_check_interval: -1 diff --git a/examples/llama/README.md b/examples/llama/README.md new file mode 100644 index 000000000..d8915d38a --- /dev/null +++ b/examples/llama/README.md @@ -0,0 +1,17 @@ +## Debugging the tests with vscode + +To debug the tests with vscode, add the following json to your `launch.json` file. + +``` +{ + "name": "Test conversion", + "type": "python", + "request": "launch", + "module": "pytest", + "console": "integratedTerminal", + "args": [ + "examples/llama/tests" + ], + "justMyCode": false +} +``` diff --git a/examples/llama/__init__.py b/examples/llama/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/llama/convert_hf_to_nanotron.py b/examples/llama/convert_hf_to_nanotron.py new file mode 100644 index 000000000..9fc81949b --- /dev/null +++ b/examples/llama/convert_hf_to_nanotron.py @@ -0,0 +1,119 @@ +""" +Converts a HF model to nanotron format +Command: + torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --checkpoint_path=hf_weights --save_path=nanotron_weights +""" + +import dataclasses +import json +from argparse import ArgumentParser +from pathlib import Path + +import nanotron +import torch +from convert_weights import get_config_mapping, get_weight_mapping, load_nanotron_model +from nanotron.config import LlamaConfig as NanotronLlamaConfig +from nanotron.models.llama import LlamaForTraining +from transformers import LlamaConfig as HFLlamaConfig +from transformers import LlamaForCausalLM + + +def _handle_attention_block( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, n_q_heads: int, n_kv_heads: int, d_qk: int +) -> torch.Tensor: + # Huggingface Llama separates the q, k, v weights (as opposed to nanotron). + # Furthermore, in the rotary embeddings in nanotron expects interleaved pairs of even + # and odd dimensions GPT-J style, while the huggingface implementation expects + # the whole 1st half and then the whole 2nd half GPT-NeoX style (for more information + # see flash_attn.layers.rotary.RotaryEmbedding). + # This function handles the concatenation of the q, k, v weights and proper permutation + # to ensure correct transformation. + + def interleave(w: torch.Tensor): + w_new = [] + for head_w in w.split(d_qk): + head_w = head_w.view(2, d_qk // 2, -1).transpose(0, 1).reshape(d_qk, -1) + w_new.append(head_w) + return torch.cat(w_new) + + q = interleave(q) + k = interleave(k) + return torch.cat([q, k, v]) + + +def convert_hf_to_nt(model_hf: LlamaForCausalLM, model_nt: LlamaForTraining, config: NanotronLlamaConfig): + """Converts the weights from the model_hf to model_nt, making modifications + in-place.""" + + hf_sd = model_hf.state_dict() + nt_to_hf = get_weight_mapping(config, nt_to_hf=True) + + for module_name_nt, module_nt in model_nt.named_modules(): + for param_name_nt, param_nt in module_nt.named_parameters(recurse=False): + # In the case of qkv_proj, the nt_to_hf has exactly three keys, ccorresponding + # to q, k, v. + if "qkv_proj" in module_name_nt: + key_k, key_q, key_v = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"]) + q = hf_sd[key_q] + k = hf_sd[key_k] + v = hf_sd[key_v] + param = _handle_attention_block( + q, + k, + v, + config.num_attention_heads, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + ) + # The case of gate_up_proj, nt_to_hf_map has two keys. + elif "gate_up_proj" in module_name_nt: + key_gate, key_up = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"]) + gate = hf_sd[key_gate] + up = hf_sd[key_up] + param = torch.cat([gate, up]) + # All other cases are simple 1-to-1 correspondence. + else: + hf_key = nt_to_hf[f"{module_name_nt}.{param_name_nt}"] + param = hf_sd[hf_key] + + with torch.no_grad(): + param_nt.copy_(param) + + +def get_nanotron_config(config: HFLlamaConfig) -> NanotronLlamaConfig: + """Converts a huggingface configuration to nanotron configuration.""" + attrs = {key: getattr(config, value) for key, value in get_config_mapping(nt_to_hf=True).items()} + return NanotronLlamaConfig(**attrs) + + +def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): + """Loads the huggingface checkpoint in `checkpoint_path`, creates + a new nanotron instance, copies the weights from the huggingface checkpoint + and saves the transformed nanotron to `save_path`.""" + + # Load huggingface. + hf_model = LlamaForCausalLM.from_pretrained(checkpoint_path) + + # Init nanotron model. + model_config = get_nanotron_config(hf_model.config) + nanotron_model = load_nanotron_model(model_config=model_config) + + # Copy weights and save model. + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=1 + ) + convert_hf_to_nt(hf_model, nanotron_model, model_config) + nanotron.serialize.save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) + with open(save_path / "model_config.json", "w+") as f: + json.dump(dataclasses.asdict(model_config), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint") + parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the nanotron model") + args = parser.parse_args() + + # Convert HF model to nanotron format. + convert_checkpoint_and_save(checkpoint_path=args.checkpoint_path, save_path=args.save_path) diff --git a/examples/llama/convert_nanotron_to_hf.py b/examples/llama/convert_nanotron_to_hf.py new file mode 100644 index 000000000..e11b27da6 --- /dev/null +++ b/examples/llama/convert_nanotron_to_hf.py @@ -0,0 +1,154 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=nanotron-path --save_path=hf-path +""" + +import json +from argparse import ArgumentParser +from pathlib import Path +from typing import Literal, Optional + +import torch +from convert_weights import get_config_mapping, get_weight_mapping, load_nanotron_model +from nanotron.config import LlamaConfig as NanotronLlamaConfig +from nanotron.models import init_on_device_and_dtype +from nanotron.models.llama import LlamaForTraining +from transformers import AutoTokenizer, LlamaForCausalLM +from transformers import LlamaConfig as HFLlamaConfig + +TEST_PROMPT = "What is the meaning of the word chutzpah?\nThe word chutzpah means" + + +def _handle_attention_block( + qkv: torch.Tensor, part: Literal["q", "k", "v"], n_q_heads: int, n_kv_heads: int, d_qk: int +) -> torch.Tensor: + # Huggingface Llama separates the q, k, v weights (as opposed to nanotron). + # Furthermore, in the rotary embeddings in nanotron expects interleaved pairs of even + # and odd dimensions GPT-J style, while the huggingface implementation expects + # the whole 1st half and then the whole 2nd half GPT-NeoX style (for more information + # see flash_attn.layers.rotary.RotaryEmbedding). + # This function selects the proper chunk of the bundled qkv tensor and permutation + # to ensure correct transformation to huggingface. + + def interleave(w: torch.Tensor): + w_new = [] + for head_w in w.split(d_qk): + head_w = head_w.view(d_qk // 2, 2, -1).transpose(0, 1).reshape(d_qk, -1) + w_new.append(head_w) + return torch.cat(w_new) + + assert part in ["q", "k", "v"], "part must be one of [q, k, v]" + + index_end_q = n_q_heads * d_qk + index_end_k = index_end_q + n_kv_heads * d_qk + if part == "q": + return interleave(qkv[:index_end_q]) + if part == "k": + return interleave(qkv[index_end_q:index_end_k]) + return qkv[index_end_k:] + + +def _handle_gate_up_proj(gate_up_proj: torch.Tensor, gate: bool) -> torch.Tensor: + # The gate and up projection are bundled in nanotron. + # This function selects the proper chunk in the bundled weights to return + # either the gate or the up projection only. + weight_size = gate_up_proj.shape[0] // 2 + if gate: + return gate_up_proj[:weight_size] + else: + return gate_up_proj[weight_size:] + + +def convert_nt_to_hf(nanotron_model: LlamaForTraining, hf_model: LlamaForCausalLM, model_config: NanotronLlamaConfig): + """Converts the weights from the nanotron_model to hf_model, making modifications + in-place.""" + + nanotron_model_state_dict = nanotron_model.state_dict() + + hf_to_nt = get_weight_mapping(model_config, nt_to_hf=False) + for module_name_hf, module_hf in hf_model.named_modules(): + for param_name_hf, param_hf in module_hf.named_parameters(recurse=False): + # Get the Nanotron parameter + nanotron_key = hf_to_nt[f"{module_name_hf}.{param_name_hf}"] + param = nanotron_model_state_dict[nanotron_key] + + if "qkv_proj" in nanotron_key: + proj_name = module_name_hf.split(".")[4][0] + param = _handle_attention_block( + param, + proj_name, + model_config.num_attention_heads, + model_config.num_key_value_heads, + model_config.hidden_size // model_config.num_attention_heads, + ) + + elif "gate_up_proj" in nanotron_key: + gate = "gate" in module_name_hf + param = _handle_gate_up_proj(param, gate) + + with torch.no_grad(): + param_hf.copy_(param) + + +def get_hf_config(config: NanotronLlamaConfig) -> HFLlamaConfig: + """Converts a nanotron configuration to huggingface configuration.""" + attrs = {key: getattr(config, value) for key, value in get_config_mapping(nt_to_hf=False).items()} + return HFLlamaConfig(**attrs) + + +def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str] = None): + """Loads the nanotron checkpoint in `checkpoint_path`, creates + a new huggingface instance, copies the weights from the nanotron checkpoint + and saves the transformed huggingface to `save_path`.""" + + # Init nanotron model. + with open(checkpoint_path / "model_config.json", "r") as f: + attrs = json.load(f) + model_config = NanotronLlamaConfig(**attrs) + nanotron_model = load_nanotron_model( + model_config=model_config, + checkpoint_path=checkpoint_path, + ) + # Init huggingface model. + with init_on_device_and_dtype(torch.device("cuda"), torch.bfloat16): + model_config_hf = get_hf_config(model_config) + hf_model = LlamaForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + convert_nt_to_hf(nanotron_model, hf_model, model_config) + hf_model.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +def check_converted_model_generation(save_path: Path): + """Loads a huggingface model and tokenizer from `save_path` and + performs a dummy text generation.""" + + tokenizer = AutoTokenizer.from_pretrained(save_path) + input_ids = tokenizer(TEST_PROMPT, return_tensors="pt")["input_ids"].cuda() + print("Inputs:", tokenizer.batch_decode(input_ids)) + + model = LlamaForCausalLM.from_pretrained(save_path).cuda().bfloat16() + out = model.generate(input_ids, max_new_tokens=100) + print("Generation (converted): ", tokenizer.batch_decode(out)) + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert Nanotron weights to HF format") + parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint") + parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the HF model") + parser.add_argument("--tokenizer_name", type=str, default="meta-llama/Llama-2-7b-chat-hf") + args = parser.parse_args() + + # Convert Nanotron model to HF format. + convert_checkpoint_and_save( + checkpoint_path=args.checkpoint_path, save_path=args.save_path, tokenizer_name=args.tokenizer_name + ) + + # Check if the conversion was successful by generating some text. + if args.tokenizer_name is not None: + check_converted_model_generation(save_path=args.save_path) diff --git a/examples/llama/convert_weights.py b/examples/llama/convert_weights.py new file mode 100644 index 000000000..7663399a6 --- /dev/null +++ b/examples/llama/convert_weights.py @@ -0,0 +1,141 @@ +import json +from pathlib import Path +from typing import Optional + +import nanotron +import torch +from nanotron.config import LlamaConfig as NanotronLlamaConfig +from nanotron.models.llama import LlamaForTraining +from nanotron.trainer import mark_tied_parameters + + +def get_weight_mapping(config: NanotronLlamaConfig, nt_to_hf: bool = True) -> dict[str, str]: + """Returns the nanotron to huggingface parameter mapping if `nt_to_hf`, otherwise the + huggingface to nanotron mapping.""" + + hf_to_nt_map = {} + hf_to_nt_map["lm_head.weight"] = "model.lm_head.pp_block.weight" + hf_to_nt_map["model.embed_tokens.weight"] = "model.token_position_embeddings.pp_block.token_embedding.weight" + hf_to_nt_map["model.norm.weight"] = "model.final_layer_norm.pp_block.weight" + hf_to_nt_map["model.embed_tokens.weight"] = "model.token_position_embeddings.pp_block.token_embedding.weight" + + for i in range(config.num_hidden_layers): + hf_prefix = f"model.layers.{i}" + nt_prefix = f"model.decoder.{i}.pp_block" + hf_to_nt_map[f"{hf_prefix}.self_attn.q_proj.weight"] = f"{nt_prefix}.attn.qkv_proj.weight" + hf_to_nt_map[f"{hf_prefix}.self_attn.k_proj.weight"] = f"{nt_prefix}.attn.qkv_proj.weight" + hf_to_nt_map[f"{hf_prefix}.self_attn.v_proj.weight"] = f"{nt_prefix}.attn.qkv_proj.weight" + hf_to_nt_map[f"{hf_prefix}.self_attn.o_proj.weight"] = f"{nt_prefix}.attn.o_proj.weight" + hf_to_nt_map[f"{hf_prefix}.mlp.gate_proj.weight"] = f"{nt_prefix}.mlp.gate_up_proj.weight" + hf_to_nt_map[f"{hf_prefix}.mlp.gate_proj.bias"] = f"{nt_prefix}.mlp.gate_up_proj.bias" + hf_to_nt_map[f"{hf_prefix}.mlp.up_proj.weight"] = f"{nt_prefix}.mlp.gate_up_proj.weight" + hf_to_nt_map[f"{hf_prefix}.mlp.up_proj.bias"] = f"{nt_prefix}.mlp.gate_up_proj.bias" + hf_to_nt_map[f"{hf_prefix}.mlp.down_proj.weight"] = f"{nt_prefix}.mlp.down_proj.weight" + hf_to_nt_map[f"{hf_prefix}.mlp.down_proj.bias"] = f"{nt_prefix}.mlp.down_proj.bias" + hf_to_nt_map[f"{hf_prefix}.input_layernorm.weight"] = f"{nt_prefix}.input_layernorm.weight" + hf_to_nt_map[f"{hf_prefix}.post_attention_layernorm.weight"] = f"{nt_prefix}.post_attention_layernorm.weight" + + if nt_to_hf: + nt_to_hf_map = {} + for hf, nt in hf_to_nt_map.items(): + # Because the qkv and gate_up projections are separated in the + # huggingface format, when we return nanotron to huggingface + # we will need to return a list of parameters instead (e.g. + # the `qkv_proj` will point to a list `[q_proj, k_proj, v_proj]`). + if nt in nt_to_hf_map and isinstance(nt_to_hf_map[nt], list): + nt_to_hf_map[nt].append(hf) + elif nt in nt_to_hf_map: + nt_to_hf_map[nt] = [nt_to_hf_map[nt], hf] + else: + nt_to_hf_map[nt] = hf + return nt_to_hf_map + return hf_to_nt_map + + +def get_config_mapping(nt_to_hf: bool = True) -> dict[str, str]: + """Returns either the nanotron to huggingface (if `nt_to_hf`) + configuration mapping, or the huggingface to nanotron.""" + + hf_to_nt_map = { + "bos_token_id": "bos_token_id", + "eos_token_id": "eos_token_id", + "hidden_act": "hidden_act", + "hidden_size": "hidden_size", + "initializer_range": "initializer_range", + "intermediate_size": "intermediate_size", + "max_position_embeddings": "max_position_embeddings", + "num_attention_heads": "num_attention_heads", + "num_hidden_layers": "num_hidden_layers", + "num_key_value_heads": "num_key_value_heads", + "pad_token_id": "pad_token_id", + "pretraining_tp": "pretraining_tp", + "rms_norm_eps": "rms_norm_eps", + "rope_scaling": "rope_scaling", + "rope_theta": "rope_theta", + "tie_word_embeddings": "tie_word_embeddings", + "use_cache": "use_cache", + "vocab_size": "vocab_size", + } + if nt_to_hf: + return {nt: hf for hf, nt in hf_to_nt_map.items()} + return hf_to_nt_map + + +def make_parallel_config( + dp: int = 1, + pp: int = 1, + tp: int = 1, +): + parallel_config = nanotron.config.ParallelismArgs( + dp=dp, + pp=pp, + tp=tp, + pp_engine=nanotron.config.AllForwardAllBackwardPipelineEngine(), + tp_mode=nanotron.config.TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + return parallel_config + + +def load_nanotron_model( + model_config: Optional[NanotronLlamaConfig] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +) -> LlamaForTraining: + """ + Creates and returns a nanotron model. + If `model_config` is None, then `checkpoint_path` must be set, in which case + the configuration will be loaded from such path. + If `checkpoint_path` is None, then `model_config` must be set, in which case + the model created will have random weights. + """ + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = NanotronLlamaConfig(**json.load(f)) + parallel_config = make_parallel_config() + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + nanotron_model = nanotron.models.build_model( + model_builder=lambda: LlamaForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + # Load checkpoint directly in memory and then only keep the state dictionary + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=nanotron_model, parallel_context=parallel_context, root_folder=checkpoint_path + ) + return nanotron_model diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt new file mode 100644 index 000000000..440127437 --- /dev/null +++ b/examples/llama/requirements.txt @@ -0,0 +1 @@ +transformers==4.39.3 diff --git a/examples/llama/tests/test_conversion.py b/examples/llama/tests/test_conversion.py new file mode 100644 index 000000000..b5ce35290 --- /dev/null +++ b/examples/llama/tests/test_conversion.py @@ -0,0 +1,251 @@ +# ruff: noqa: E402 +import dataclasses +import json +from pathlib import Path + +import pytest +import torch +from transformers import LlamaForCausalLM +from utils import set_system_path + +set_system_path() + +import nanotron +from nanotron.config import LlamaConfig as NanotronLlamaConfig +from nanotron.models.base import init_on_device_and_dtype +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.trainer import mark_tied_parameters + +from examples.llama.convert_hf_to_nanotron import convert_checkpoint_and_save as convert_hf_to_nt_and_save +from examples.llama.convert_hf_to_nanotron import convert_hf_to_nt +from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save +from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config +from examples.llama.convert_weights import load_nanotron_model, make_parallel_config +from tests.helpers.context import TestContext +from tests.helpers.utils import init_distributed + +CONFIG = NanotronLlamaConfig( + **{ + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 1024, + "is_llama_config": True, + "max_position_embeddings": 128, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pad_token_id": None, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 4096, + } +) + + +BATCH_SIZE = 3 +SEQUENCE_LENGTH = 5 +ATOL = 0.03 + + +def create_nanotron_model(parallel_context: ParallelContext) -> LlamaForTraining: + parallel_config = make_parallel_config( + tp=parallel_context.tensor_parallel_size, + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + ) + nanotron_model = nanotron.models.build_model( + model_builder=lambda: LlamaForTraining( + config=CONFIG, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=torch.bfloat16, + device=torch.device("cuda"), + ) + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + return nanotron_model + + +def create_huggingface_model() -> LlamaForCausalLM: + config_hf = get_hf_config(CONFIG) + with init_on_device_and_dtype(torch.device("cuda"), torch.bfloat16): + model_hf = LlamaForCausalLM._from_config(config_hf) + return model_hf + + +@pytest.fixture(autouse=True, scope="module") +def fix_seed(): + torch.manual_seed(0) + yield + + +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda") + + +def _test_nt_to_hf(parallel_context: ParallelContext, input_ids: torch.Tensor): + model_nt = create_nanotron_model(parallel_context) + model_hf = create_huggingface_model() + convert_nt_to_hf(model_nt, model_hf, CONFIG) + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + logits_hf = model_hf(input_ids).logits + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) + + +def test_nt_to_hf(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf)(input_ids=input_ids) + + +def _test_nt_to_hf_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext): + # Create and save nanotron model. + model_nt = create_nanotron_model(parallel_context) + root = test_context.get_auto_remove_tmp_dir() + nt_path = root / "nanotron" + hf_path = root / "hf" + nanotron.serialize.save_weights(model=model_nt, parallel_context=parallel_context, root_folder=nt_path) + with open(nt_path / "model_config.json", "w+") as f: + json.dump(dataclasses.asdict(CONFIG), f) + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + del model_nt + # Perform conversion. + convert_nt_to_hf_and_save(nt_path, hf_path) + # Load huggingface and get logits. + model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() + logits_hf = model_hf(input_ids).logits + assert logits_nt.size() == logits_hf.size() + torch.testing.assert_allclose(logits_nt, logits_hf, atol=ATOL) + + +def test_nt_to_hf_with_files(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf_with_files)(input_ids=input_ids, test_context=TestContext()) + + +def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor): + model_nt = create_nanotron_model(parallel_context) + model_hf = create_huggingface_model() + convert_hf_to_nt(model_hf, model_nt, CONFIG) + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + logits_hf = model_hf(input_ids).logits + assert logits_nt.size() == logits_hf.size() + torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL) + + +def test_hf_to_nt(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt)(input_ids=input_ids) + + +def _test_hf_to_nt_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext): + # Create and save hf model. + model_hf = create_huggingface_model() + root = test_context.get_auto_remove_tmp_dir() + nt_path = root / "nanotron" + hf_path = root / "hf" + model_hf.save_pretrained(hf_path) + logits_hf = model_hf(input_ids).logits + del model_hf + # Perform conversion. + convert_hf_to_nt_and_save(hf_path, nt_path) + # Load nanotron and get logits. + input_mask = torch.ones_like(input_ids) + model_nt = load_nanotron_model(checkpoint_path=nt_path) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL) + + +def test_hf_to_nt_with_files(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt_with_files)(input_ids=input_ids, test_context=TestContext()) + + +def _test_composed_conversion(parallel_context: ParallelContext): + # Get HF statedict. + model_hf = create_huggingface_model() + hf_sd = {key: val.clone() for key, val in model_hf.state_dict().items()} + # Convert once to nanotron, save its statedict. + model_nt = create_nanotron_model(parallel_context) + convert_hf_to_nt(model_hf, model_nt, CONFIG) + nt_sd = {key: val.clone() for key, val in model_nt.state_dict().items()} + # Convert back to HF, compare statedicts. + del model_hf + model_hf = create_huggingface_model() + convert_nt_to_hf(model_nt, model_hf, CONFIG) + hf_sd_new = model_hf.state_dict() + assert set(hf_sd_new) == set(hf_sd) + assert all(torch.all(hf_sd[key] == hf_sd_new[key]) for key in hf_sd_new) + # Convert to nanotron one more time, compare statedicts. + del model_nt + model_nt = create_nanotron_model(parallel_context) + convert_hf_to_nt(model_hf, model_nt, CONFIG) + nt_sd_new = model_nt.state_dict() + assert set(nt_sd_new) == set(nt_sd) + assert all(torch.all(nt_sd[key] == nt_sd_new[key]) for key in nt_sd_new) + + +def test_composed_conversion(): + init_distributed(tp=1, dp=1, pp=1)(_test_composed_conversion)() + + +def _save_parallel_nanotron(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path): + # Create and save a parallel model. + model_nt = create_nanotron_model(parallel_context) + nanotron.serialize.save_weights(model=model_nt, parallel_context=parallel_context, root_folder=nt_path) + with open(nt_path / "model_config.json", "w+") as f: + json.dump(dataclasses.asdict(CONFIG), f) + + # Get parallel predictions. + input_ids = input_ids.cuda() # Move them to the current device index. + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + if torch.distributed.get_rank() == 0: + torch.save(logits_nt.detach().cpu(), nt_path / "logits.pt") + + # Convert nanotron to hf, load it and compare logits. + # hf_path = root/"hf" + # convert_nt_to_hf_and_save(nt_path, hf_path) + # model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() + # logits_hf = model_hf(input_ids).logits + + # assert logits_nt.size() == logits_hf.size() + # assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) + + +def _convert_from_parallel(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path, hf_path: Path): + # Convert parallel nanotron to hf, get and save huggingface predictions. + convert_nt_to_hf_and_save(nt_path, hf_path) + model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() + logits_hf = model_hf(input_ids).logits + torch.save(logits_hf.detach().cpu(), hf_path / "logits.pt") + + +def test_tensor_parallel_conversion(input_ids: torch.Tensor): + # Set up test. + test_context = TestContext() + root = test_context.get_auto_remove_tmp_dir() + nt_path = root / "nanotron" + hf_path = root / "nanotron" + + # Launch both parts. + init_distributed(tp=2, dp=1, pp=1)(_save_parallel_nanotron)(input_ids=input_ids, nt_path=nt_path) + assert (nt_path / "logits.pt").exists() + init_distributed(tp=1, dp=1, pp=1)(_convert_from_parallel)(input_ids=input_ids, nt_path=nt_path, hf_path=hf_path) + assert (hf_path / "logits.pt").exists() + + # Load logits and verify they match. + logits_nt = torch.load(nt_path / "logits.pt") + logits_hf = torch.load(hf_path / "logits.pt") + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) diff --git a/examples/llama/tests/test_conversion.py.orig b/examples/llama/tests/test_conversion.py.orig new file mode 100644 index 000000000..af0688371 --- /dev/null +++ b/examples/llama/tests/test_conversion.py.orig @@ -0,0 +1,264 @@ +# ruff: noqa: E402 +import json +<<<<<<< HEAD +from pathlib import Path +======= +>>>>>>> main + +import pytest +import torch +from transformers import LlamaForCausalLM +from utils import set_system_path + +set_system_path() + +import nanotron +from nanotron.config import LlamaConfig as NanotronLlamaConfig +from nanotron.models.base import init_on_device_and_dtype +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext + +from examples.llama.convert_hf_to_nanotron import convert_checkpoint_and_save as convert_hf_to_nt_and_save +<<<<<<< HEAD +from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save +from examples.llama.convert_hf_to_nanotron import convert_hf_to_nt +from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config +from examples.llama.convert_weights import load_nanotron_model +from tests.helpers.context import TestContext +from tests.helpers.utils import init_distributed +======= +from examples.llama.convert_hf_to_nanotron import convert_hf_to_nt +from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save +from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config +from examples.llama.convert_weights import load_nanotron_model, make_parallel_config +from tests.helpers.context import TestContext +from tests.helpers.utils import init_distributed, rerun_if_address_is_in_use +>>>>>>> main + +CONFIG = NanotronLlamaConfig( + **{ + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 1024, + "is_llama_config": True, + "max_position_embeddings": 128, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pad_token_id": None, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 4096, + } +) + + +BATCH_SIZE = 3 +SEQUENCE_LENGTH = 5 +ATOL = 0.02 + + +def create_nanotron_model(pp: int = 1, tp: int = 1, dp: int = 1) -> LlamaForTraining: + parallel_config = make_parallel_config(dp, pp, tp) + return load_nanotron_model(parallel_config, CONFIG, torch.device("cuda"), torch.bfloat16) + + +def create_huggingface_model() -> LlamaForCausalLM: + config_hf = get_hf_config(CONFIG) + with init_on_device_and_dtype(torch.device("cuda"), torch.bfloat16): + model_hf = LlamaForCausalLM._from_config(config_hf) + return model_hf + + +@pytest.fixture(autouse=True, scope="module") +def fix_seed(): + torch.manual_seed(0) + yield + + +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda") + + +def _test_nt_to_hf(parallel_context: ParallelContext, input_ids: torch.Tensor): + model_nt = create_nanotron_model() + model_hf = create_huggingface_model() + convert_nt_to_hf(model_nt, model_hf, CONFIG) + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + logits_hf = model_hf(input_ids).logits + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) + + +def test_nt_to_hf(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf)(input_ids=input_ids) + + +def _test_nt_to_hf_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext): + # Create and save nanotron model. + model_nt = create_nanotron_model() + root = test_context.get_auto_remove_tmp_dir() + nt_path = root / "nanotron" + hf_path = root / "hf" + nanotron.serialize.save_weights(model=model_nt, parallel_context=parallel_context, root_folder=nt_path) + with open(nt_path / "model_config.json", "w+") as f: + json.dump(vars(CONFIG), f) + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + del model_nt + # Perform conversion. + convert_nt_to_hf_and_save(nt_path, hf_path) + # Load huggingface and get logits. + model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() + logits_hf = model_hf(input_ids).logits + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) + + +def test_nt_to_hf_with_files(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf_with_files)(input_ids=input_ids, test_context=TestContext()) + + +def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor): + model_nt = create_nanotron_model() + model_hf = create_huggingface_model() + convert_hf_to_nt(model_hf, model_nt, CONFIG) + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + logits_hf = model_hf(input_ids).logits + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) + + +def test_hf_to_nt(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt)(input_ids=input_ids) + + +def _test_hf_to_nt_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext): + # Create and save hf model. + model_hf = create_huggingface_model() + root = test_context.get_auto_remove_tmp_dir() + nt_path = root / "nanotron" + hf_path = root / "hf" + model_hf.save_pretrained(hf_path) + logits_hf = model_hf(input_ids).logits + del model_hf + # Perform conversion. + convert_hf_to_nt_and_save(hf_path, nt_path) + # Load nanotron and get logits. + input_mask = torch.ones_like(input_ids) + model_nt = load_nanotron_model(checkpoint_path=nt_path) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL) + + +def test_hf_to_nt_with_files(input_ids: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt_with_files)(input_ids=input_ids, test_context=TestContext()) + + +def _test_composed_conversion(parallel_context: ParallelContext): + # Get HF statedict. + model_hf = create_huggingface_model() + hf_sd = {key: val.clone() for key, val in model_hf.state_dict().items()} + # Convert once to nanotron, save its statedict. + model_nt = create_nanotron_model() + convert_hf_to_nt(model_hf, model_nt, CONFIG) + nt_sd = {key: val.clone() for key, val in model_nt.state_dict().items()} + # Convert back to HF, compare statedicts. + del model_hf + model_hf = create_huggingface_model() + convert_nt_to_hf(model_nt, model_hf, CONFIG) + hf_sd_new = model_hf.state_dict() + assert set(hf_sd_new) == set(hf_sd) + assert all(torch.all(hf_sd[key] == hf_sd_new[key]) for key in hf_sd_new) + # Convert to nanotron one more time, compare statedicts. + del model_nt + model_nt = create_nanotron_model() + convert_hf_to_nt(model_hf, model_nt, CONFIG) + nt_sd_new = model_nt.state_dict() + assert set(nt_sd_new) == set(nt_sd) + assert all(torch.all(nt_sd[key] == nt_sd_new[key]) for key in nt_sd_new) + + +def test_composed_conversion(): + init_distributed(tp=1, dp=1, pp=1)(_test_composed_conversion)() + + +<<<<<<< HEAD +def _save_parallel_nanotron(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path): + # Create and save a parallel model. + model_nt = create_nanotron_model(tp=parallel_context.tensor_parallel_size, pp=parallel_context.pipeline_parallel_size) + # print(torch.distributed.get_rank(), "model_nt", set(p.device for p in model_nt.parameters())) + nanotron.serialize.save_weights(model=model_nt, parallel_context=parallel_context, root_folder=nt_path) + with open(nt_path/"model_config.json", "w+") as f: + json.dump(vars(CONFIG), f) + + # Get parallel predictions. + input_ids = input_ids.cuda() # Move them to the current device index. + input_mask = torch.ones_like(input_ids) + # print(torch.distributed.get_rank(), "input_ids", input_ids.device) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + if torch.distributed.get_rank() == 0: + torch.save(logits_nt.detach().cpu(), nt_path/"logits.pt") + # print(torch.distributed.get_rank(), logits_nt.shape) + + # Convert nanotron to hf, load it and compare logits. + # hf_path = root/"hf" + # convert_nt_to_hf_and_save(nt_path, hf_path) + # model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() + # logits_hf = model_hf(input_ids).logits + + # assert logits_nt.size() == logits_hf.size() + # assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) + + +def _convert_from_parallel(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path, hf_path: Path): + # Convert parallel nanotron to hf, get and save huggingface predictions. + convert_nt_to_hf_and_save(nt_path, hf_path) + model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() + logits_hf = model_hf(input_ids).logits + torch.save(logits_hf.detach().cpu(), hf_path/"logits.pt") + +def test_tensor_parallel_conversion(input_ids: torch.Tensor): + # Set up test. + test_context = TestContext() + root = test_context.get_auto_remove_tmp_dir() + nt_path =root/"nanotron" + hf_path =root/"nanotron" + + # Launch both parts. + init_distributed(tp=2, dp=1, pp=1)(_save_parallel_nanotron)(input_ids=input_ids, nt_path=nt_path) + assert (nt_path/"logits.pt").exists() + init_distributed(tp=1, dp=1, pp=1)(_convert_from_parallel)(input_ids=input_ids, nt_path=nt_path, hf_path=hf_path) + assert (hf_path/"logits.pt").exists() + + # Load logits and verify they match. + logits_nt = torch.load(nt_path/"logits.pt") + logits_hf = torch.load(hf_path/"logits.pt") + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) +======= +def _test_tensor_parallel_conversion(parallel_context: ParallelContext): + model_nt = create_nanotron_model(tp=2) + model_hf = create_huggingface_model() + convert_nt_to_hf(model_nt, model_hf, CONFIG) + input_mask = torch.ones_like(input_ids) + logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) + logits_hf = model_hf(input_ids).logits + assert logits_nt.size() == logits_hf.size() + assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) + + +@rerun_if_address_is_in_use() +def test_tensor_parallel_conversion(): + init_distributed(tp=2, dp=1, pp=1)(_test_tensor_parallel_conversion)() +>>>>>>> main diff --git a/examples/llama/tests/utils.py b/examples/llama/tests/utils.py new file mode 100644 index 000000000..6ac3c4650 --- /dev/null +++ b/examples/llama/tests/utils.py @@ -0,0 +1,15 @@ +import importlib +import sys +from pathlib import Path + + +def set_system_path(): + package = importlib.import_module("nanotron") + # NOTE: Path(package.__file__).parent = .../nanotron/src/nanotron + # we want .../nanotron + package_path = Path(package.__file__).parent.parent.parent + sys.path.insert(0, str(package_path)) + + # we also want ../llama + llama_path = Path(__file__).parent.parent + sys.path.insert(0, str(llama_path)) diff --git a/pyproject.toml b/pyproject.toml index e65f37a53..6a0cfb83d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ fast-modeling = [ nanosets = [ "transformers", - "datasets", + "datatrove[io,processing]@git+https://github.com/huggingface/datatrove", "numba", ] diff --git a/run_train.py b/run_train.py index b33231f4f..021d955de 100644 --- a/run_train.py +++ b/run_train.py @@ -143,17 +143,17 @@ def get_dataloader_from_data_stage( elif isinstance(data.dataset, NanosetDatasetsArgs): # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 del tokenizer # Create Nanoset from nanotron.data.nanoset import Nanoset with main_rank_first(trainer.parallel_context.world_pg): train_dataset = Nanoset( - dataset_paths=data.dataset.dataset_path, + dataset_folders=data.dataset.dataset_folder, dataset_weights=data.dataset.dataset_weights, sequence_length=trainer.sequence_length, - token_dtype=token_dtype, + token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, random_seed=data.seed, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 16ef085c7..05b499554 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -93,18 +93,18 @@ def __post_init__(self): @dataclass class NanosetDatasetsArgs: - dataset_path: Union[str, dict, List[str]] + dataset_folder: Union[str, dict, List[str]] def __post_init__(self): - if isinstance(self.dataset_path, str): # Case 1: 1 Dataset file - self.dataset_path = [self.dataset_path] + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_path, List): # Case 2: > 1 Dataset file + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_path, dict): # Case 3: dict with > 1 dataset_path and weights - tmp_dataset_path = self.dataset_path.copy() - self.dataset_path = list(tmp_dataset_path.keys()) - self.dataset_weights = list(tmp_dataset_path.values()) + elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + tmp_dataset_folder = self.dataset_folder.copy() + self.dataset_folder = list(tmp_dataset_folder.keys()) + self.dataset_weights = list(tmp_dataset_folder.values()) @dataclass @@ -161,6 +161,7 @@ class GeneralArgs: Args: project: Name of the project (a project gather several runs in common tensorboard/hub-folders) + entity: Weights and bias entity name (optional) run: Name of the run step: Global step (updated when we save the checkpoint) consumed_train_samples: Number of samples consumed during training (should be actually just step*batch_size) @@ -168,6 +169,7 @@ class GeneralArgs: """ project: str + entity: Optional[str] = None run: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None @@ -247,7 +249,7 @@ class LRSchedulerArgs: lr_warmup_steps: number of steps to warmup the learning rate lr_warmup_style: linear or constant - lr_decay_style: linear or cosine + lr_decay_style: linear, cosine or 1-sqrt min_decay_lr: minimum learning rate after decay lr_decay_steps: optional number of steps to decay the learning rate otherwise will default to train_steps - lr_warmup_steps lr_decay_starting_step: optional number of steps to decay the learning rate otherwise will default to train_steps - lr_warmup_steps @@ -270,9 +272,9 @@ def __post_init__(self): self.lr_warmup_style = "linear" if self.lr_decay_style is None: self.lr_decay_style = "linear" - if self.lr_decay_style not in ["linear", "cosine"]: + if self.lr_decay_style not in ["linear", "cosine", "1-sqrt"]: raise ValueError( - f"lr_decay_style should be a string selected in ['linear', 'cosine'] and not {self.lr_decay_style}" + f"lr_decay_style should be a string selected in ['linear', 'cosine', '1-sqrt'] and not {self.lr_decay_style}" ) if self.min_decay_lr is None: self.min_decay_lr = self.learning_rate diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py new file mode 100644 index 000000000..199527e15 --- /dev/null +++ b/src/nanotron/data/collator.py @@ -0,0 +1,80 @@ +import dataclasses +from typing import Dict, List, Union + +import numpy as np +import torch +from nanotron import distributed as dist +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer + + +@dataclasses.dataclass +class NanosetDataCollatorForCLM: + """ + Data collator used for causal language modeling with Nanosets dataset. + + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # Make sure we load only what's necessary, ie we only load a `input_ids` column. + assert all(list(example.keys()) == ["input_ids"] for example in examples) + + # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? + input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + batch_size, expanded_input_length = input_ids.shape + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + assert ( + expanded_input_length == self.sequence_length + 1 + ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" + + # Process inputs: last token is the label + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 4719c476c..9d3285f60 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,7 +1,7 @@ import nanotron.distributed as dist from nanotron import logging +from nanotron.data.collator import NanosetDataCollatorForCLM from nanotron.dataloader import ( - DataCollatorForCLM, EmptyInfiniteDataset, get_dataloader_worker_init, get_sampler, @@ -32,7 +32,7 @@ def build_nanoset_dataloader( # No need to spawn a lot of workers, we can just use main dataloader_num_workers = 0 - data_collator = DataCollatorForCLM( + data_collator = NanosetDataCollatorForCLM( sequence_length=sequence_length, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 9d62b33d1..902009670 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -1,7 +1,10 @@ +import os +import warnings from typing import Dict, List, Tuple, Union import numpy as np import torch +from datatrove.utils.dataset import DatatroveFolderDataset from nanotron import logging from nanotron.data.utils import count_dataset_indexes, normalize from nanotron.logging import log_rank @@ -15,49 +18,60 @@ class Nanoset(torch.utils.data.Dataset): The Nanoset dataset Args: - dataset_paths (List[str]): List of paths to tokenized datasets - dataset_weights (List[float]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + dataset_folders (List[str]): List of folders with tokenized datasets + dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ sequence_length (int): Sequence length of the built samples - token_dtype (Union[np.uint16, np.int32]): dtype of the tokens stored in the processed dataset files. np.uin16 for vocab sizes < 65535, np.int32 otherwise + token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size """ def __init__( self, - dataset_paths: List[str], - dataset_weights: Union[List[float], None], + dataset_folders: List[str], sequence_length: int, - token_dtype: Union[np.uint16, np.int32], + token_size: int, train_split_num_samples: int, + dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, ) -> None: + # Checks + if isinstance(dataset_folders, str): + warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + dataset_folders = [dataset_folders] + # Init - self.dataset_paths = dataset_paths - self.dataset_weights = dataset_weights + self.dataset_folders = dataset_folders self.sequence_length = sequence_length - self.token_dtype = token_dtype + self.token_size = token_size self.train_split_num_samples = train_split_num_samples self.random_seed = random_seed + self.datatrove_datasets = [] + for dataset_folder in self.dataset_folders: + self.datatrove_datasets.append( + DatatroveFolderDataset( + folder_path=dataset_folder, + filename_pattern=os.path.join(dataset_folder, "*.ds"), + seq_len=sequence_length, + recursive=False, + token_size=token_size, + shuffle=True, + ) + ) # Build Nanoset Index ## To build the index we need the length of each dataset - self.dataset_lengths = [] - for dataset_path in self.dataset_paths: - self.dataset_buffer_mmap = np.memmap(dataset_path, mode="r", order="C", dtype=self.token_dtype) - self.dataset_buffer = memoryview(self.dataset_buffer_mmap) - dataset_number_of_tokens = int(len(self.dataset_buffer)) - number_of_samples = int( - (dataset_number_of_tokens - 1) / sequence_length - ) # Discard last sample if length < sequence_length - self.dataset_lengths.append(number_of_samples) + self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] ## Set dataset weights if ( - self.dataset_weights is None + dataset_weights is None ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch self.dataset_weights = normalize(self.dataset_lengths) else: self.dataset_weights = normalize(dataset_weights) + assert len(dataset_folders) == len( + self.dataset_weights + ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() @@ -79,25 +93,12 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: idx (int): The index into the dataset Returns: - Dict[str, numpy.ndarray]: The input ids wrapped in a dictionary + Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary """ - dataset = self.dataset_index[idx] dataset_sample = self.dataset_sample_index[idx] - # Rebuild the memmap in every access to free memory - # https://stackoverflow.com/a/61472122 - self.dataset_buffer_mmap = np.memmap(self.dataset_paths[dataset], mode="r", order="C", dtype=self.token_dtype) - self.dataset_buffer = memoryview(self.dataset_buffer_mmap) - - # uint16 -> 2 bytes per token, int32 -> 4 bytes per token - offset = dataset_sample * self.sequence_length * (np.iinfo(self.token_dtype).bits / 8) - input_ids_tokens = np.frombuffer( - self.dataset_buffer, dtype=self.token_dtype, count=(self.sequence_length + 1), offset=int(offset) - ) - - # Return tokens as np.int32 as Torch can't handle uint16 - return {"input_ids": input_ids_tokens.astype(np.int32)} + return self.datatrove_datasets[dataset][dataset_sample] def build_nanoset_index(self) -> np.ndarray: """ @@ -124,15 +125,6 @@ def build_nanoset_index(self) -> np.ndarray: return dataset_index, dataset_sample_index - def __del__(self) -> None: - """ - Clean up Nanoset - """ - - if hasattr(self, "dataset_buffer_mmap"): - self.dataset_buffer_mmap._mmap.close() - del self.dataset_buffer_mmap - def print_nanoset_info(self): log_rank(f"> Total number of samples: {len(self)}", logger=logger, level=logging.INFO, rank=0) @@ -141,10 +133,10 @@ def print_nanoset_info(self): ) # Print samples from each dataset + weight - dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_paths)) + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) for index, sample_count in enumerate(dataset_sample_count): log_rank( - f"> Total number of samples from the {self.dataset_paths[index].rsplit('/', 1)[-1]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index f7bf63e5b..a82f0294a 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -146,6 +146,12 @@ def lr_lambda(current_step: int, initial_lr: float): * (lr_decay_steps - (current_step - lr_decay_starting_step)) / lr_decay_steps ) + elif lr_scheduler_args.lr_decay_style == "1-sqrt": + lmbda = ( + lr_scheduler_args.min_decay_lr + + (initial_lr - lr_scheduler_args.min_decay_lr) + * (1 - math.sqrt((current_step - lr_decay_starting_step) / lr_decay_steps)) + ) else: raise ValueError(f"Unknown decay style {lr_scheduler_args.lr_decay_style}") diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2072a7892..2411e5fa0 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -306,6 +306,7 @@ def __init__( self.rotary_embedding = RotaryEmbedding( dim=self.d_qk, end=config.max_position_embeddings, + theta=config.rope_theta, ) # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index e04e26f56..6c28f4fdd 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -26,8 +26,8 @@ def __init__( world_size % data_parallel_size == 0 ), "The total number of processes must be divisible by the data parallel size." assert world_size % num_gpus_per_model == 0, ( - "The total number of processes must be divisible by" - "the number of GPUs per model (tensor_parallel_size * pipeline_parallel_size)." + f"The total number of processes ({world_size}) must be divisible by " + f"the number of GPUs per model ({num_gpus_per_model}, i.e. tensor_parallel_size * pipeline_parallel_size)." ) if num_gpus_per_model * data_parallel_size != world_size: raise ValueError( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0eda00dc5..b6752f381 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -276,7 +276,8 @@ def pre_training(self, *args, **kwargs): if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.init( project=self.config.general.project, - name=f"{current_time}_{self.config.general.run}", + name=f"{current_time}_{self.config.general.project}_{self.config.general.run}", + entity=self.config.general.entity, config={"nanotron_config": self.config.as_dict()}, ) diff --git a/tests/helpers/data.py b/tests/helpers/data.py index 33bb24808..72deb7f5b 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -3,6 +3,7 @@ import json import os import sys +from argparse import Namespace from collections import OrderedDict from pathlib import Path @@ -10,8 +11,6 @@ package_path = Path(package.__file__).parent.parent.parent sys.path.append(str(package_path)) -from argparse import Namespace - import nanotron.distributed as dist import torch from nanotron.data.nanoset import Nanoset @@ -23,31 +22,34 @@ def create_dataset_paths(tmp_dir: str, quantity: int): - json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}") for i in range(quantity)] - mmap_dataset_path = [f"{path}_input_ids.npy" for path in json_dataset_path] + json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}.json") for i in range(quantity)] + datatrove_tokenized_dataset_paths = [os.path.join(tmp_dir, f"tokenized_documents_{i}") for i in range(quantity)] - return json_dataset_path, mmap_dataset_path + return json_dataset_path, datatrove_tokenized_dataset_paths def create_dummy_json_dataset(path_to_json: str, dummy_text: str, n_samples: int = 50000): - with open(path_to_json + ".json", "a") as json_file: + with open(path_to_json, "a") as json_file: for sample in range(n_samples): sample_dict = {"text": f"[{sample}] Hello! Im sample {sample}! And this is my dummy text: {dummy_text}"} json_file.write(json.dumps(sample_dict)) json_file.write("\n") -def preprocess_dummy_dataset(path_to_json: str, tokenizer: str): +def preprocess_dummy_dataset(json_dataset_path: str, datatrove_tokenized_dataset_path: str, tokenizer: str): # Create args for preprocessing args = Namespace( - input=path_to_json + ".json", + readers="jsonl", + dataset=json_dataset_path, column="text", - output_prefix=path_to_json, + glob_pattern=None, + output_folder=datatrove_tokenized_dataset_path, tokenizer_name_or_path=tokenizer, - add_special_tokens=False, + eos_token=None, + n_tasks=1, + logging_dir=None, ) - # tools/preprocess_data.py main main(args) @@ -122,7 +124,7 @@ def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: Par IDX_SAMPLE = 23 nanoset_identifiers = OrderedDict() - nanoset_identifiers["dataset_paths"] = nanoset.dataset_paths + nanoset_identifiers["dataset_folders"] = nanoset.dataset_folders nanoset_identifiers["dataset_weights"] = nanoset.dataset_weights.tolist() nanoset_identifiers["sequence_length"] = nanoset.sequence_length nanoset_identifiers["train_split_num_samples"] = nanoset.train_split_num_samples @@ -131,6 +133,7 @@ def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: Par nanoset_identifiers["input_ids"] = nanoset[IDX_SAMPLE]["input_ids"].tolist() nanoset_identifiers["dataset_index"] = nanoset.dataset_index.tolist() nanoset_identifiers["dataset_sample_index"] = nanoset.dataset_sample_index.tolist() + nanoset_identifiers["token_size"] = nanoset.token_size unique_description_hash = compute_hash(nanoset_identifiers) assert_tensor_synced_across_pg( diff --git a/tests/test_build_nanoset_dataloader.py b/tests/nanoset/test_build_nanoset_dataloader.py similarity index 79% rename from tests/test_build_nanoset_dataloader.py rename to tests/nanoset/test_build_nanoset_dataloader.py index e8ea8abb5..113c545c6 100644 --- a/tests/test_build_nanoset_dataloader.py +++ b/tests/nanoset/test_build_nanoset_dataloader.py @@ -1,4 +1,10 @@ +import sys from math import isclose +from pathlib import Path +from typing import List + +package_path = Path(__file__).parent.parent +sys.path.append(str(package_path)) import numpy as np import pytest @@ -28,7 +34,7 @@ for all_3d_configs in get_all_3d_configurations(gpus) ], ) -@pytest.mark.parametrize("train_steps", [5, 100]) +@pytest.mark.parametrize("train_steps", [500, 10000]) @pytest.mark.parametrize("sequence_length", [512, 8192]) @pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"]) @rerun_if_address_is_in_use() @@ -37,16 +43,21 @@ def test_build_nanoset_dataloader( ): test_context = TestContext() - # Create dataset files - json_paths, mmap_dataset_paths = create_dataset_paths(tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2) + # Create dataset folders + json_paths, datatrove_tokenized_dataset_folders = create_dataset_paths( + tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2 + ) # Create dummy json datasets for idx, json_path in enumerate(json_paths): create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000) + # Preprocess json dataset with datatrove + for json_path, datatrove_tokenized_dataset_folder in zip(json_paths, datatrove_tokenized_dataset_folders): + preprocess_dummy_dataset(json_path, datatrove_tokenized_dataset_folder, tokenizer_name_or_path) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_build_nanoset_dataloader)( - json_paths=json_paths, - path_to_mmap_files=mmap_dataset_paths, + datatrove_tokenized_dataset_folders=datatrove_tokenized_dataset_folders, train_steps=train_steps, sequence_length=sequence_length, tokenizer_name_or_path=tokenizer_name_or_path, @@ -55,8 +66,7 @@ def test_build_nanoset_dataloader( def _test_build_nanoset_dataloader( parallel_context: ParallelContext, - json_paths: str, - path_to_mmap_files: str, + datatrove_tokenized_dataset_folders: List[str], train_steps: int, sequence_length: int, tokenizer_name_or_path: str, @@ -66,41 +76,37 @@ def _test_build_nanoset_dataloader( N_MICRO_BATCHES_PER_BATCH = 8 GLOBAL_BATCH_SIZE = MICRO_BATCH_SIZE * N_MICRO_BATCHES_PER_BATCH * parallel_context.dp_pg.size() - # Preprocess dummy json datasets - for json_path in json_paths: - preprocess_dummy_dataset(path_to_json=json_path, tokenizer=tokenizer_name_or_path) - input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1) # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 del tokenizer # Create Nanoset configs: 1. Normal 2. Blended 3. Blended with weights nanoset_config = { - "dataset_paths": [path_to_mmap_files[0]], + "dataset_folders": [datatrove_tokenized_dataset_folders[0]], "dataset_weights": [1], "sequence_length": sequence_length, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": None, "sequence_length": sequence_length, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_weighted_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": [8, 2], "sequence_length": sequence_length, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } @@ -114,7 +120,7 @@ def _test_build_nanoset_dataloader( # Assert we have the same Nanoset in all ranks assert_nanoset_sync_across_all_ranks(train_dataset, parallel_context) - dataset_sample_count = count_dataset_indexes(train_dataset.dataset_index, len(train_dataset.dataset_paths)) + dataset_sample_count = count_dataset_indexes(train_dataset.dataset_index, len(train_dataset.dataset_folders)) for idx, ds_length in enumerate(train_dataset.dataset_lengths): # Assert Nanoset doesn't sample indexes greater than the datasets assert ( @@ -124,7 +130,7 @@ def _test_build_nanoset_dataloader( # Assert Nanoset builds up the correct blend WRT the dataset_weights assert isclose( normalize(dataset_sample_count).tolist()[idx], train_dataset.dataset_weights[idx], abs_tol=0.05 - ), f"Requested Nanoset to contain {round(train_dataset.dataset_weights[idx]*100, 2)}% of samples from {train_dataset.dataset_paths[idx]} but got {round(normalize(dataset_sample_count).tolist()[idx]*100, 2)}%" + ), f"Requested Nanoset to contain {round(train_dataset.dataset_weights[idx]*100, 2)}% of samples from {train_dataset.dataset_folders[idx]} but got {round(normalize(dataset_sample_count).tolist()[idx]*100, 2)}%" # Create Dataloaders dataloader = build_nanoset_dataloader( train_dataset, @@ -157,22 +163,27 @@ def _test_build_nanoset_dataloader( for all_3d_configs in get_all_3d_configurations(gpus) ], ) -@pytest.mark.parametrize("skipped_batches", [20, 50]) +@pytest.mark.parametrize("skipped_batches", [20, 5555]) @pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"]) @rerun_if_address_is_in_use() def test_recover_nanoset_dataloader(tp: int, dp: int, pp: int, skipped_batches: int, tokenizer_name_or_path: str): test_context = TestContext() - # Create dataset files - json_paths, mmap_dataset_paths = create_dataset_paths(tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2) + # Create dataset folders + json_paths, datatrove_tokenized_dataset_folders = create_dataset_paths( + tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2 + ) # Create dummy json datasets for idx, json_path in enumerate(json_paths): create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000) + # Preprocess json dataset with datatrove + for json_path, datatrove_tokenized_dataset_folder in zip(json_paths, datatrove_tokenized_dataset_folders): + preprocess_dummy_dataset(json_path, datatrove_tokenized_dataset_folder, tokenizer_name_or_path) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_recover_nanoset_dataloader)( - json_paths=json_paths, - path_to_mmap_files=mmap_dataset_paths, + datatrove_tokenized_dataset_folders=datatrove_tokenized_dataset_folders, skipped_batches=skipped_batches, tokenizer_name_or_path=tokenizer_name_or_path, ) @@ -180,8 +191,7 @@ def test_recover_nanoset_dataloader(tp: int, dp: int, pp: int, skipped_batches: def _test_recover_nanoset_dataloader( parallel_context: ParallelContext, - json_paths: str, - path_to_mmap_files: str, + datatrove_tokenized_dataset_folders: List[str], skipped_batches: int, tokenizer_name_or_path: str, ): @@ -190,43 +200,39 @@ def _test_recover_nanoset_dataloader( N_MICRO_BATCHES_PER_BATCH = 8 GLOBAL_BATCH_SIZE = MICRO_BATCH_SIZE * N_MICRO_BATCHES_PER_BATCH * parallel_context.dp_pg.size() SEQUENCE_LENGTH = 1024 - TRAIN_STEPS = 100 - - # Preprocess dummy json datasets - for json_path in json_paths: - preprocess_dummy_dataset(path_to_json=json_path, tokenizer=tokenizer_name_or_path) + TRAIN_STEPS = 10000 input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1) # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 del tokenizer # Create Nanoset configs: 1. Normal 2. Blended 3. Blended with weights nanoset_config = { - "dataset_paths": [path_to_mmap_files[0]], + "dataset_folders": [datatrove_tokenized_dataset_folders[0]], "dataset_weights": [1], "sequence_length": SEQUENCE_LENGTH, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": None, "sequence_length": SEQUENCE_LENGTH, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_weighted_nanoset_config = { - "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], + "dataset_folders": datatrove_tokenized_dataset_folders, "dataset_weights": [8, 2], "sequence_length": SEQUENCE_LENGTH, - "token_dtype": token_dtype, + "token_size": token_size, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 465d22f04..38db67f19 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -1,26 +1,21 @@ -import argparse -import os -import shutil -import sys +""" +To process HuggingFace Datasets: + python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/emotion --n-tasks 16 hf --dataset dair-ai/emotion +To process Jsonl files: + python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/c4-es --n-tasks 16 jsonl --dataset raw_datasets/c4-es-json-files +""" -import numpy as np -import torch.distributed as dist -from tqdm import tqdm -from transformers import AutoTokenizer +import argparse -from datasets import concatenate_datasets, load_dataset +from datatrove.executor.local import LocalPipelineExecutor +from datatrove.pipeline.readers import HuggingFaceDatasetReader, JsonlReader +from datatrove.pipeline.tokens import DocumentTokenizer def get_args(): parser = argparse.ArgumentParser() - group = parser.add_argument_group(title="input data") - group.add_argument( - "--input", type=str, required=True, help="Path to local stored dataset or repository on the Hugging Face hub" - ) - group.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset") - parser.add_argument("--split", type=str, default="train", help="Which split of the data to process") - group = parser.add_argument_group(title="tokenizer") + group = parser.add_argument_group(title="Tokenizer") group.add_argument( "--tokenizer-name-or-path", type=str, @@ -28,13 +23,54 @@ def get_args(): help="A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.", ) group.add_argument( - "--add-special-tokens", - action="store_true", - help="Whether or not to add special tokens when encoding the sequences. This will be passed to the Tokenizer", + "--eos-token", + type=str, + default=None, + help="EOS token to add after each document. Default: None", + ) + + group = parser.add_argument_group(title="Output data") + group.add_argument( + "--output-folder", type=str, required=True, help="Path to the output folder to store the tokenized documents" + ) + group = parser.add_argument_group(title="Miscellaneous configs") + group.add_argument( + "--logging-dir", + type=str, + default=None, + help="Path to a folder for storing the logs of the preprocessing step. Default: None", + ) + group.add_argument( + "--n-tasks", type=int, default=8, help="Total number of tasks to run the preprocessing step. Default: 8" + ) + # Subparsers for processing either Hugging Face datasets or jsonl files + sp = parser.add_subparsers( + dest="readers", + required=True, + description="Type of dataset to process. It can be either a Hugging Face Dataset loaded with datasets.load_data ('hf') or a .jsonl dataset ('jsonl')", + ) + + p1 = sp.add_parser(name="hf") + p1.add_argument( + "--dataset", + type=str, + required=True, + help="Path to local stored dataset or repository on the Hugging Face hub that can be loaded with datasets.load_dataset", ) + p1.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text") + p1.add_argument("--split", type=str, default="train", help="Which split of the data to process. Default: train") - group = parser.add_argument_group(title="output data") - group.add_argument("--output-prefix", type=str, required=True, help="Path to the output processed dataset file") + p2 = sp.add_parser(name="jsonl") + p2.add_argument( + "--dataset", + type=str, + required=True, + help="Path to a .jsonl file or a folder containing multiple .jsonl files", + ) + p2.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text") + p2.add_argument( + "--glob-pattern", type=str, default=None, help="A glob pattern to filter files to read. Default: None" + ) args = parser.parse_args() @@ -42,74 +78,32 @@ def get_args(): def main(args): - - world_size, rank = int(os.environ["WORLD_SIZE"]), int(os.environ["RANK"]) - - # Remove stdout from all processes except main to not flood the stdout - if rank: - sys.stdout = open(os.devnull, "w") - - # Check if output directory exists - if not os.path.isdir(os.path.abspath(os.path.join(args.output_prefix, os.path.pardir))): - print(f"Creating {os.path.abspath(os.path.join(args.output_prefix, os.path.pardir))} directory...") - os.makedirs(os.path.abspath(os.path.join(args.output_prefix, os.path.pardir)), exist_ok=True) - - if args.input.endswith(".json"): # For processing JSON files (Cross compatibility with other projects) - ds = load_dataset("json", data_files=args.input) - ds = concatenate_datasets( - [ds[splits] for splits in ds.keys()] - ) # load_dataset returns DatasetDict and we want a Dataset + # Build datatrove reader + if args.readers == "hf": + datatrove_reader = HuggingFaceDatasetReader( + dataset=args.dataset, + text_key=args.column, + dataset_options={"split": args.split}, + ) else: - ds = load_dataset(args.input, split=args.split) - - ds = ds.shard(num_shards=world_size, index=rank, contiguous=True) - ds = ds.select_columns(args.column) - - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) - token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 - - # Create tmp directory for worker outputs - tmp_folder = os.path.abspath(os.path.join(args.output_prefix, os.pardir, "tmp")) - os.makedirs(tmp_folder, exist_ok=True) - - print("Creating workers output files...") - worker_output_file = os.path.join(tmp_folder, f"worker_{rank}_input_ids.npy") - ds = ds.map( - lambda x: {"input_ids": tokenizer(x, add_special_tokens=args.add_special_tokens).input_ids}, - input_columns=args.column, - batched=True, - desc="Tokenizing Dataset", - remove_columns=[args.column], + datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) + + preprocess_executor = LocalPipelineExecutor( + pipeline=[ + datatrove_reader, + DocumentTokenizer( + output_folder=args.output_folder, + tokenizer_name_or_path=args.tokenizer_name_or_path, + eos_token=args.eos_token, + max_tokens_per_file=1e9, + ), + ], + tasks=args.n_tasks, + logging_dir=args.logging_dir, ) - - worker_input_ids_file = open(worker_output_file, "wb") - for sample in ds: - np_array = np.array(sample["input_ids"], dtype=token_dtype) - worker_input_ids_file.write(np_array.tobytes(order="C")) - worker_input_ids_file.close() - - # Wait for all workers to process each shard of the Dataset - dist.barrier() - - # Only the main rank merges the worker files - if not rank: - output_file = f"{args.output_prefix}_input_ids.npy" - input_ids_file = open(output_file, "wb") - for worker_idx in tqdm(range(world_size), desc="Merging workers output files"): - worker_output_file = os.path.join(tmp_folder, f"worker_{worker_idx}_input_ids.npy") - with open(worker_output_file, "rb") as f: - shutil.copyfileobj(f, input_ids_file) - os.remove(worker_output_file) - - input_ids_file.close() - os.rmdir(tmp_folder) - print(f"Done! {args.input} processed dataset stored in {output_file}") - - else: # Close devnull stdout redirect - sys.stdout.close() + preprocess_executor.run() if __name__ == "__main__": _args = get_args() - dist.init_process_group(backend="gloo") main(_args)