Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
### Added

- `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster.
- Accept no validation data in `edsnlp.train` script
- Log the training config at the beginning of the trainings
- Support a specific model output dir path for trainings (`output_model_dir`), and whether to save the model or not (`save_model`)
- Specify whether to log the validation results or not (`logger=False`)

### Fixed

Expand All @@ -15,6 +19,7 @@
1. reproducibility
2. in multiprocessing mode, ensure that the same data is shuffled in the same way in all workers
- Bubble BaseComponent instantiation errors correctly
- Improved support for multi-gpu gradient accumulation (only sync the gradients at the end of the accumulation), now controled by the optiona `sub_batch_size` argument of `TrainingData`.

## v0.14.0 (2024-11-14)

Expand Down
114 changes: 81 additions & 33 deletions edsnlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Dict,
Iterable,
Optional,
Sequence,
Union,
)

Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(
data: Stream,
batch_size: BatchSizeArg,
shuffle: str,
accumulation_batch_size: Optional[BatchSizeArg] = None,
sub_batch_size: Optional[BatchSizeArg] = None,
pipe_names: Optional[Collection[str]] = None,
post_init: bool = True,
):
Expand All @@ -256,7 +257,7 @@ def __init__(
datasets), "fragment" to shuffle the fragment-based datasets
like parquet files, or a batching expression like "2000 words"
to shuffle the dataset in chunks of 2000 words.
accumulation_batch_size: Optional[BatchSizeArg]
sub_batch_size: Optional[BatchSizeArg]
How to split each batch into sub-batches that will be fed to
the model independently to accumulate gradients over.
pipe_names: Optional[Collection[str]]
Expand All @@ -269,7 +270,7 @@ def __init__(
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.accumulation_batch_size = accumulation_batch_size
self.sub_batch_size = sub_batch_size
self.pipe_names = set(pipe_names) if pipe_names else None
self.post_init = post_init

Expand All @@ -282,25 +283,46 @@ def __call__(self, nlp, device):
data = data.map(nlp.preprocess, kwargs=dict(supervision=True))
batcher = stat_batchify(self.batch_size[1] or "docs")
data = data.batchify(batch_size=self.batch_size[0], batch_by=batcher)
if self.accumulation_batch_size:
sub_batcher = stat_batchify(self.accumulation_batch_size[1] or "docs")
if self.sub_batch_size:
sub_batcher = stat_batchify(self.sub_batch_size[1] or "docs")
data = data.map(
lambda batch: [
nlp.collate(sub_batch)
for sub_batch in sub_batcher(batch, self.accumulation_batch_size[0])
nlp.collate(sub_batch, device=device)
for sub_batch in sub_batcher(batch, self.sub_batch_size[0])
]
)
else:
data = data.map(nlp.collate, kwargs=dict(device=device))
return data


class PipeDict(torch.nn.ModuleDict):
def __init__(self, pipes, loss_scales):
super().__init__(pipes)
self.loss_scales = loss_scales

def forward(self, batch, enable: Optional[Sequence[str]] = None):
loss = None
all_results = {}
for name, pipe in self.items():
if enable is None or name in enable:
res = pipe(batch[name])
all_results[name] = res
if "loss" in res:
res["loss"] = res["loss"] * self.loss_scales.get(name, 1)
loss = res["loss"] if loss is None else loss + res["loss"]
if torch.isnan(loss):
raise ValueError(f"NaN loss at component {name}")
res[f"{name}_loss"] = res["loss"]
return all_results, loss


@validate_arguments(registry=registry)
def train(
*,
nlp: Pipeline,
train_data: AsList[TrainingData],
val_data: AsList[Stream],
val_data: AsList[Stream] = [],
seed: int = 42,
max_steps: int = 1000,
optimizer: Union[ScheduledOptimizer, torch.optim.Optimizer] = None,
Expand All @@ -313,6 +335,10 @@ def train(
cpu: bool = False,
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no",
output_dir: Union[Path, str] = Path("artifacts"),
output_model_dir: Optional[Union[Path, str]] = None,
save_model: bool = True,
logger: bool = True,
config_meta: Optional[Dict] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -385,6 +411,15 @@ def train(
The output directory, which will contain a `model-last` directory
with the last model, and a `train_metrics.json` file with the
training metrics and stats.
output_model_dir: Optional[Union[Path, str]]
The directory where to save the model. If None, defaults to
`output_dir / "model-last"`.
save_model: bool
Whether to save the model or not. This can be useful if you are only
interested in the metrics, but no the model, and want to avoid
spending time dumping the model weights to the disk.
logger: bool
Whether to log the validation metrics in a rich table.
kwargs: Dict
Additional keyword arguments.

Expand All @@ -398,13 +433,15 @@ def train(
# accelerator.register_for_checkpointing(dataset)
is_main_process = accelerator.is_main_process
device = accelerator.device
print("Starting training on device:", device)

output_dir = Path(output_dir or Path.cwd() / "artifacts")
model_path = output_dir / "model-last"
output_model_dir = output_model_dir or output_dir / "model-last"
train_metrics_path = output_dir / "train_metrics.json"
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
if config_meta is not None: # pragma: no cover
print(config_meta["unresolved_config"].to_yaml_str())
config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml")

validation_interval = validation_interval or max_steps // 10
checkpoint_interval = checkpoint_interval or validation_interval
Expand Down Expand Up @@ -443,8 +480,8 @@ def train(
nlp.post_init(chain_zip([td.data for td in train_data if td.post_init]))

for phase_i, pipe_names in enumerate(phases):
trained_pipes = [nlp.get_pipe(name) for name in pipe_names]
trained_pipes_params = {p for pipe in trained_pipes for p in pipe.parameters()}
trained_pipes = PipeDict({n: nlp.get_pipe(n) for n in pipe_names}, loss_scales)
trained_pipes_params = set(trained_pipes.parameters())
phase_training_data = [
td
for td in train_data
Expand Down Expand Up @@ -492,7 +529,7 @@ def train(
)
)
)
(accel_optim, *trained_pipes) = accelerator.prepare(optim, *trained_pipes)
(accel_optim, trained_pipes) = accelerator.prepare(optim, trained_pipes)
if hasattr(accel_optim.optimizer, "initialize"):
accel_optim.optimizer.initialize()

Expand All @@ -501,7 +538,7 @@ def train(
set_seed(seed)
with (
RichTablePrinter(LOGGER_FIELDS, auto_refresh=False)
if is_main_process
if is_main_process and logger
else nullcontext()
) as logger:
# Training loop
Expand All @@ -526,10 +563,15 @@ def train(
)
cumulated_data.clear()
train_metrics_path.write_text(json.dumps(all_metrics, indent=2))
logger.log_metrics(flatten_dict(all_metrics[-1]))
if logger:
logger.log_metrics(flatten_dict(all_metrics[-1]))

if is_main_process and (step % checkpoint_interval) == 0:
nlp.to_disk(model_path)
if (
save_model
and is_main_process
and (step % checkpoint_interval) == 0
):
nlp.to_disk(output_model_dir)

if step == max_steps:
break
Expand Down Expand Up @@ -559,28 +601,34 @@ def train(
set_flat_stats(b, batch_stats)

res_stats = defaultdict(lambda: 0.0)
for batch, batch_pipe_names in zip(batches, batches_pipe_names):
loss = torch.zeros((), device=accelerator.device)
with nlp.cache():
for name, pipe in zip(pipe_names, trained_pipes):
if name not in batch_pipe_names:
continue
res = dict(pipe(batch[name]))
if "loss" in res:
res["loss"] = res["loss"] * loss_scales.get(name, 1)
loss += res["loss"]
res[f"{name}_loss"] = res["loss"]
for idx, (batch, batch_pipe_names) in enumerate(
zip(batches, batches_pipe_names)
):
cache_ctx = (
nlp.cache() if len(batch_pipe_names) > 1 else nullcontext()
)
no_sync_ctx = (
accelerator.no_sync(trained_pipes)
if idx < len(batches) - 1
else nullcontext()
)
with cache_ctx, no_sync_ctx:
all_res, loss = trained_pipes(
batch,
enable=batch_pipe_names,
)
for name, res in all_res.items():
for k, v in res.items():
if (
isinstance(v, float)
isinstance(v, (float, int))
or isinstance(v, torch.Tensor)
and v.ndim == 0
):
res_stats[k] += float(v)
if torch.isnan(loss):
raise ValueError(f"NaN loss at component {name}")
del k, v, res, pipe
accelerator.backward(loss)
del k, v
del res
del all_res
accelerator.backward(loss)
del loss

# Sync output stats after forward such as losses, supports, etc.
Expand Down
2 changes: 1 addition & 1 deletion tests/training/qlf_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ train_data:
shuffle: dataset
batch_size: 4 docs
pipe_names: [ "qualifier" ]
accumulation_batch_size: 10 words
sub_batch_size: 10 words

val_data:
"@readers": json
Expand Down
Loading