diff --git a/docs/quick-start.md b/docs/quick-start.md index e58737125..20fc1a2b1 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -492,10 +492,13 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: train_iters: 100 # (1)! logs: interval: 10 - evaluations: + evaluators: validation: - iterations: 25 interval: 100 + evaluator: + type: loss + iterations: 25 + dataset_name: validation export: # (2)! format: llama interval: 100 @@ -550,10 +553,13 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: train_iters: 100_000 # (1)! logs: interval: 10 - evaluations: + evaluators: validation: - iterations: 25 - interval: 1000 + interval: 100 + evaluator: + type: loss + iterations: 25 + dataset_name: validation checkpoint: interval: 1000 keep: 5 diff --git a/docs/recipes/continue-training.md b/docs/recipes/continue-training.md index a19d965d3..d7df7a196 100644 --- a/docs/recipes/continue-training.md +++ b/docs/recipes/continue-training.md @@ -5,11 +5,13 @@ title: Continual Pretraining of Llama 3.1 8B or Qwen 2.5 7B In this guide, we provide step-by-step instructions to do continued pretraining on The Stack with Llama 3.1 8B or Qwen 2.5 7B models. -# Preliminary steps +## Preliminary steps + - [Quick Start](../quick-start.md) - [Data preparation](data-preparation.md) -# Download the Pretrained Model +## Download the Pretrained Model + Let's download the model first: === "Llama 3.1 8B" ```bash @@ -22,21 +24,27 @@ Let's download the model first: git clone https://huggingface.co/Qwen/Qwen2.5-7B ./fast-llm-tutorial/pretrained-model ``` -# Training +## Training + This is not much different from a pretraining config. We will: + - specify the the model checkpoint to load and its format. Fast-LLM will automatically infer the corresponding model architecture. - adapt some of the training parameters for our needs. - and that's it! === "Llama 3.1 8B" + ```yaml training: train_iters: 100_000 logs: interval: 10 - evaluations: + evaluators: validation: - iterations: 25 - interval: 1000 + interval: 100 + evaluator: + type: loss + iterations: 25 + dataset_name: validation checkpoint: interval: 1000 keep: 5 @@ -55,8 +63,8 @@ This is not much different from a pretraining config. We will: path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (2)! validation: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (2)! - optimizer: + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (2)! + optimizer: weight_decay: 0.1 beta_1: 0.9 beta_2: 0.95 @@ -78,20 +86,24 @@ This is not much different from a pretraining config. We will: multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + training_dtype: bf16 run: experiment_dir: fast-llm-tutorial/Llama-3.1-8B-cpt ``` + === "Qwen 2.5 7B" ```yaml training: train_iters: 100_000 logs: interval: 10 - validation: - Validation: - iterations: 25 - interval: 1000 + evaluators: + validation: + interval: 100 + evaluator: + type: loss + iterations: 25 + dataset_name: validation checkpoint: interval: 1000 keep: 5 @@ -110,8 +122,8 @@ This is not much different from a pretraining config. We will: path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! validation: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! - optimizer: + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! + optimizer: weight_decay: 0.1 beta_1: 0.9 beta_2: 0.95 @@ -133,7 +145,7 @@ This is not much different from a pretraining config. We will: multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + training_dtype: bf16 run: experiment_dir: fast-llm-tutorial/qwen-2.5-7B-cpt ``` @@ -144,7 +156,8 @@ This is not much different from a pretraining config. We will: 4. Config of the pretrained model. We load the model downloaded from the repository earlier. 5. This tells Fast-LLM to load the weights of the pretrained model. If we wanted to use the model's configuration, but train from scratch, we could use the same config but set this to `no`. -# Checkpoint usage +## Checkpoint usage + Checkpoints will be saved regularly, and every 20k steps a checkpoint will be exported in the HF format. You can use it in `transformers` as you would use the pretrained model, except this one should be stronger on programming languages! === "Llama 3.1 8B" @@ -160,4 +173,4 @@ You can use it in `transformers` as you would use the pretrained model, except tokenizer = AutoTokenizer.from_pretrained("fast-llm-tutorial/pretrained-model") pipe = pipeline("text-generation", model="fast-llm-tutorial/qwen-2.5-7B-cpt/export/qwen2/20000/", tokenizer=tokenizer) - ``` \ No newline at end of file + ``` diff --git a/docs/recipes/data-configuration.md b/docs/recipes/data-configuration.md index 23da8cc4f..eb6278550 100644 --- a/docs/recipes/data-configuration.md +++ b/docs/recipes/data-configuration.md @@ -25,13 +25,13 @@ We already saw an example dataset configuration in the [quick-start guide](../qu In this section we are interested in generalizing step 3. For more details on steps 1 and 2, please refer to the quick-start guide or [this example](data-configuration.md). -The section `data.datasets` holds descriptions of datasets used in training, validation, and testing. +The section `data.datasets` holds descriptions of datasets used in training, validation, and testing. -The Training and Testing phases must have predetermined dataset names: `training` and `testing`, respectively. Each of these phases can have only one dataset. +The Training and Testing phases must have predetermined dataset names: `training` and `testing`, respectively. Each of these phases can have only one dataset. -For validation datasets, the rules are different. There can be as many validation datasets as needed, and their names are arbitrary. In the example above, the dataset name `validation` is chosen for simplicity. The datasets names used for validation and their application details are specified in the training config `evaluations` sections. +For datasets used for loss evaluator during a validation phase, the rules are different. There can be as many such datasets as needed, and their names are arbitrary. In the example above, the dataset name `validation` is chosen for simplicity. The datasets names used for validation and their application details are specified in the training config `evaluators` sections. -Adding multiple validation datasets increases flexibility in tracking the accuracy of your trained model. One possible scenario is using a separate validation dataset for each blended training dataset, allowing you to track training progress on each subset separately and observe how the model performs in real time on different subsets of your training data. +Adding multiple datasets for loss evaluators in validation phase increases flexibility in tracking the accuracy of your trained model. One possible scenario is using a separate validation dataset for each blended training dataset, allowing you to track training progress on each subset separately and observe how the model performs in real time on different subsets of your training data. Below are examples of how to configure various aspects of training and validation datasets. @@ -128,22 +128,27 @@ data: !!! note "Default seed" In the absence of explicit seed, Fast-LLM uses a default seed (`data.sampling`'s default) instead, and uses seed shifts to ensure different seeds for each phase and for the various blended datasets. +## Example 5: Specifying Multiple Dataset for Loss Evaluators During Validation phase -## Example 5: Specifying Multiple Validation Datasets +In this example, we show how to specify multiple datasets for loss evaluators and configure how often they are applied, along with their usage attributes in the `training.evaluators` section. -In this example, we show how to specify multiple validation datasets and configure how often they are applied, along with their usage attributes in the `training.evaluations` section. - -Please note that the same dataset names must be used in the `training.evaluations` section. If a validation dataset is specified in the `datasets` section but not in `training.evaluations`, it will not be used for validation. +Please note that the same dataset names must be used in the `training.evaluators` section. If a dataset is specified in the `datasets` section but not in `training.evaluators`, it will not be used for loss evaluation. ```yaml training: - evaluations: + evaluators: the_stack: - iterations: 25 interval: 50 + evaluator: + type: loss + iterations: 25 + dataset_name: the_stack fineweb: - iterations: 25 interval: 100 + evaluator: + type: loss + iterations: 15 + dataset_name: fineweb data: datasets: the_stack: @@ -152,7 +157,7 @@ data: fineweb: type: file path: path/to/validation_fineweb_dataset.yaml - + ``` ## Example 6: Advanced scenario @@ -207,7 +212,7 @@ data: !!! note "Configure from file" If a dataset configuration is especially complex and makes the dataset configuration excessively big, or is reused across many experiments, you may want to save it to a yaml file and refer to it un the config using a `file` dataset. This can be used to reduce the present example to - + ```yaml data: datasets: diff --git a/docs/recipes/instruction-finetuning.md b/docs/recipes/instruction-finetuning.md index 15a454260..2c58a987d 100644 --- a/docs/recipes/instruction-finetuning.md +++ b/docs/recipes/instruction-finetuning.md @@ -114,10 +114,13 @@ training: train_iters: 5_000 logs: interval: 1 - evaluations: + evaluators: validation: - iterations: 25 - interval: 1000 + interval: 100 + evaluator: + type: loss + iterations: 25 + dataset_name: validation checkpoint: interval: 1000 keep: 5 diff --git a/docs/recipes/train.md b/docs/recipes/train.md index 9c5f92e5c..efdf6111b 100644 --- a/docs/recipes/train.md +++ b/docs/recipes/train.md @@ -4,13 +4,13 @@ title: Training Llama 3.1 8B Follow this guide to train a Llama-3.1 or Qwen 2.5 7B like model from scratch! +## Preliminary steps -# Preliminary steps - [Quick Start](../quick-start.md) - [Data preparation](data-preparation.md) +## Training configuration -# Training configuration In this guide, we show you how to configure a model architecture and train a model from scratch. Let's start from the following training configuration: === "Llama 3.1 8B" @@ -19,10 +19,12 @@ Let's start from the following training configuration: train_iters: 100_000 logs: interval: 10 - evaluations: - validation: - iterations: 25 - interval: 1000 + evaluators: + interval: 100 + evaluator: + type: loss + iterations: 25 + dataset_name: validation checkpoint: interval: 1000 keep: 5 @@ -68,10 +70,13 @@ Let's start from the following training configuration: train_iters: 100_000 logs: interval: 10 - evaluations: + evaluators: validation: - iterations: 25 - interval: 1000 + interval: 100 + evaluator: + type: loss + iterations: 25 + dataset_name: validation checkpoint: interval: 1000 keep: 5 @@ -133,16 +138,16 @@ By specifying a pretrained model from the HuggingFace hub, Fast-LLM automaticall === "Llama 3.1 8B" ```yaml pretrained: - format: llama + format: llama path: fast-llm-tutorial/pretrained_model - model_weights: no + model_weights: no ``` === "Qwen 2.5 7B" ```yaml pretrained: - format: qwen2 + format: qwen2 path: fast-llm-tutorial/pretrained_model - model_weights: no + model_weights: no ``` Alternatively, we define the model architecture ourselves as follows: @@ -196,4 +201,3 @@ Alternatively, we define the model architecture ourselves as follows: 1. Hidden-size/num-layers will be used to provide good defaults for weight initialization std. Configuring the model this way is a bit more verbose than using the pretrained configuration, but gives an idea of how to configure a the model with Fast-LLM. - diff --git a/examples/mistral.yaml b/examples/mistral.yaml index f1fa82795..10aa54b7f 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -3,9 +3,11 @@ training: num_workers: 8 logs: interval: 10 - evaluations: + evaluators: validation: - iterations: null + evaluator: + type: loss + iterations: null test_iters: 0 batch: sequence_length: 4096 diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 1586d370d..4c041945d 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,3 +34,8 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) + bos_token: str | None = Field( + default=None, + desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", + hint=FieldHint.core, + ) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 85bcc6561..5de5e2a2b 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -85,5 +85,4 @@ def _from_dict( assert rename not in default["datasets"] default["datasets"][rename] = default["datasets"].pop(phase.value) - cls._handle_renamed_field(default, "validation", ("evaluations", "validation")) return super()._from_dict(default, strict, flat) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index c6fece9d7..176c077a2 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -117,8 +117,10 @@ def setup( self._datasets = {} for dataset_name, sampling_parameters in self._sampling_parameters.items(): if self._tokenizer is not None: - # TODO: Too constraining? - Assert.eq(self._tokenizer.vocab_size, sampling_parameters.vocab_size) + # NOTE: Some models like Qwen2-1.5B-Instruct + # have vocab_size bigger in model config than in tokenizer + # TODO: Still, is it too constraining? + Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size) if sampling_parameters.num_samples > 0: sampling = GPTSamplingData( config=self._config.sampling, diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 988e23e76..c74586207 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -1,6 +1,6 @@ import numpy as np import torch -from transformers import PreTrainedTokenizerFast +from transformers import AutoTokenizer from fast_llm.data.config import TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank @@ -13,9 +13,15 @@ class Tokenizer: def __init__(self, config: TokenizerConfig): log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained( - pretrained_model_name_or_path=config.path, errors="replace", max_len=None + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=config.path, + errors="replace", + max_len=None, + trust_remote_code=True, + use_fast=True, ) + if config.bos_token is not None: + self.tokenizer.bos_token = config.bos_token if self.tokenizer.eos_token_id is None: raise ValueError("Tokenizer does not have an EOS token.") if self.tokenizer.bos_token_id is None: @@ -52,7 +58,7 @@ def tokenize_with_spans( token_spans = [] char_pos = 0 beginning_of_text = True - + for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py new file mode 100644 index 000000000..7223631f8 --- /dev/null +++ b/fast_llm/engine/evaluation/config.py @@ -0,0 +1,64 @@ +import abc +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLoss + + +@config_class() +class EvaluatorConfigBase(Config): + @abc.abstractmethod + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "Evaluator": + pass + + +@config_class(registry=True) +class EvaluatorConfig(EvaluatorConfigBase): + _abstract: typing.ClassVar[bool] = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + if not "type" in default: + default["type"] = "loss" + return super()._from_dict(default, strict, flat) + + +@config_class(dynamic_type={EvaluatorConfig: "loss"}) +class EvaluatorLossConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + + iterations: int | None = Field( + default=None, + desc="Number of iterations for each evaluation phase. Setting to None will disable.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + + dataset_name: str | None = Field(default=None) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "EvaluatorLoss": + from fast_llm.engine.evaluation.evaluator import EvaluatorLoss + + return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py new file mode 100644 index 000000000..f07a8c48e --- /dev/null +++ b/fast_llm/engine/evaluation/evaluator.py @@ -0,0 +1,308 @@ +import abc +import dataclasses +import logging +import time +import typing + +from fast_llm.config import Configurable +from fast_llm.core.distributed import safe_barrier +from fast_llm.data.data.abstract import Data +from fast_llm.engine.config_utils.run import Run, log_main_rank +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, EvaluatorLossConfig +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.engine.training.config import WandbConfig +from fast_llm.engine.training.wandb import Wandb +from fast_llm.logging import format_metrics, get_memory_usage_mib + +# from fast_llm.engine.training.lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class TrainingProgress: + done: bool + completed_steps: int + consumed_samples: int + consumed_tokens: int + + +@dataclasses.dataclass +class EvaluationMetrics: + metrics: dict[str, any] = dataclasses.field(default_factory=dict) + formatted_metrics: str | None = None + + +@dataclasses.dataclass +class EvaluatorSamplingParameters: + dataset_name: str + num_samples: int + + +class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): + config_class: typing.ClassVar[type[EvaluatorConfig]] = EvaluatorConfig + + _is_setup: bool = False + + def __init__( + self, + name: str, + eval_config: EvaluatorLossConfig, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ): + super().__init__(eval_config) + self._name = name + self._batch_config = batch_config + self._data_load_num_proc = data_load_num_proc + self._train_iters = train_iters + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + # TODO: check if objects passed are actually set up themselves, if appropriate + self._distributed = distributed + self._run = run + self._runner = runner + self._multi_stage = multi_stage + self._data = data + self._phase = phase + + @abc.abstractmethod + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: ... + + @abc.abstractmethod + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + """ + Returns the name and number of required samples in a dataset, + or None if the evaluation does not rely on Fast-LLM data or + if the evaluation is skipped for this run. + """ + + +class EvaluatorLoss[ConfigType: EvaluatorLossConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[EvaluatorLossConfig]] = EvaluatorLossConfig + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + super().setup(distributed, run, multi_stage, runner, data, phase) + + # Setup the schedule + self._schedule = Schedule( + multi_stage=self._multi_stage, + batch_config=self._batch_config, + schedule_config=runner.config, + distributed_config=distributed.config, + phase=PhaseType.validation, + ) + + self._loss_defs = self._multi_stage.base_model.loss_defs + self._evaluation_iterator = None + self._is_setup = True + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return EvaluatorSamplingParameters( + (self._name if self._config.dataset_name is None else self._config.dataset_name), + self._config.iterations * self._batch_config.batch_size, + ) + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + if run_index is None: + run_index = 0 + + metrics = {} + formatted_metrics = None + + if self._evaluation_iterator is None: + self._evaluation_iterator = self._get_data_iterator(self._get_completed_evaluation_steps(run_index)) + # TODO: formatting metric category as Validation.evaluation_dataset_name + # maybe format each metric with evaluation_dataset_name prefix instead? + # TODO: setting performance metrics per evaluation dataset + # maybe to set aggregate performance metrics for all evaluations datasets? + phase = PhaseType.validation + metric_key = f"{phase.value}.{self._name}" + metrics[metric_key] = self._evaluate_loss( + data_iterator=self._evaluation_iterator, + phase=phase, + num_iters=self._config.iterations, + begin_iter=self._get_completed_evaluation_steps(run_index), + completed_steps=None if training_progress is None else training_progress.completed_steps, + ) + + if self._train_iters is not None: + metrics[metric_key]["train_iters"] = self._train_iters + + if training_progress is not None: + metrics[metric_key]["iteration"] = training_progress.completed_steps + metrics[metric_key]["consumed_samples"] = training_progress.consumed_samples + metrics[metric_key]["consumed_tokens"] = training_progress.consumed_tokens + + formatted_metrics = format_metrics( + metrics[metric_key], + self._loss_defs, + phase, + dataset_name=self._name, + ) + + return EvaluationMetrics(metrics, formatted_metrics) + + def _evaluate_loss( + self, + *, + data_iterator: typing.Iterator, + phase: PhaseType, + num_iters: int, + completed_steps: int | None, + begin_iter: int = 0, + ) -> dict[str, float | int]: + full_phase_name = f"{phase.value}_{self._name}" + safe_barrier(self._distributed.world_group, f"{full_phase_name} begin") + begin_time = time.perf_counter() + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + for iter_ in range(num_iters): + iter_losses, _, _ = self._runner.run_step(data_iterator, self._schedule, iteration=begin_iter + iter_) + for name, value in iter_losses.items(): + total_losses[name] += value + + tensor_save_name = ( + f"{full_phase_name}_{iter_}" + if completed_steps is None + else f"{full_phase_name}_{completed_steps}_{iter_}" + ) + self._run.save_logged_tensors(tensor_save_name) + + safe_barrier( + self._distributed.world_group, + f"{full_phase_name} end", + ) + end_time = time.perf_counter() + time_per_iteration = (end_time - begin_time) / num_iters + model_tflops, hardware_tflops = self._multi_stage.get_tflops( + phase, + time_per_iteration, + self._batch_config.batch_size, + self._batch_config.sequence_length, + ) + # TODO add other relevant eval metrics + metrics = { + "batch_size": self._batch_config.batch_size, + **{name: (value / num_iters) for name, value in total_losses.items()}, + "step_time_ms": time_per_iteration * 1000, + "model_tflops": model_tflops, + "hardware_tflops": hardware_tflops, + "tokens_per_sec_per_gpu": ( + (self._batch_config.sequence_length * self._batch_config.batch_size) + / self._schedule._distributed.world_size + / time_per_iteration + ), + **get_memory_usage_mib(), + } + return metrics + + def _get_completed_evaluation_steps(self, run_index: int) -> int: + # Number of evaluations steps performed before the current step + return max(0, run_index - 1) * self.config.iterations + + def _get_data_iterator( + self, completed_steps: int = 0, prefetch_factor: int | None = None + ) -> typing.Iterator[typing.Any]: + return self._data.get_iterator( + self._batch_config, + self._name, + consumed_samples=completed_steps * self._batch_config.batch_size, + num_workers=self._data_load_num_proc, + prefetch_factor=prefetch_factor, + ) + + +# NOTE: This is not a standalone runnable; it's a submodule of Trainer used for code encapsulation. +class EvaluatorRunner: + _is_setup: bool = False + + def __init__( + self, + evaluator_configs: dict[str, EvaluatorConfigBase], + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + wandb_config: WandbConfig | None = None, + ): + self._wandb_config = wandb_config + self._evaluators = [ + eval_config.get_evaluator(name, batch_config, data_load_num_proc, train_iters) + for name, eval_config in evaluator_configs.items() + ] + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + wandb: Wandb, + phase: PhaseType, + ) -> None: + self._wandb = wandb + for evaluator in self._evaluators: + evaluator.setup(distributed, run, multi_stage, runner, data, phase) + self._is_setup = True + + def get_sampling_parameters(self) -> list[EvaluatorSamplingParameters]: + return [ + sampling_params + for sampling_params in (evaluator.get_sampling_parameters() for evaluator in self._evaluators) + if sampling_params is not None + ] + + def run( + self, + metrics: dict[str:any], + training_progress: TrainingProgress | None = None, + ): + assert self._is_setup + formatted_metrics = [] + for evaluator in self._evaluators: + evaluation_metrics = evaluator.run(training_progress) + if len(evaluation_metrics.metrics) == 0: + continue + for k, v in evaluation_metrics.metrics.items(): + metrics[k] = v + if evaluation_metrics.formatted_metrics is not None: + formatted_metrics.append(evaluation_metrics.formatted_metrics) + + if len(formatted_metrics) > 0: + formatted_metrics = "\n".join(formatted_metrics) + log_main_rank(formatted_metrics) + if self._wandb_config is not None and self._wandb_config.alert.enabled( + 0 if training_progress is None else training_progress.completed_steps + ): + self._wandb.alert("Validation results", formatted_metrics, "INFO") diff --git a/fast_llm/engine/evaluation/evaluators.py b/fast_llm/engine/evaluation/evaluators.py new file mode 100644 index 000000000..9e3b1e9b9 --- /dev/null +++ b/fast_llm/engine/evaluation/evaluators.py @@ -0,0 +1,12 @@ +from fast_llm.config import config_class +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.training.config import TrainerConfig + + +@config_class(dynamic_type={RunnableConfig: "evaluate"}) +class EvaluatorsConfig(RunnableConfig): + + @classmethod + def parse_and_run(cls, args: list[str] | None = None) -> None: + args.append("training.train_iters=0") + return TrainerConfig.parse_and_run(args) diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index b09c88baf..b414323e4 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -108,7 +108,8 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything) + if self.fast_llm_config is not None: + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 554da8cd1..3c2db428d 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -38,7 +38,13 @@ def __init__( assert config.fast_llm_config is fast_llm_model.config assert isinstance(config, self.config_class) + # The HF constructor performs a deep copy of the config, + # but config.fast_llm_config may contain non-picklable items like process groups. + # Temporarily remove it before the call and restore it afterward. + fast_llm_config = config.fast_llm_config + config.fast_llm_config = None super().__init__(config, **kwargs) + config.fast_llm_config = fast_llm_config self._inference_runner = self.runner_class(fast_llm_model, runner) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 497d11108..00570be99 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -1,3 +1,4 @@ +import abc import dataclasses import logging import typing @@ -12,7 +13,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP @@ -252,6 +253,11 @@ def setup(self, distributed: Distributed | None = None, mode: StageMode = StageM self.train(self._mode.support_backward) + @abc.abstractmethod + def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: + # TODO: Do in model, automate/generalize, get other stats + pass + def _allocate_buffers( self, buffer_meta: TensorMeta, sizes: list[int], name: str ) -> tuple[tuple[torch.Tensor, ...], int]: diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index fbc94ac83..efe8f714b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -24,6 +24,7 @@ ) from fast_llm.engine.config_utils.run import ExperimentConfig from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -32,7 +33,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.inference.runner import InferenceRunner - from fast_llm.engine.training.trainer import Trainer + from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator @config_class() @@ -152,22 +153,34 @@ class WandbConfig(Config): @config_class() -class EvaluationConfig(IntervalConfig): - interval = FieldUpdate( - desc="The number of training iterations between each evaluation phase." - " Setting to None will disable evaluation." - ) - offset = FieldUpdate(desc="Offset for the first evaluation phase.") - iterations: int | None = Field( - default=None, - desc="Number of iterations for each evaluation phase. Setting to None will disable.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) +class TrainingEvaluatorConfig(EvaluatorConfigBase, IntervalConfig): + evaluator: EvaluatorConfig = Field(desc="Evaluator to run") + + def get_run_count(self, training_iterations: int, extra_evaluations: int = 0): + # Number of completed evaluation runs + return (self.get_count(training_iterations) + extra_evaluations) if self.enabled() else 0 + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "TrainingEvaluator": + from fast_llm.engine.training.trainer import TrainingEvaluator + + return TrainingEvaluator(name, self, batch_config, data_load_num_proc, train_iters) - def get_iteration_count(self, training_iterations: int, extra_evaluations: int = 0): - # Number of completed validation iterations - return (self.get_count(training_iterations) + extra_evaluations) * self.iterations if self.enabled() else 0 + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + cls._handle_renamed_field(default, "iterations", ("evaluator", "iterations")) + return super()._from_dict(default, strict, flat) @config_class() @@ -277,7 +290,7 @@ class ShutdownConfig(IntervalConfig): @config_class() class TrainingConfig(Config): - evaluations: dict[str, EvaluationConfig] = Field( + evaluators: dict[str, TrainingEvaluatorConfig] = Field( default_factory=dict, desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, @@ -325,7 +338,8 @@ def _from_dict( flat: bool = False, ) -> typing.Self: # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field(default, "validation", ("evaluations", "validation")) + cls._handle_renamed_field(default, "validation", ("evaluators", "validation")) + cls._handle_renamed_field(default, "evaluations", ("evaluators")) return super()._from_dict(default, strict, flat) def _validate(self) -> None: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index f96b6b91b..a3cf078dc 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -15,12 +15,26 @@ from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorRunner, + EvaluatorSamplingParameters, + TrainingProgress, +) from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.engine.training.config import TrainerConfig, TrainingCheckpointBaseConfig, TrainingCheckpointConfig +from fast_llm.engine.training.config import ( + TrainerConfig, + TrainingCheckpointBaseConfig, + TrainingCheckpointConfig, + TrainingEvaluatorConfig, +) from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage from fast_llm.utils import Assert, Interrupter @@ -28,6 +42,77 @@ logger = logging.getLogger(__name__) +class TrainingEvaluator[ConfigType: TrainingEvaluatorConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[TrainingEvaluatorConfig]] = TrainingEvaluatorConfig + + evaluator: Evaluator + + def __init__( + self, + name: str, + eval_config: TrainingEvaluatorConfig, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ): + super().__init__(name, eval_config, batch_config, data_load_num_proc, train_iters) + + self._train_iters = 0 if self._train_iters is None else self._train_iters + + self.evaluator = eval_config.evaluator.get_evaluator(name, batch_config, data_load_num_proc, train_iters) + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + self.evaluator.setup( + distributed, + run, + multi_stage, + runner, + data, + phase, + ) + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + # Run index must be None because it is defined here to be passed to actual evaluator + assert run_index is None + + # Training progress can be None as it can be run in a training + # run without training, just evaluation + if training_progress is None: + done = True + completed_steps = 0 + else: + done = training_progress.done + completed_steps = training_progress.completed_steps + + if done or self.config.enabled(completed_steps): + return self.evaluator.run(training_progress, run_index=self._config.get_run_count(completed_steps - 1)) + else: + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + name_samples = self.evaluator.get_sampling_parameters() + if name_samples is None: + return None + run_count = self._config.get_run_count( + self._train_iters, + # There may be an extra evaluation after the last training step.s + not self._config.enabled(self._train_iters), + ) + return EvaluatorSamplingParameters(name_samples.dataset_name, name_samples.num_samples * run_count) + + class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): config_class: typing.ClassVar[type[TrainerConfig]] = TrainerConfig # TODO: Generalize data, schedule, logging, etc. @@ -39,13 +124,20 @@ class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): _completed_steps: int + _is_evaluation_only: bool + + _evaluator_runner: EvaluatorRunner + def __init__(self, config: TrainerConfig): super().__init__(config) + + self._is_evaluation_only = config.training.train_iters == 0 + self._data = self._get_data() log_main_rank("Creating model...") self._multi_stage = self._config.model.get_model_class()( self._config.model, - optimizer_state_names=self._config.optimizer.state_names(), + optimizer_state_names=self._config.optimizer.state_names() if not self._is_evaluation_only else (), ) self._reference_models = {} for name, reference_config in self._config.reference_models.items(): @@ -55,51 +147,54 @@ def __init__(self, config: TrainerConfig): ) self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) - phase: PhaseType self._runner = ScheduleRunner( config=self._config.schedule, multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - steps_per_split = { - PhaseType.training: {PhaseType.training.value.lower(): self._config.training.train_iters}, - PhaseType.validation: { - dataset_name: self._config.training.evaluations[dataset_name].get_iteration_count( - self._config.training.train_iters, - # There may be an extra evaluation after the last training step. - not self._config.training.evaluations[dataset_name].enabled(self._config.training.train_iters), - ) - for dataset_name in self._config.training.evaluations.keys() - }, - PhaseType.test: {PhaseType.test.value.lower(): self._config.training.test_iters}, - } - self._samples_per_split = { - phase: { - dataset_name: self._config.batch.batch_size * steps - for dataset_name, steps in datasets.items() - if steps > 0 - } - for phase, datasets in steps_per_split.items() - } - # Prune empty phases. - self._samples_per_split = {k: v for k, v in self._samples_per_split.items() if len(v) > 0} - self._loss_defs = self._multi_stage.base_model.loss_defs - # Setup the schedules - self._schedule = { - phase: { - dataset_name: Schedule( - multi_stage=self._multi_stage, - batch_config=self._config.batch, - schedule_config=self._config.schedule, - distributed_config=self._config.model.distributed, - phase=phase, - ) - for dataset_name in datasets + if not self._is_evaluation_only: + steps_per_split = { + PhaseType.training: {PhaseType.training.value.lower(): self._config.training.train_iters}, + PhaseType.test: {PhaseType.test.value.lower(): self._config.training.test_iters}, } - for phase, datasets in self._samples_per_split.items() - } + + self._samples_per_split = { + phase: { + dataset_name: self._config.batch.batch_size * steps + for dataset_name, steps in datasets.items() + if steps > 0 + } + for phase, datasets in steps_per_split.items() + } + # Prune empty phases. + self._samples_per_split = {k: v for k, v in self._samples_per_split.items() if len(v) > 0} + + # Setup the schedules + self._schedule = { + phase: { + dataset_name: Schedule( + multi_stage=self._multi_stage, + batch_config=self._config.batch, + schedule_config=self._config.schedule, + distributed_config=self._config.model.distributed, + phase=phase, + ) + for dataset_name in datasets + } + for phase, datasets in self._samples_per_split.items() + } + else: + self._samples_per_split = {} + + self._evaluator_runner = EvaluatorRunner( + evaluator_configs=self._config.training.evaluators, + batch_config=self._config.batch, + data_load_num_proc=self._config.training.num_workers, + train_iters=self._config.training.train_iters, + wandb_config=self._config.training.wandb, + ) def setup(self, distributed: Distributed, run: Run) -> None: assert distributed.config is self._config.model.distributed @@ -118,13 +213,16 @@ def setup(self, distributed: Distributed, run: Run) -> None: reference_model.setup() # Setup the optimizer. - param_groups, grads_for_norm = self._multi_stage.get_param_groups(ParamGroup) - self._optimizer = self._config.optimizer.optimizer_cls( - self._config.optimizer, - param_groups=param_groups, - grads_for_norm=grads_for_norm, - distributed=self._distributed, - ) + if self._is_evaluation_only: + self._optimizer = None + else: + param_groups, grads_for_norm = self._multi_stage.get_param_groups(ParamGroup) + self._optimizer = self._config.optimizer.optimizer_cls( + self._config.optimizer, + param_groups=param_groups, + grads_for_norm=grads_for_norm, + distributed=self._distributed, + ) # Setup the schedules. with torch.no_grad(): @@ -137,10 +235,28 @@ def setup(self, distributed: Distributed, run: Run) -> None: dataset_name: self._get_sampling_parameters({"num_samples": samples}) for datasets in self._samples_per_split.values() for dataset_name, samples in datasets.items() + } + | { + eval_sampling_params.dataset_name: self._get_sampling_parameters( + {"num_samples": eval_sampling_params.num_samples} + ) + for eval_sampling_params in self._evaluator_runner.get_sampling_parameters() }, None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", timeout=self._config.training.timeout, ) + + # Must be called with all arguments set up + self._evaluator_runner.setup( + distributed=self._distributed, + run=self._run, + multi_stage=self._multi_stage, + runner=self._runner, + data=self._data, + wandb=self._wandb, + phase=PhaseType.inference if self._is_evaluation_only else PhaseType.validation, + ) + self._is_setup = True @abc.abstractmethod @@ -162,10 +278,6 @@ def _consumed_tokens(self) -> int: assert self._is_setup return self._consumed_samples * self._config.batch.sequence_length - def _get_completed_evaluation_steps(self, dataset_name) -> int: - # Number of evaluations steps performed before the current step - return self._config.training.evaluations[dataset_name].get_iteration_count(self._completed_steps - 1) - def run(self) -> None: assert self._is_setup with self._wandb: @@ -173,10 +285,14 @@ def run(self) -> None: def _run_training(self) -> None: self._prepare_training_state() + log_main_rank("done with setup ...") log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"After initial setup", str)) self._run.save_logged_tensors("init") + if self._is_evaluation_only: + assert len(self._samples_per_split) == 0 + if PhaseType.training in self._samples_per_split: done = self._completed_steps >= self._config.training.train_iters if done: @@ -185,13 +301,15 @@ def _run_training(self) -> None: else: done, metrics = self._train() else: - done, metrics = True, {} + metrics = {} + done = True + self._evaluator_runner.run(metrics=metrics) if done and PhaseType.test in self._samples_per_split: log_main_rank(lambda: f"Running test phase ...") test_iterator = self._get_data_iterator(PhaseType.test.value.lower()) metrics_key = PhaseType.test.value - metrics[metrics_key] = self._evaluate( + metrics[metrics_key] = self._evaluate_loss( data_iterator=test_iterator, phase=PhaseType.test, num_iters=self._config.training.test_iters, @@ -220,7 +338,6 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._completed_steps, self._config.training.prefetch_factor, ) - evaluation_iterators = {name: None for name in self._config.training.evaluations.keys()} log_main_rank("Training ...") @@ -272,7 +389,12 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: remaining_time = average_time_per_iteration * ( self._config.training.train_iters - self._completed_steps ) - model_tflops, hardware_tflops = self.get_tflops(PhaseType.training, time_per_iteration) + model_tflops, hardware_tflops = self._multi_stage.get_tflops( + PhaseType.training, + time_per_iteration, + self._config.batch.batch_size, + self._config.batch.sequence_length, + ) metrics_key = PhaseType.training.value metrics[metrics_key] = { "train_iters": self._config.training.train_iters, @@ -318,50 +440,20 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: profiler.step() done = self._completed_steps >= self._config.training.train_iters + # TODO: Signal-based stop. + stop = done or self._config.training.shutdown.enabled(self._completed_steps) # Evaluation # TODO: Adjust valid iterator length. - if PhaseType.validation in self._samples_per_split and ( - done - or any( - evaluation_conf.enabled(self._completed_steps) - for evaluation_conf in self._config.training.evaluations.values() - ) - ): - formatted_metrics = [] - for dataset_name, evaluation_conf in self._config.training.evaluations.items(): - if not evaluation_conf.enabled(self._completed_steps): - continue - if evaluation_iterators[dataset_name] is None: - evaluation_iterators[dataset_name] = self._get_data_iterator( - dataset_name, self._get_completed_evaluation_steps(dataset_name) - ) - # TODO: formatting metric category as Validation.evaluation_dataset_name - # maybe format each metric with evaluation_dataset_name prefix instead? - # TODO: setting performance metrics per evaluation dataset - # maybe to set aggregate performance metrics for all evaluations datasets? - metric_key = f"{PhaseType.validation.value}.{dataset_name}" - metrics[metric_key] = self._evaluate( - data_iterator=evaluation_iterators[dataset_name], - phase=PhaseType.validation, - num_iters=evaluation_conf.iterations, - begin_iter=self._get_completed_evaluation_steps(dataset_name), - dataset_name=dataset_name, - ) - formatted_metrics.append( - format_metrics( - metrics[metric_key], - self._loss_defs, - PhaseType.validation, - dataset_name=dataset_name, - ) - ) - - if len(formatted_metrics) > 0: - formatted_metrics = "\n".join(formatted_metrics) - log_main_rank(formatted_metrics) - if self._config.training.wandb.alert.enabled(self._completed_steps): - self._wandb.alert("Validation results", formatted_metrics, "INFO") + self._evaluator_runner.run( + metrics=metrics, + training_progress=TrainingProgress( + done=done, + completed_steps=self._completed_steps, + consumed_samples=self._consumed_samples, + consumed_tokens=self._consumed_tokens, + ), + ) if is_main_rank() and metrics: self._wandb.log_metrics(self._completed_steps, metrics) @@ -383,55 +475,6 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: profiler.step() return done, metrics - def _evaluate( - self, - *, - data_iterator: typing.Iterator, - phase: PhaseType, - num_iters: int, - begin_iter: int = 0, - dataset_name: str | None = None, - ) -> dict[str, float | int]: - full_phase_name = phase.value if dataset_name is None else f"{phase.value}_{dataset_name}" - safe_barrier(self._distributed.world_group, f"{full_phase_name} begin") - begin_time = time.perf_counter() - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} - for iter_ in range(num_iters): - iter_losses, _, _ = self._runner.run_step( - data_iterator, self._schedule[phase][dataset_name], iteration=begin_iter + iter_ - ) - for name, value in iter_losses.items(): - total_losses[name] += value - self._run.save_logged_tensors(f"{full_phase_name}_{self._completed_steps}_{iter_}") - - safe_barrier( - self._distributed.world_group, - f"{full_phase_name} end", - ) - end_time = time.perf_counter() - time_per_iteration = (end_time - begin_time) / num_iters - model_tflops, hardware_tflops = self.get_tflops(phase, time_per_iteration) - # TODO add other relevant eval metrics - metrics = { - "train_iters": self._config.training.train_iters, - "batch_size": self._config.batch.batch_size, - "iteration": self._completed_steps, - **{name: (value / num_iters) for name, value in total_losses.items()}, - "consumed_samples": self._consumed_samples, - "consumed_tokens": self._consumed_tokens, - "step_time_ms": time_per_iteration * 1000, - "model_tflops": model_tflops, - "hardware_tflops": hardware_tflops, - "tokens_per_sec_per_gpu": ( - (self._config.batch.sequence_length * self._config.batch.batch_size) - / self._config.model.distributed.world_size - / time_per_iteration - ), - **get_memory_usage_mib(), - } - - return metrics - def _get_data_iterator( self, dataset_name, completed_steps: int = 0, prefetch_factor: int | None = None ) -> typing.Iterator[typing.Any]: @@ -455,9 +498,15 @@ def _prepare_training_state(self) -> None: ) self._multi_stage.load_checkpoint(self._config.pretrained) else: + if self._is_evaluation_only: + raise ValueError( + "Evaluation mode, model need to be trained first or pretrained checkpoint is provided for loading" + ) log_main_rank(f"Initializing training state from scratch...") self._multi_stage.initialize_weights() - self._optimizer.reset_state() + + if not self._is_evaluation_only: + self._optimizer.reset_state() self._completed_steps = 0 else: log_main_rank(lambda: f"Loading checkpoint from iteration {last_iteration}...") @@ -534,7 +583,8 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout) ) assert metadata is not None - self._optimizer.load(metadata["optimizer"]) + if not self._is_evaluation_only: + self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] @@ -561,8 +611,3 @@ def _get_last_checkpoint(self) -> int | None: iteration = -1 iteration = self._run.broadcast_int(iteration) return iteration if iteration >= 0 else None - - @abc.abstractmethod - def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: - # TODO: Do in model, automate/generalize, get other stats - pass diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 233887ec6..52637869d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -246,12 +246,15 @@ def _logits_cross_entropy_forward_backward_split( split_size = div(target.size(0), self._cross_entropy_splits) grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) - logit_input_grad = torch.empty_like(logit_input) + if self.training: + logit_input_grad = torch.empty_like(logit_input) + else: + logit_input_grad = None for logit_input_, target_, loss_mask_, logit_input_grad_ in zip( logit_input.split(split_size), target.split(split_size), [None] * self._cross_entropy_splits if loss_mask is None else loss_mask.split(split_size), - logit_input_grad.split(split_size), + logit_input_grad.split(split_size) if self.training else [None] * split_size, strict=True, ): loss_, grad_ = self._logits_cross_entropy_forward_backward( @@ -263,7 +266,8 @@ def _logits_cross_entropy_forward_backward_split( kwargs, ) # TODO: Avoid copy with explicit out argument. - logit_input_grad_.copy_(grad_) + if self.training: + logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) @@ -272,7 +276,7 @@ def _logits_cross_entropy_forward_backward_split( if self._sequence_parallel_logits: # TODO: Async all_reduce(loss, group=self._tensor_space.distributed.tensor_group) - return loss, logit_input_grad.view_as(input_) + return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( self, @@ -358,4 +362,4 @@ def _logits_cross_entropy_forward_backward( # TODO: de-allocate earlier. del logits - return loss, output_parallel_linear_backward(grad, context) + return loss, output_parallel_linear_backward(grad, context) if self.training else None diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 9c791ba64..a70cacce6 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -92,11 +92,13 @@ _METRIC_FORMATS_KEYS = { PhaseType.training: _TRAINING_METRIC_FORMAT_KEYS, PhaseType.validation: _VALIDATION_METRIC_FORMAT_KEYS, + PhaseType.inference: _VALIDATION_METRIC_FORMAT_KEYS, PhaseType.test: _VALIDATION_METRIC_FORMAT_KEYS, } _METRIC_FORMATS = { PhaseType.training: _TRAINING_METRIC_FORMATS, PhaseType.validation: _VALIDATION_METRIC_FORMATS, + PhaseType.inference: _VALIDATION_METRIC_FORMATS, PhaseType.test: _VALIDATION_METRIC_FORMATS, } diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 3be748568..a5860096e 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -5,3 +5,5 @@ from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridSSMTrainerConfig # isort: skip + +from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 582575c01..f19ef151b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -405,6 +405,64 @@ class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel + def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: + # TODO: Do in model, automate/generalize, get other stats + """Get tflop/s/GPU from global-batch-size and elapsed-time""" + checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 + transformer_config = self._config.base_model.transformer + + consumed_tokens_per_iteration = sequence_length * batch_size + + num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 + transformer_flops_base = ( + 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers + ) + dense_flops_base = transformer_flops_base * transformer_config.hidden_size + # Query, key, value, dense. + flops_per_iteration = ( + 2 + * (transformer_config.num_attention_heads + transformer_config.head_groups) + * transformer_config.kv_channels + * dense_flops_base + ) + # MLP + flops_per_iteration += ( + (2 + transformer_config.gated) + * transformer_config.ffn_hidden_size + * dense_flops_base + * transformer_config.num_experts_per_token + ) + + # LM-head + flops_per_iteration += ( + 6 + * consumed_tokens_per_iteration + * transformer_config.hidden_size + * self._config.base_model.vocab_size + * self._config.base_model.prediction_heads + ) + + # Attention-matrix computation + attn_flops_base = transformer_flops_base * transformer_config.projection_size + if transformer_config.window_size is None: + # Ignore masked values (s**2/2) + attn_flops = attn_flops_base * sequence_length + model_tflops = flops_per_iteration + attn_flops + else: + # s*w - w**2/2 + attn_flops = ( + 2 + * attn_flops_base + * transformer_config.window_size + * (1 - transformer_config.window_size / 2 / sequence_length) + ) + model_tflops = flops_per_iteration + attn_flops + + # Partial recomputation (normal is 2 ops * ckpt_factor = 6, adding 1 for recomputing Q x K) + hardware_flops = flops_per_iteration + 7 / 6 * attn_flops + ratio = elapsed_time_per_iteration * self._config.distributed.world_size * 1e12 + return model_tflops / ratio, hardware_flops / ratio + class GPTInferenceRunner(InferenceRunner): model_class: typing.ClassVar[type[GPTModel]] = GPTModel diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index cc39d7f70..0b2bb3433 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -3,7 +3,6 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -34,59 +33,3 @@ def _get_sampling_parameters( } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) - - def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: - # TODO: Do in model, automate/generalize, get other stats - """Get tflop/s/GPU from global-batch-size and elapsed-time""" - checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 - transformer_config = self._config.model.base_model.transformer - sequence_length = self._config.batch.sequence_length - - tokens = self._config.batch.batch_size * sequence_length - num_transformer_layers = transformer_config.num_layers + self._config.model.base_model.prediction_heads - 1 - transformer_flops_base = 2 * checkpoint_activations_factor * tokens * num_transformer_layers - dense_flops_base = transformer_flops_base * transformer_config.hidden_size - # Query, key, value, dense. - flops_per_iteration = ( - 2 - * (transformer_config.num_attention_heads + transformer_config.head_groups) - * transformer_config.kv_channels - * dense_flops_base - ) - # MLP - flops_per_iteration += ( - (2 + transformer_config.gated) - * transformer_config.ffn_hidden_size - * dense_flops_base - * transformer_config.num_experts_per_token - ) - - # LM-head - flops_per_iteration += ( - 6 - * tokens - * transformer_config.hidden_size - * self._config.model.base_model.vocab_size - * self._config.model.base_model.prediction_heads - ) - - # Attention-matrix computation - attn_flops_base = transformer_flops_base * transformer_config.projection_size - if transformer_config.window_size is None: - # Ignore masked values (s**2/2) - attn_flops = attn_flops_base * sequence_length - model_tflops = flops_per_iteration + attn_flops - else: - # s*w - w**2/2 - attn_flops = ( - 2 - * attn_flops_base - * transformer_config.window_size - * (1 - transformer_config.window_size / 2 / sequence_length) - ) - model_tflops = flops_per_iteration + attn_flops - - # Partial recomputation (normal is 2 ops * ckpt_factor = 6, adding 1 for recomputing Q x K) - hardware_flops = flops_per_iteration + 7 / 6 * attn_flops - ratio = elapsed_time_per_iteration * self._config.model.distributed.world_size * 1e12 - return model_tflops / ratio, hardware_flops / ratio diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 74df75ba7..4e20e52b0 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -37,8 +37,8 @@ def test_checkpoint_and_eval(run_test_script): CONFIG_COMMON + [ "training.checkpoint.interval=1", - "training.evaluations.validation.interval=2", - "training.evaluations.validation.iterations=1", + "training.evaluators.validation.interval=2", + "training.evaluators.validation.evaluator.iterations=1", ], ) @@ -71,8 +71,8 @@ def test_resume(run_test_script): CONFIG_COMMON + [ "training.checkpoint.interval=1", - "training.evaluations.validation.interval=2", - "training.evaluations.validation.iterations=1", + "training.evaluators.validation.interval=2", + "training.evaluators.validation.evaluator.iterations=1", ], compare=f"test_{TEST_MODEL}_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, @@ -88,8 +88,8 @@ def test_resume_frozen(run_test_script): CONFIG_COMMON + [ "training.checkpoint.interval=1", - "training.evaluations.validation.interval=2", - "training.evaluations.validation.iterations=1", + "training.evaluators.validation.interval=2", + "training.evaluators.validation.evaluator.iterations=1", "model.base_model.transformer.mlp_lr_scale=0.", ], compare=f"test_{TEST_MODEL}_checkpoint_and_eval", diff --git a/tests/test_gpt_loss.py b/tests/test_gpt_loss.py new file mode 100644 index 000000000..89262eca1 --- /dev/null +++ b/tests/test_gpt_loss.py @@ -0,0 +1,121 @@ +import math + +import torch + +from fast_llm.config import NoAutoValidate +from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.optimizer.config import OptimizerConfig +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig +from tests.test_gpt_generate_and_forward import model_and_tokenizer # noqa: F401 +from tests.utils.utils import requires_cuda + + +def _get_model_runner_schedule( + model_path: str, + use_flash_attention: bool, + use_bf16: bool, + checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + phase=PhaseType.inference, +): + assert phase == PhaseType.inference or phase == PhaseType.validation + updates = { + ("pretrained", "path"): model_path, + ("pretrained", "model_weights"): True, + ("pretrained", "format"): checkpoint_format.name, + ("model", "base_model", "cross_entropy_impl"): "fused", + ("model", "multi_stage", "zero_stage"): 2, + } + + if use_flash_attention: + updates[("model", "base_model", "transformer", "use_flash_attention")] = True + updates[("model", "distributed", "training_dtype")] = "bf16" + else: + updates[("model", "base_model", "transformer", "use_flash_attention")] = False + if use_bf16: + updates[("model", "distributed", "training_dtype")] = "bf16" + + config = PretrainedGPTModelConfig.from_dict({}, updates) + multi_stage = config.model.get_model_class()( + config.model, optimizer_state_names=OptimizerConfig.state_names() if phase == PhaseType.validation else () + ) + schedule_config = ScheduleConfig() + with NoAutoValidate(): + batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=2048, batch_size=2) + batch_config.setup(config.model.distributed) + batch_config.validate() + + schedule = Schedule( + multi_stage=multi_stage, + batch_config=batch_config, + schedule_config=schedule_config, + distributed_config=config.model.distributed, + phase=phase, + ) + + runner = ScheduleRunner( + config=schedule_config, + multi_stage=multi_stage, + distributed_config=config.model.distributed, + ) + + distributed = Distributed(config.model.distributed) + + with torch.no_grad(): + multi_stage.setup(distributed) + + with torch.no_grad(): + runner.setup(distributed) + + multi_stage.load_checkpoint(config.pretrained) + + return multi_stage, runner, schedule, batch_config + + +def _test_for_phase(model_path, fast_llm_checkpoint_format, phase): + model, runner, schedule, batch_config = _get_model_runner_schedule( + model_path, True, True, fast_llm_checkpoint_format, phase + ) + + inputs = GPTBatch( + torch.randint( + 1, + model.config.base_model.vocab_size, + [2, batch_config.sequence_length + 1], + dtype=torch.int64, + generator=torch.Generator().manual_seed(42), + ) + ) + + iteration = 1 + + # we need to set phase to validation here so preprocess would crate labels from input + # so it is the same process for validation and inference phases + # otherwise we can add labels manually after preprocess for inference phase + batch = model.base_model.preprocess(inputs, phase=PhaseType.validation, iteration=iteration) + ((inputs_, kwargs),) = batch + kwargs[LanguageModelKwargs.phase] = phase + iter_losses, _, _ = runner.run_step( + iter((((inputs_, kwargs),),)), schedule, iteration=iteration, preprocessed=True + ) + + return iter_losses + + +# @pytest.mark.extra_slow +@requires_cuda +def test_loss_validation_vs_inference(model_and_tokenizer): + model_path, _, fast_llm_checkpoint_format = model_and_tokenizer + + iter_losses_validation = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.validation) + + iter_losses_inference = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.inference) + + assert len(iter_losses_validation) == len(iter_losses_inference) + for key in iter_losses_validation.keys(): + assert math.isclose(iter_losses_validation[key], iter_losses_inference[key], rel_tol=1e-5)