From ebe8cddd2a08c70b8c461e9ec5448c057a29e07b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 28 Nov 2024 16:49:09 +0100 Subject: [PATCH 1/2] feat: added more options to the trainer --- changelog.md | 4 ++++ edsnlp/training/trainer.py | 35 +++++++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/changelog.md b/changelog.md index e46fa5761..773bb8ff1 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,10 @@ ### Added - `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster. +- Accept no validation data in `edsnlp.train` script +- Log the training config at the beginning of the trainings +- Support a specific model output dir path for trainings (`output_model_dir`), and whether to save the model or not (`save_model`) +- Specify whether to log the validation results or not (`logger=False`) ### Fixed diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 8e819702d..8abb12612 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -300,7 +300,7 @@ def train( *, nlp: Pipeline, train_data: AsList[TrainingData], - val_data: AsList[Stream], + val_data: AsList[Stream] = [], seed: int = 42, max_steps: int = 1000, optimizer: Union[ScheduledOptimizer, torch.optim.Optimizer] = None, @@ -313,6 +313,10 @@ def train( cpu: bool = False, mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no", output_dir: Union[Path, str] = Path("artifacts"), + output_model_dir: Optional[Union[Path, str]] = None, + save_model: bool = True, + logger: bool = True, + config_meta: Dict, **kwargs, ): """ @@ -385,6 +389,15 @@ def train( The output directory, which will contain a `model-last` directory with the last model, and a `train_metrics.json` file with the training metrics and stats. + output_model_dir: Optional[Union[Path, str]] + The directory where to save the model. If None, defaults to + `output_dir / "model-last"`. + save_model: bool + Whether to save the model or not. This can be useful if you are only + interested in the metrics, but no the model, and want to avoid + spending time dumping the model weights to the disk. + logger: bool + Whether to log the validation metrics in a rich table. kwargs: Dict Additional keyword arguments. @@ -398,13 +411,14 @@ def train( # accelerator.register_for_checkpointing(dataset) is_main_process = accelerator.is_main_process device = accelerator.device - print("Starting training on device:", device) + accelerator.print(config_meta["unresolved_config"].to_yaml_str()) output_dir = Path(output_dir or Path.cwd() / "artifacts") - model_path = output_dir / "model-last" + output_model_dir = output_model_dir or output_dir / "model-last" train_metrics_path = output_dir / "train_metrics.json" if is_main_process: os.makedirs(output_dir, exist_ok=True) + config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml") validation_interval = validation_interval or max_steps // 10 checkpoint_interval = checkpoint_interval or validation_interval @@ -501,7 +515,7 @@ def train( set_seed(seed) with ( RichTablePrinter(LOGGER_FIELDS, auto_refresh=False) - if is_main_process + if is_main_process and logger else nullcontext() ) as logger: # Training loop @@ -526,10 +540,15 @@ def train( ) cumulated_data.clear() train_metrics_path.write_text(json.dumps(all_metrics, indent=2)) - logger.log_metrics(flatten_dict(all_metrics[-1])) + if logger: + logger.log_metrics(flatten_dict(all_metrics[-1])) - if is_main_process and (step % checkpoint_interval) == 0: - nlp.to_disk(model_path) + if ( + save_model + and is_main_process + and (step % checkpoint_interval) == 0 + ): + nlp.to_disk(output_model_dir) if step == max_steps: break @@ -572,7 +591,7 @@ def train( res[f"{name}_loss"] = res["loss"] for k, v in res.items(): if ( - isinstance(v, float) + isinstance(v, (float, int)) or isinstance(v, torch.Tensor) and v.ndim == 0 ): From 5e8fe5dd0ff901a2a4d4e0cb217ebc8ffeed7695 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 28 Nov 2024 16:57:55 +0100 Subject: [PATCH 2/2] fix: improve gradient accumulation support --- changelog.md | 1 + edsnlp/training/trainer.py | 85 +++++++++++++++++++++++------------ tests/training/qlf_config.yml | 2 +- 3 files changed, 59 insertions(+), 29 deletions(-) diff --git a/changelog.md b/changelog.md index 773bb8ff1..4e2f7f5f5 100644 --- a/changelog.md +++ b/changelog.md @@ -19,6 +19,7 @@ 1. reproducibility 2. in multiprocessing mode, ensure that the same data is shuffled in the same way in all workers - Bubble BaseComponent instantiation errors correctly +- Improved support for multi-gpu gradient accumulation (only sync the gradients at the end of the accumulation), now controled by the optiona `sub_batch_size` argument of `TrainingData`. ## v0.14.0 (2024-11-14) diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 8abb12612..56ac7f0af 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -13,6 +13,7 @@ Dict, Iterable, Optional, + Sequence, Union, ) @@ -230,7 +231,7 @@ def __init__( data: Stream, batch_size: BatchSizeArg, shuffle: str, - accumulation_batch_size: Optional[BatchSizeArg] = None, + sub_batch_size: Optional[BatchSizeArg] = None, pipe_names: Optional[Collection[str]] = None, post_init: bool = True, ): @@ -256,7 +257,7 @@ def __init__( datasets), "fragment" to shuffle the fragment-based datasets like parquet files, or a batching expression like "2000 words" to shuffle the dataset in chunks of 2000 words. - accumulation_batch_size: Optional[BatchSizeArg] + sub_batch_size: Optional[BatchSizeArg] How to split each batch into sub-batches that will be fed to the model independently to accumulate gradients over. pipe_names: Optional[Collection[str]] @@ -269,7 +270,7 @@ def __init__( self.data = data self.batch_size = batch_size self.shuffle = shuffle - self.accumulation_batch_size = accumulation_batch_size + self.sub_batch_size = sub_batch_size self.pipe_names = set(pipe_names) if pipe_names else None self.post_init = post_init @@ -282,12 +283,12 @@ def __call__(self, nlp, device): data = data.map(nlp.preprocess, kwargs=dict(supervision=True)) batcher = stat_batchify(self.batch_size[1] or "docs") data = data.batchify(batch_size=self.batch_size[0], batch_by=batcher) - if self.accumulation_batch_size: - sub_batcher = stat_batchify(self.accumulation_batch_size[1] or "docs") + if self.sub_batch_size: + sub_batcher = stat_batchify(self.sub_batch_size[1] or "docs") data = data.map( lambda batch: [ - nlp.collate(sub_batch) - for sub_batch in sub_batcher(batch, self.accumulation_batch_size[0]) + nlp.collate(sub_batch, device=device) + for sub_batch in sub_batcher(batch, self.sub_batch_size[0]) ] ) else: @@ -295,6 +296,27 @@ def __call__(self, nlp, device): return data +class PipeDict(torch.nn.ModuleDict): + def __init__(self, pipes, loss_scales): + super().__init__(pipes) + self.loss_scales = loss_scales + + def forward(self, batch, enable: Optional[Sequence[str]] = None): + loss = None + all_results = {} + for name, pipe in self.items(): + if enable is None or name in enable: + res = pipe(batch[name]) + all_results[name] = res + if "loss" in res: + res["loss"] = res["loss"] * self.loss_scales.get(name, 1) + loss = res["loss"] if loss is None else loss + res["loss"] + if torch.isnan(loss): + raise ValueError(f"NaN loss at component {name}") + res[f"{name}_loss"] = res["loss"] + return all_results, loss + + @validate_arguments(registry=registry) def train( *, @@ -316,7 +338,7 @@ def train( output_model_dir: Optional[Union[Path, str]] = None, save_model: bool = True, logger: bool = True, - config_meta: Dict, + config_meta: Optional[Dict] = None, **kwargs, ): """ @@ -411,14 +433,15 @@ def train( # accelerator.register_for_checkpointing(dataset) is_main_process = accelerator.is_main_process device = accelerator.device - accelerator.print(config_meta["unresolved_config"].to_yaml_str()) output_dir = Path(output_dir or Path.cwd() / "artifacts") output_model_dir = output_model_dir or output_dir / "model-last" train_metrics_path = output_dir / "train_metrics.json" if is_main_process: os.makedirs(output_dir, exist_ok=True) - config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml") + if config_meta is not None: # pragma: no cover + print(config_meta["unresolved_config"].to_yaml_str()) + config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml") validation_interval = validation_interval or max_steps // 10 checkpoint_interval = checkpoint_interval or validation_interval @@ -457,8 +480,8 @@ def train( nlp.post_init(chain_zip([td.data for td in train_data if td.post_init])) for phase_i, pipe_names in enumerate(phases): - trained_pipes = [nlp.get_pipe(name) for name in pipe_names] - trained_pipes_params = {p for pipe in trained_pipes for p in pipe.parameters()} + trained_pipes = PipeDict({n: nlp.get_pipe(n) for n in pipe_names}, loss_scales) + trained_pipes_params = set(trained_pipes.parameters()) phase_training_data = [ td for td in train_data @@ -506,7 +529,7 @@ def train( ) ) ) - (accel_optim, *trained_pipes) = accelerator.prepare(optim, *trained_pipes) + (accel_optim, trained_pipes) = accelerator.prepare(optim, trained_pipes) if hasattr(accel_optim.optimizer, "initialize"): accel_optim.optimizer.initialize() @@ -578,17 +601,23 @@ def train( set_flat_stats(b, batch_stats) res_stats = defaultdict(lambda: 0.0) - for batch, batch_pipe_names in zip(batches, batches_pipe_names): - loss = torch.zeros((), device=accelerator.device) - with nlp.cache(): - for name, pipe in zip(pipe_names, trained_pipes): - if name not in batch_pipe_names: - continue - res = dict(pipe(batch[name])) - if "loss" in res: - res["loss"] = res["loss"] * loss_scales.get(name, 1) - loss += res["loss"] - res[f"{name}_loss"] = res["loss"] + for idx, (batch, batch_pipe_names) in enumerate( + zip(batches, batches_pipe_names) + ): + cache_ctx = ( + nlp.cache() if len(batch_pipe_names) > 1 else nullcontext() + ) + no_sync_ctx = ( + accelerator.no_sync(trained_pipes) + if idx < len(batches) - 1 + else nullcontext() + ) + with cache_ctx, no_sync_ctx: + all_res, loss = trained_pipes( + batch, + enable=batch_pipe_names, + ) + for name, res in all_res.items(): for k, v in res.items(): if ( isinstance(v, (float, int)) @@ -596,10 +625,10 @@ def train( and v.ndim == 0 ): res_stats[k] += float(v) - if torch.isnan(loss): - raise ValueError(f"NaN loss at component {name}") - del k, v, res, pipe - accelerator.backward(loss) + del k, v + del res + del all_res + accelerator.backward(loss) del loss # Sync output stats after forward such as losses, supports, etc. diff --git a/tests/training/qlf_config.yml b/tests/training/qlf_config.yml index 884a8e349..afd6da65e 100644 --- a/tests/training/qlf_config.yml +++ b/tests/training/qlf_config.yml @@ -81,7 +81,7 @@ train_data: shuffle: dataset batch_size: 4 docs pipe_names: [ "qualifier" ] - accumulation_batch_size: 10 words + sub_batch_size: 10 words val_data: "@readers": json