-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_mtsamples.py
147 lines (124 loc) · 4.4 KB
/
test_mtsamples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json
import logging
from pathlib import Path
from medkit.core.doc_pipeline import DocPipeline
from medkit.core.pipeline import Pipeline, PipelineStep
from medkit.core.text import TextDocument
from medkit.text.context import NegationDetector
from medkit.text.ner import RegexpMatcher
from medkit.text.preprocessing import RegexpReplacer
from medkit.text.segmentation import SentenceTokenizer
_PATH_TO_MTSAMPLES = Path(__file__).parent / ".." / "data" / "mtsamples"
def _get_medkit_docs():
path = _PATH_TO_MTSAMPLES / "mtsamples_translated.json"
if not path.exists():
msg = (
"For running this test, you need to have mtsamples_translated.json file in"
" `tests/data/mtsamples` folder.\nThe file is not provided with medkit"
" library. Please contact us to get this file."
)
raise FileNotFoundError(msg)
with Path(path).open() as fp:
dataset = json.load(fp)
docs = []
for data in dataset:
metadata = {}
text = ""
for key, value in data.items():
if key == "transcription_translated":
text = value
else:
metadata[key] = value
docs.append(TextDocument(text=text, metadata=metadata))
return docs
def test_mt_samples_without_pipeline(caplog):
docs = _get_medkit_docs()
assert len(docs) == 4999
# init and configure operations
rules = [
(r"[nN]\s*°", "numéro"),
(r"(?<=[0-9]\s)°", " degrés"),
(r"(?<=[0-9])°", " degrés"),
("\u00c6", "AE"), # ascii
("\u00e6", "ae"), # ascii
("\u0152", "OE"), # ascii
("\u0153", "oe"), # ascii
(r"«|»", '"'),
("®|©", ""),
("½", "1/2"), # ascii
("…", "..."), # ascii
("¼", "1/4"), # ascii
]
regexp_replacer = RegexpReplacer(output_label="norm_text", rules=rules)
sentence_tokenizer = SentenceTokenizer()
negation_detector = NegationDetector(output_label="negation")
regexp_matcher = RegexpMatcher(attrs_to_copy=["negation"])
# annotate each doc
nb_tot_anns = 0
for doc in docs:
anns = [doc.raw_segment]
anns = regexp_replacer.run(anns)
anns = sentence_tokenizer.run(anns)
with caplog.at_level(logging.WARNING, logger="medkit.text.context.negation_detector"):
negation_detector.run(anns)
assert len(caplog.messages) == 0
with caplog.at_level(logging.WARNING, logger="medkit.text.context.regexp_matcher"):
anns = regexp_matcher.run(anns)
assert len(caplog.messages) == 0
for ann in anns:
doc.anns.add(ann)
nb_tot_anns += len(doc.anns)
assert nb_tot_anns == 13631
def test_mt_samples_with_doc_pipeline():
docs = _get_medkit_docs()
assert len(docs) == 4999
# init and configure operations
rules = [
(r"[nN]\s*°", "numéro"),
(r"(?<=[0-9]\s)°", " degrés"),
(r"(?<=[0-9])°", " degrés"),
("\u00c6", "AE"), # ascii
("\u00e6", "ae"), # ascii
("\u0152", "OE"), # ascii
("\u0153", "oe"), # ascii
(r"«|»", '"'),
("®|©", ""),
("½", "1/2"), # ascii
("…", "..."), # ascii
("¼", "1/4"), # ascii
]
char_replacer = PipelineStep(
operation=RegexpReplacer(output_label="norm_text", rules=rules),
input_keys=["full_text"],
output_keys=["norm_text"],
)
sentence_tokenizer = PipelineStep(
operation=SentenceTokenizer(),
input_keys=["norm_text"],
output_keys=["sentences"],
)
negation_detector = PipelineStep(
operation=NegationDetector(output_label="negation"),
input_keys=["sentences"],
output_keys=[],
)
regexp_matcher = PipelineStep(
operation=RegexpMatcher(attrs_to_copy=["negation"]),
input_keys=["sentences"],
output_keys=["entities"],
)
pipeline = Pipeline(
steps=[char_replacer, sentence_tokenizer, negation_detector, regexp_matcher],
input_keys=char_replacer.input_keys,
output_keys=regexp_matcher.output_keys,
)
doc_pipeline = DocPipeline(
pipeline=pipeline,
labels_by_input_key={"full_text": [TextDocument.RAW_LABEL]},
)
# annotate each doc
nb_tot_anns = 0
doc_pipeline.run(docs)
for doc in docs:
nb_tot_anns += len(doc.anns)
assert nb_tot_anns == 13631