From b18a8d788326caa42c10cf22e74e2f9c1688e14b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Wed, 19 Feb 2025 08:57:30 +0100 Subject: [PATCH] feat: compute eds.ner_crf loss as mean over words --- changelog.md | 5 ++ .../trainable/extractive_qa/extractive_qa.py | 4 +- edsnlp/pipes/trainable/ner_crf/ner_crf.py | 57 ++++++++++++------- edsnlp/utils/batching.py | 11 +++- 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/changelog.md b/changelog.md index bceb2c3616..1eb1dde93f 100644 --- a/changelog.md +++ b/changelog.md @@ -11,6 +11,11 @@ - Provided a [detailed tutorial](./docs/tutorials/tuning.md) on hyperparameter tuning, covering usage scenarios and configuration options. - `ScheduledOptimizer` (e.g., `@core: "optimizer"`) now supports importing optimizers using their qualified name (e.g., `optim: "torch.optim.Adam"`). +### Changed + +- The loss of `eds.ner_crf` is now computed as the mean over the words instead of the sum. This change is compatible with multi-gpu training. +- Having multiple stats keys matching a batching pattern now warns instead of raising an error. + ### Fixed - Support packaging with poetry 2.0 diff --git a/edsnlp/pipes/trainable/extractive_qa/extractive_qa.py b/edsnlp/pipes/trainable/extractive_qa/extractive_qa.py index 4cb0c5385c..d1dc883bf1 100644 --- a/edsnlp/pipes/trainable/extractive_qa/extractive_qa.py +++ b/edsnlp/pipes/trainable/extractive_qa/extractive_qa.py @@ -217,8 +217,9 @@ def preprocess(self, doc, **kwargs): questions = [x[0] for x in prompt_contexts_and_labels] labels = [x[1] for x in prompt_contexts_and_labels] ctxs = [x[2] for x in prompt_contexts_and_labels] + lengths = [len(ctx) for ctx in ctxs] return { - "lengths": [len(ctx) for ctx in ctxs], + "lengths": lengths, "$labels": labels, "$contexts": ctxs, "embedding": self.embedding.preprocess( @@ -227,6 +228,7 @@ def preprocess(self, doc, **kwargs): prompts=questions, **kwargs, ), + "stats": {"ner_words": sum(lengths)}, } def preprocess_supervised(self, doc, **kwargs): diff --git a/edsnlp/pipes/trainable/ner_crf/ner_crf.py b/edsnlp/pipes/trainable/ner_crf/ner_crf.py index 5c2fda0988..73ef63bab0 100644 --- a/edsnlp/pipes/trainable/ner_crf/ner_crf.py +++ b/edsnlp/pipes/trainable/ner_crf/ner_crf.py @@ -35,6 +35,7 @@ "targets": NotRequired[torch.Tensor], "window_indices": NotRequired[torch.Tensor], "window_indexer": NotRequired[torch.Tensor], + "stats": Dict[str, int], }, ) NERBatchOutput = TypedDict( @@ -344,10 +345,12 @@ def preprocess(self, doc, **kwargs): ) ctxs = get_spans(doc, self.context_getter) if self.context_getter else [doc[:]] + lengths = [len(ctx) for ctx in ctxs] return { "embedding": self.embedding.preprocess(doc, contexts=ctxs, **kwargs), - "lengths": [len(ctx) for ctx in ctxs], + "lengths": lengths, "$contexts": ctxs, + "stats": {"ner_words": sum(lengths)}, } def preprocess_supervised(self, doc, **kwargs): @@ -389,9 +392,8 @@ def preprocess_supervised(self, doc, **kwargs): if discarded: warnings.warn( - f"Some spans were discarded in {doc._.note_id} (" - f"{', '.join(repr(d.text) for d in discarded)}) because they " - f"were overlapping with other spans with the same label." + "Some spans were discarded in the training data because they " + "were overlapping with other spans with the same label." ) return { @@ -402,6 +404,9 @@ def preprocess_supervised(self, doc, **kwargs): def collate(self, preps) -> NERBatchInput: collated: NERBatchInput = { "embedding": self.embedding.collate(preps["embedding"]), + "stats": { + k: sum(v) for k, v in preps["stats"].items() if not k.startswith("__") + }, } lengths = [length for sample in preps["lengths"] for length in sample] max_len = max(lengths) @@ -437,27 +442,37 @@ def forward(self, batch: NERBatchInput) -> NERBatchOutput: loss = tags = None if "targets" in batch: if self.mode == "independent": - loss = torch.nn.functional.cross_entropy( - scores.view(-1, 5), - batch["targets"].view(-1), - ignore_index=-1, - reduction="sum", + loss = ( + torch.nn.functional.cross_entropy( + scores.view(-1, 5), + batch["targets"].view(-1), + ignore_index=-1, + reduction="sum", + ) + / batch["stats"]["ner_words"] ) elif self.mode == "joint": - loss = self.crf( - scores, - mask, - batch["targets"].unsqueeze(-1) == torch.arange(5).to(scores.device), - ).sum() - elif self.mode == "marginal": - loss = torch.nn.functional.cross_entropy( - self.crf.marginal( + loss = ( + self.crf( scores, mask, - ).view(-1, 5), - batch["targets"].view(-1), - ignore_index=-1, - reduction="sum", + batch["targets"].unsqueeze(-1) + == torch.arange(5).to(scores.device), + ).sum() + / batch["stats"]["ner_words"] + ) + elif self.mode == "marginal": + loss = ( + torch.nn.functional.cross_entropy( + self.crf.marginal( + scores, + mask, + ).view(-1, 5), + batch["targets"].view(-1), + ignore_index=-1, + reduction="sum", + ) + / batch["stats"]["ner_words"] ) else: if self.window == 1: diff --git a/edsnlp/utils/batching.py b/edsnlp/utils/batching.py index 8a70941c38..c719e39971 100644 --- a/edsnlp/utils/batching.py +++ b/edsnlp/utils/batching.py @@ -1,3 +1,4 @@ +import warnings from typing import ( TYPE_CHECKING, Callable, @@ -435,10 +436,16 @@ def rec( if exact_key is None: candidates = [k for k in item if "/stats/" in k and key in k] if len(candidates) != 1: - raise ValueError( - f"Batching key {key!r} should match exactly one " + warnings.warn( + f"Batching key {key!r} should match one " f"candidate in {[k for k in item if '/stats/' in k]}" ) + if len(candidates) == 0: + stat_keys = [k for k in item if "/stats/"] + raise ValueError( + f"Pattern {key!r} doesn't match any key in {stat_keys} " + " to determine the batch size." + ) exact_key = candidates[0] value = item[exact_key] if num_items > 0 and total + value > batch_size: