diff --git a/changelog.md b/changelog.md index 0c1d2e749..e993ed824 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,11 @@ # Changelog +## Unreleased + +### Fixed + +- Numbers are now only detected without trying to remove the pollution in between digits, ie `55 @ 77777` could be detected as a full number before, but not anymore. + ## v0.13.0 ### Added diff --git a/edsnlp/pipes/misc/measurements/measurements.py b/edsnlp/pipes/misc/measurements/measurements.py index f86fe8813..b7a42c0c4 100644 --- a/edsnlp/pipes/misc/measurements/measurements.py +++ b/edsnlp/pipes/misc/measurements/measurements.py @@ -714,7 +714,12 @@ def __init__( self.unitless_patterns[pattern_name] = {"name": name, **pattern} # NUMBER PATTERNS - self.regex_matcher.add("number", [number_regex]) + self.regex_matcher.add( + "number", + [number_regex], + ignore_excluded=False, + ignore_space_tokens=False, + ) self.number_label_hashes = {nlp.vocab.strings["number"]} for number, terms in number_terms.items(): self.term_matcher.build_patterns(nlp, {number: terms}) diff --git a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py index eb8be14b4..e6abd939b 100644 --- a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py +++ b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py @@ -207,7 +207,20 @@ def to_disk(self, path, *, exclude: Optional[Set[str]]): if repr_id in exclude: return self.tokenizer.save_pretrained(path) + + # Fix for https://github.com/aphp/edsnlp/issues/317 + old_params_data = {} + for param in self.transformer.parameters(): + if not param.is_contiguous(): + old_params_data[param] = param.data + param.data = param.data.contiguous() + self.transformer.save_pretrained(path) + + # Restore non-contiguous tensors + for param, data in old_params_data.items(): + param.data = data + for param in self.transformer.parameters(): exclude.add(object.__repr__(param)) cfg = super().to_disk(path, exclude=exclude) or {} diff --git a/tests/pipelines/misc/test_measurements.py b/tests/pipelines/misc/test_measurements.py index db2fed001..b7461be3e 100644 --- a/tests/pipelines/misc/test_measurements.py +++ b/tests/pipelines/misc/test_measurements.py @@ -226,6 +226,7 @@ def test_numbers(blank_nlp: PipelineProtocol, matcher: MeasurementsMatcher): ("2 m", "2 m"), ("⅛ m", "0.125 m"), ("0 m", "0 m"), + ("55 @ 77777 cm", "77777 cm"), ]: doc = blank_nlp(text) doc = matcher(doc)