From e325925f7752c96527232522f152e3c1f3a5c51a Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 23 Sep 2025 11:51:42 +0100 Subject: [PATCH 1/8] CU-869ak0v7n: Small refactor and refinement in terms of typing for meta cat utils --- .../components/addons/meta_cat/data_utils.py | 190 ++++++++++-------- 1 file changed, 104 insertions(+), 86 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py index 7892b9ae..8999e535 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Iterator import copy from medcat.components.addons.meta_cat.mctokenizers.tokenizers import ( @@ -15,7 +15,8 @@ def prepare_from_json(data: dict, cui_filter: Optional[set] = None, replace_center: Optional[str] = None, prerequisites: dict = {}, - lowercase: bool = True) -> dict: + lowercase: bool = True + ) -> dict[str, list[tuple[list, list, str]]]: """Convert the data from a json format into a CSV-like format for training. This function is not very efficient (the one working with documents as part of the meta_cat.pipe method is much better). @@ -64,91 +65,107 @@ def prepare_from_json(data: dict, if len(text) > 0: doc_text = tokenizer(text) + for name, sample in _prepare_from_json_loop( + document, prerequisites, cui_filter, doc_text, + cntx_left, cntx_right, lowercase, replace_center, + tokenizer): + if name in out_data: + out_data[name].append(sample) + else: + out_data[name] = [sample] - for ann in document.get('annotations', document.get( - # A hack to support entities and annotations - 'entities', {}).values()): - cui = ann['cui'] - skip = False - if 'meta_anns' in ann and prerequisites: - # It is possible to require certain meta_anns to exist - # and have a specific value - for meta_ann in prerequisites: - if (meta_ann not in ann['meta_anns'] or - ann['meta_anns'][meta_ann][ - 'value'] != prerequisites[meta_ann]): - # Skip this annotation as the prerequisite - # is not met - skip = True - break - - if not skip and (cui_filter is None or - not cui_filter or cui in cui_filter): - if ann.get('validated', True) and ( - not ann.get('deleted', False) and - not ann.get('killed', False) - and not ann.get('irrelevant', False)): - start = ann['start'] - end = ann['end'] - - # Updated implementation to extract all the tokens - # for the medical entity (rather than the one) - ctoken_idx = [] - for ind, pair in enumerate( - doc_text['offset_mapping']): - if start <= pair[0] or start <= pair[1]: - if end <= pair[1]: - ctoken_idx.append(ind) - break - else: - ctoken_idx.append(ind) - - _start = max(0, ctoken_idx[0] - cntx_left) - _end = min(len(doc_text['input_ids']), - ctoken_idx[-1] + 1 + cntx_right) - - cpos = cntx_left + min(0, ind - cntx_left) - cpos_new = [x - _start for x in ctoken_idx] - tkns = doc_text['input_ids'][_start:_end] - - if replace_center is not None: - if lowercase: - replace_center = replace_center.lower() - for p_ind, pair in enumerate( - doc_text['offset_mapping']): - if start >= pair[0] and start < pair[1]: - s_ind = p_ind - if end > pair[0] and end <= pair[1]: - e_ind = p_ind - - ln = e_ind - s_ind - tkns = tkns[:cpos] + tokenizer( - replace_center)['input_ids'] + tkns[ - cpos + ln + 1:] - - # Backward compatibility if meta_anns is a list vs - # dict in the new approach - meta_anns: list[dict] = [] - if 'meta_anns' in ann: - if isinstance(ann['meta_anns'], dict): - meta_anns.extend(ann['meta_anns'].values()) - else: - meta_anns.extend(ann['meta_anns']) - - # If the annotation is validated - for meta_ann in meta_anns: - name = meta_ann['name'] - value = meta_ann['value'] - - sample = [tkns, cpos_new, value] - - if name in out_data: - out_data[name].append(sample) - else: - out_data[name] = [sample] return out_data +def _prepare_from_json_loop(document: dict, + prerequisites: dict, + cui_filter: Optional[set], + doc_text: dict, + cntx_left: int, + cntx_right: int, + lowercase: bool, + replace_center: Optional[str], + tokenizer: TokenizerWrapperBase, + ) -> Iterator[tuple[str, tuple[list, list, str]]]: + for ann in document.get('annotations', document.get( + # A hack to support entities and annotations + 'entities', {}).values()): + cui = ann['cui'] + skip = False + if 'meta_anns' in ann and prerequisites: + # It is possible to require certain meta_anns to exist + # and have a specific value + for meta_ann in prerequisites: + if (meta_ann not in ann['meta_anns'] or + ann['meta_anns'][meta_ann][ + 'value'] != prerequisites[meta_ann]): + # Skip this annotation as the prerequisite + # is not met + skip = True + break + + if not skip and (cui_filter is None or + not cui_filter or cui in cui_filter): + if ann.get('validated', True) and ( + not ann.get('deleted', False) and + not ann.get('killed', False) + and not ann.get('irrelevant', False)): + start = ann['start'] + end = ann['end'] + + # Updated implementation to extract all the tokens + # for the medical entity (rather than the one) + ctoken_idx = [] + for ind, pair in enumerate( + doc_text['offset_mapping']): + if start <= pair[0] or start <= pair[1]: + if end <= pair[1]: + ctoken_idx.append(ind) + break + else: + ctoken_idx.append(ind) + + _start = max(0, ctoken_idx[0] - cntx_left) + _end = min(len(doc_text['input_ids']), + ctoken_idx[-1] + 1 + cntx_right) + + cpos = cntx_left + min(0, ind - cntx_left) + cpos_new = [x - _start for x in ctoken_idx] + tkns = doc_text['input_ids'][_start:_end] + + if replace_center is not None: + if lowercase: + replace_center = replace_center.lower() + for p_ind, pair in enumerate( + doc_text['offset_mapping']): + if start >= pair[0] and start < pair[1]: + s_ind = p_ind + if end > pair[0] and end <= pair[1]: + e_ind = p_ind + + ln = e_ind - s_ind + tkns = tkns[:cpos] + tokenizer( + replace_center)['input_ids'] + tkns[ + cpos + ln + 1:] + + # Backward compatibility if meta_anns is a list vs + # dict in the new approach + meta_anns: list[dict] = [] + if 'meta_anns' in ann: + if isinstance(ann['meta_anns'], dict): + meta_anns.extend(ann['meta_anns'].values()) + else: + meta_anns.extend(ann['meta_anns']) + + # If the annotation is validated + for meta_ann in meta_anns: + name = meta_ann['name'] + value = meta_ann['value'] + + sample = (tkns, cpos_new, value) + yield name, sample + + def prepare_for_oversampled_data(data: list, tokenizer: TokenizerWrapperBase) -> list: """Convert the data from a json format into a CSV-like format for @@ -189,7 +206,7 @@ def prepare_for_oversampled_data(data: list, return data_sampled -def encode_category_values(data: dict, +def encode_category_values(data: list[tuple[list, list, str]], existing_category_value2id: Optional[dict] = None, category_undersample=None, alternative_class_names: list[list[str]] = [] @@ -198,7 +215,7 @@ def encode_category_values(data: dict, `prepare_from_json` into integer values. Args: - data (dict): + data (list[tuple[list, list, str]]): Output of `prepare_from_json`. existing_category_value2id(Optional[dict]): Map from category_value to id (old/existing). @@ -288,7 +305,8 @@ def encode_category_values(data: dict, # Map values to numbers for i in range(len(data_list)): - data_list[i][2] = category_value2id[data_list[i][2]] + # NOTE: internally, it's a a list so assingment will work + data_list[i][2] = category_value2id[data_list[i][2]] # type: ignore # Creating dict with labels and its number of samples label_data_ = {v: 0 for v in category_value2id.values()} From 1028b01718839bf4c7abc1667aecfd75bb2ab228 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 23 Sep 2025 11:54:21 +0100 Subject: [PATCH 2/8] CU-869ak0v7n: Improve typing for encode category values --- medcat-v2/medcat/components/addons/meta_cat/data_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py index 8999e535..7a9d4b9a 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py @@ -210,7 +210,8 @@ def encode_category_values(data: list[tuple[list, list, str]], existing_category_value2id: Optional[dict] = None, category_undersample=None, alternative_class_names: list[list[str]] = [] - ) -> tuple: + ) -> tuple[ + list[tuple[list, list, str]], list, dict]: """Converts the category values in the data outputted by `prepare_from_json` into integer values. @@ -228,9 +229,9 @@ def encode_category_values(data: list[tuple[list, list, str]], `config.general.alternative_class_names`. Returns: - dict: + list[tuple[list, list, str]]: New data with integers inplace of strings for category values. - dict: + list: New undersampled data (for 2 phase learning) with integers inplace of strings for category values dict: From 0cc88dcd2a32d152367593cf3ab13bd99164ce39 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 23 Sep 2025 11:55:29 +0100 Subject: [PATCH 3/8] CU-869ak0v7n: Improve typing for encode category values (again) --- medcat-v2/medcat/components/addons/meta_cat/data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py index 7a9d4b9a..7a8a8f3b 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py @@ -208,7 +208,7 @@ def prepare_for_oversampled_data(data: list, def encode_category_values(data: list[tuple[list, list, str]], existing_category_value2id: Optional[dict] = None, - category_undersample=None, + category_undersample: Optional[str] = None, alternative_class_names: list[list[str]] = [] ) -> tuple[ list[tuple[list, list, str]], list, dict]: @@ -220,7 +220,7 @@ def encode_category_values(data: list[tuple[list, list, str]], Output of `prepare_from_json`. existing_category_value2id(Optional[dict]): Map from category_value to id (old/existing). - category_undersample: + category_undersample (Optional[str]): Name of class that should be used to undersample the data (for 2 phase learning) alternative_class_names (list[list[str]]): From ae73d4601a5990f1bc60a326640bc0757d8a9a75 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 23 Sep 2025 11:57:11 +0100 Subject: [PATCH 4/8] CU-869ak0v7n: Add meta_anns as an optional part to trainer export annotation typed dict --- medcat-v2/medcat/data/mctexport.py | 1 + 1 file changed, 1 insertion(+) diff --git a/medcat-v2/medcat/data/mctexport.py b/medcat-v2/medcat/data/mctexport.py index 7bbca964..e44c2251 100644 --- a/medcat-v2/medcat/data/mctexport.py +++ b/medcat-v2/medcat/data/mctexport.py @@ -14,6 +14,7 @@ class MedCATTrainerExportAnnotation( MedCATTrainerExportAnnotationRequired, total=False): id: Union[str, int] validated: Optional[bool] + meta_anns: dict[str, dict[str, str]] class MedCATTrainerExportDocument(TypedDict): From 3dd529f4203b3da88c5c2703ed89caf50f795cc1 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 23 Sep 2025 12:05:54 +0100 Subject: [PATCH 5/8] CU-869ak0v7n: Some typing fixes --- medcat-v2/medcat/components/addons/meta_cat/data_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py index 7a8a8f3b..6e410a75 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/data_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/data_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Iterator +from typing import Optional, Iterator, cast import copy from medcat.components.addons.meta_cat.mctokenizers.tokenizers import ( @@ -162,7 +162,10 @@ def _prepare_from_json_loop(document: dict, name = meta_ann['name'] value = meta_ann['value'] - sample = (tkns, cpos_new, value) + # NOTE: representing as tuple so as to have better typing + # but using a list to allow assignment + sample: tuple[list, list, str] = cast( + tuple[list, list, str], [tkns, cpos_new, value]) yield name, sample From bf9bce8b189f8bb9059fbb728f11cbb6f8d102c7 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 23 Sep 2025 12:19:00 +0100 Subject: [PATCH 6/8] CU-869ak0v7n: Imporve typing for creating batch piped data --- .../components/addons/meta_cat/meta_cat.py | 6 +-- .../components/addons/meta_cat/ml_utils.py | 48 ++++++++++++++----- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index 2cd3205d..9fd37303 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -15,7 +15,7 @@ from medcat.config.config import ComponentConfig from medcat.config.config_meta_cat import ConfigMetaCAT from medcat.components.addons.meta_cat.ml_utils import ( - predict, train_model, set_all_seeds, eval_model) + predict, train_model, set_all_seeds, eval_model, EvalModelResults) from medcat.components.addons.meta_cat.data_utils import ( prepare_from_json, encode_category_values, prepare_for_oversampled_data) from medcat.components.addons.addons import AddonComponent @@ -632,7 +632,7 @@ def train_raw(self, data_loaded: dict, save_dir_path: Optional[str] = None, self.config.train.last_train_on = datetime.now().timestamp() return report - def eval(self, json_path: str) -> dict: + def eval(self, json_path: str) -> EvalModelResults: """Evaluate from json. Args: @@ -640,7 +640,7 @@ def eval(self, json_path: str) -> dict: The json file ath Returns: - dict: + EvalModelResults: The resulting model dict Raises: diff --git a/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py b/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py index 89b86adc..08cd701b 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd import torch.optim as optim -from typing import Optional, Any, Union +from typing import Optional, Any, Union, TypedDict from torch import nn from scipy.special import softmax from medcat.config.config_meta_cat import ConfigMetaCAT @@ -34,7 +34,9 @@ def set_all_seeds(seed: int) -> None: def create_batch_piped_data(data: list[tuple[list[int], int, Optional[int]]], start_ind: int, end_ind: int, device: Union[torch.device, str], - pad_id: int) -> tuple: + pad_id: int + ) -> tuple[torch.Tensor, list[int], + torch.Tensor, Optional[torch.Tensor]]: """Creates a batch given data and start/end that denote batch size, will also add padding and move to the right device. @@ -52,13 +54,13 @@ def create_batch_piped_data(data: list[tuple[list[int], int, Optional[int]]], Padding index Returns: - x (): + x (torch.Tensor): Same as data, but subsetted and as a tensor - cpos (): + cpos (list[int]): Center positions for the data - attention_mask: + attention_mask (torch.Tensor): Indicating padding mask for the data - y: + y (Optional[torch.Tensor]): class label of the data """ max_seq_len = max([len(x[0]) for x in data]) @@ -78,7 +80,7 @@ class label of the data def predict(model: nn.Module, data: list[tuple[list[int], int, Optional[int]]], - config: ConfigMetaCAT) -> tuple: + config: ConfigMetaCAT) -> tuple[list[int], list[float]]: """Predict on data used in the meta_cat.pipe Args: @@ -399,8 +401,17 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4): return winner_report +EvalModelResults = TypedDict('EvalModelResults', { + "precision": float, + "recall": float, + "f1": float, + "examples": dict, + "confusion matrix": pd.DataFrame, + }) + + def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT, - tokenizer: TokenizerWrapperBase) -> dict: + tokenizer: TokenizerWrapperBase) -> EvalModelResults: """Evaluate a trained model on the provided data Args: @@ -474,9 +485,22 @@ def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT, examples: dict = {'FP': {}, 'FN': {}, 'TP': {}} id2category_value = {v: k for k, v in config.general.category_value2id.items()} + return _eval_predictions( + tokenizer, data, predictions, confusion, id2category_value, + y_eval, precision, recall, f1, examples) + + +def _eval_predictions( + tokenizer: TokenizerWrapperBase, + data: list, + predictions: list[int], + confusion: pd.DataFrame, + id2category_value: dict[int, str], + y_eval: list, + precision, recall, f1, examples: dict) -> EvalModelResults: for i, p in enumerate(predictions): y = id2category_value[y_eval[i]] - p = id2category_value[p] + pred = id2category_value[p] c = data[i][1] if isinstance(c, list): c = c[-1] @@ -487,11 +511,11 @@ def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT, tokenizer.hf_tokenizers.decode( tkns[c:c + 1]).strip() + ">> " + tokenizer.hf_tokenizers.decode(tkns[c + 1:])) - info = "Predicted: {}, True: {}".format(p, y) - if p != y: + info = "Predicted: {}, True: {}".format(pred, y) + if pred != y: # We made a mistake examples['FN'][y] = examples['FN'].get(y, []) + [(info, text)] - examples['FP'][p] = examples['FP'].get(p, []) + [(info, text)] + examples['FP'][pred] = examples['FP'].get(pred, []) + [(info, text)] else: examples['TP'][y] = examples['TP'].get(y, []) + [(info, text)] From e957ac68d46df4996d1ec3a38a2e5d9972a20d4c Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 23 Sep 2025 13:54:22 +0100 Subject: [PATCH 7/8] CU-869ak0v7n: Fix MetaCAT typing in annotation --- medcat-v2/medcat/data/mctexport.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/medcat-v2/medcat/data/mctexport.py b/medcat-v2/medcat/data/mctexport.py index e44c2251..f62a25b4 100644 --- a/medcat-v2/medcat/data/mctexport.py +++ b/medcat-v2/medcat/data/mctexport.py @@ -10,11 +10,18 @@ class MedCATTrainerExportAnnotationRequired(TypedDict): value: str +class MetaAnnotation(TypedDict): + name: str + value: str + acc: float + validated: bool + + class MedCATTrainerExportAnnotation( MedCATTrainerExportAnnotationRequired, total=False): id: Union[str, int] validated: Optional[bool] - meta_anns: dict[str, dict[str, str]] + meta_anns: list[MetaAnnotation] class MedCATTrainerExportDocument(TypedDict): From 78d082cb75331026b83812a60f30259446cecb3d Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 24 Sep 2025 14:24:55 +0100 Subject: [PATCH 8/8] CU-869ak0v7n: Allow Meta Annotations in a MedCATtrainer export to be either a list or a dict --- medcat-v2/medcat/data/mctexport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat-v2/medcat/data/mctexport.py b/medcat-v2/medcat/data/mctexport.py index f62a25b4..d84342c1 100644 --- a/medcat-v2/medcat/data/mctexport.py +++ b/medcat-v2/medcat/data/mctexport.py @@ -21,7 +21,7 @@ class MedCATTrainerExportAnnotation( MedCATTrainerExportAnnotationRequired, total=False): id: Union[str, int] validated: Optional[bool] - meta_anns: list[MetaAnnotation] + meta_anns: Union[list[MetaAnnotation], dict[str, MetaAnnotation]] class MedCATTrainerExportDocument(TypedDict):