Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
- Provided a [detailed tutorial](./docs/tutorials/tuning.md) on hyperparameter tuning, covering usage scenarios and configuration options.
- `ScheduledOptimizer` (e.g., `@core: "optimizer"`) now supports importing optimizers using their qualified name (e.g., `optim: "torch.optim.Adam"`).

### Changed

- The loss of `eds.ner_crf` is now computed as the mean over the words instead of the sum. This change is compatible with multi-gpu training.
- Having multiple stats keys matching a batching pattern now warns instead of raising an error.

### Fixed

- Support packaging with poetry 2.0
Expand Down
4 changes: 3 additions & 1 deletion edsnlp/pipes/trainable/extractive_qa/extractive_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,9 @@ def preprocess(self, doc, **kwargs):
questions = [x[0] for x in prompt_contexts_and_labels]
labels = [x[1] for x in prompt_contexts_and_labels]
ctxs = [x[2] for x in prompt_contexts_and_labels]
lengths = [len(ctx) for ctx in ctxs]
return {
"lengths": [len(ctx) for ctx in ctxs],
"lengths": lengths,
"$labels": labels,
"$contexts": ctxs,
"embedding": self.embedding.preprocess(
Expand All @@ -227,6 +228,7 @@ def preprocess(self, doc, **kwargs):
prompts=questions,
**kwargs,
),
"stats": {"ner_words": sum(lengths)},
}

def preprocess_supervised(self, doc, **kwargs):
Expand Down
57 changes: 36 additions & 21 deletions edsnlp/pipes/trainable/ner_crf/ner_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"targets": NotRequired[torch.Tensor],
"window_indices": NotRequired[torch.Tensor],
"window_indexer": NotRequired[torch.Tensor],
"stats": Dict[str, int],
},
)
NERBatchOutput = TypedDict(
Expand Down Expand Up @@ -344,10 +345,12 @@ def preprocess(self, doc, **kwargs):
)

ctxs = get_spans(doc, self.context_getter) if self.context_getter else [doc[:]]
lengths = [len(ctx) for ctx in ctxs]
return {
"embedding": self.embedding.preprocess(doc, contexts=ctxs, **kwargs),
"lengths": [len(ctx) for ctx in ctxs],
"lengths": lengths,
"$contexts": ctxs,
"stats": {"ner_words": sum(lengths)},
}

def preprocess_supervised(self, doc, **kwargs):
Expand Down Expand Up @@ -389,9 +392,8 @@ def preprocess_supervised(self, doc, **kwargs):

if discarded:
warnings.warn(
f"Some spans were discarded in {doc._.note_id} ("
f"{', '.join(repr(d.text) for d in discarded)}) because they "
f"were overlapping with other spans with the same label."
"Some spans were discarded in the training data because they "
"were overlapping with other spans with the same label."
)

return {
Expand All @@ -402,6 +404,9 @@ def preprocess_supervised(self, doc, **kwargs):
def collate(self, preps) -> NERBatchInput:
collated: NERBatchInput = {
"embedding": self.embedding.collate(preps["embedding"]),
"stats": {
k: sum(v) for k, v in preps["stats"].items() if not k.startswith("__")
},
}
lengths = [length for sample in preps["lengths"] for length in sample]
max_len = max(lengths)
Expand Down Expand Up @@ -437,27 +442,37 @@ def forward(self, batch: NERBatchInput) -> NERBatchOutput:
loss = tags = None
if "targets" in batch:
if self.mode == "independent":
loss = torch.nn.functional.cross_entropy(
scores.view(-1, 5),
batch["targets"].view(-1),
ignore_index=-1,
reduction="sum",
loss = (
torch.nn.functional.cross_entropy(
scores.view(-1, 5),
batch["targets"].view(-1),
ignore_index=-1,
reduction="sum",
)
/ batch["stats"]["ner_words"]
)
elif self.mode == "joint":
loss = self.crf(
scores,
mask,
batch["targets"].unsqueeze(-1) == torch.arange(5).to(scores.device),
).sum()
elif self.mode == "marginal":
loss = torch.nn.functional.cross_entropy(
self.crf.marginal(
loss = (
self.crf(
scores,
mask,
).view(-1, 5),
batch["targets"].view(-1),
ignore_index=-1,
reduction="sum",
batch["targets"].unsqueeze(-1)
== torch.arange(5).to(scores.device),
).sum()
/ batch["stats"]["ner_words"]
)
elif self.mode == "marginal":
loss = (
torch.nn.functional.cross_entropy(
self.crf.marginal(
scores,
mask,
).view(-1, 5),
batch["targets"].view(-1),
ignore_index=-1,
reduction="sum",
)
/ batch["stats"]["ner_words"]
)
else:
if self.window == 1:
Expand Down
11 changes: 9 additions & 2 deletions edsnlp/utils/batching.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import (
TYPE_CHECKING,
Callable,
Expand Down Expand Up @@ -435,10 +436,16 @@ def rec(
if exact_key is None:
candidates = [k for k in item if "/stats/" in k and key in k]
if len(candidates) != 1:
raise ValueError(
f"Batching key {key!r} should match exactly one "
warnings.warn(
f"Batching key {key!r} should match one "
f"candidate in {[k for k in item if '/stats/' in k]}"
)
if len(candidates) == 0:
stat_keys = [k for k in item if "/stats/"]
raise ValueError(
f"Pattern {key!r} doesn't match any key in {stat_keys} "
" to determine the batch size."
)
exact_key = candidates[0]
value = item[exact_key]
if num_items > 0 and total + value > batch_size:
Expand Down
Loading