Skip to content
204 changes: 113 additions & 91 deletions medcat-v2/medcat/components/addons/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Iterator, cast
import copy

from medcat.components.addons.meta_cat.mctokenizers.tokenizers import (
Expand All @@ -15,7 +15,8 @@ def prepare_from_json(data: dict,
cui_filter: Optional[set] = None,
replace_center: Optional[str] = None,
prerequisites: dict = {},
lowercase: bool = True) -> dict:
lowercase: bool = True
) -> dict[str, list[tuple[list, list, str]]]:
"""Convert the data from a json format into a CSV-like format for
training. This function is not very efficient (the one working with
documents as part of the meta_cat.pipe method is much better).
Expand Down Expand Up @@ -64,91 +65,110 @@ def prepare_from_json(data: dict,

if len(text) > 0:
doc_text = tokenizer(text)
for name, sample in _prepare_from_json_loop(
document, prerequisites, cui_filter, doc_text,
cntx_left, cntx_right, lowercase, replace_center,
tokenizer):
if name in out_data:
out_data[name].append(sample)
else:
out_data[name] = [sample]

for ann in document.get('annotations', document.get(
# A hack to support entities and annotations
'entities', {}).values()):
cui = ann['cui']
skip = False
if 'meta_anns' in ann and prerequisites:
# It is possible to require certain meta_anns to exist
# and have a specific value
for meta_ann in prerequisites:
if (meta_ann not in ann['meta_anns'] or
ann['meta_anns'][meta_ann][
'value'] != prerequisites[meta_ann]):
# Skip this annotation as the prerequisite
# is not met
skip = True
break

if not skip and (cui_filter is None or
not cui_filter or cui in cui_filter):
if ann.get('validated', True) and (
not ann.get('deleted', False) and
not ann.get('killed', False)
and not ann.get('irrelevant', False)):
start = ann['start']
end = ann['end']

# Updated implementation to extract all the tokens
# for the medical entity (rather than the one)
ctoken_idx = []
for ind, pair in enumerate(
doc_text['offset_mapping']):
if start <= pair[0] or start <= pair[1]:
if end <= pair[1]:
ctoken_idx.append(ind)
break
else:
ctoken_idx.append(ind)

_start = max(0, ctoken_idx[0] - cntx_left)
_end = min(len(doc_text['input_ids']),
ctoken_idx[-1] + 1 + cntx_right)

cpos = cntx_left + min(0, ind - cntx_left)
cpos_new = [x - _start for x in ctoken_idx]
tkns = doc_text['input_ids'][_start:_end]

if replace_center is not None:
if lowercase:
replace_center = replace_center.lower()
for p_ind, pair in enumerate(
doc_text['offset_mapping']):
if start >= pair[0] and start < pair[1]:
s_ind = p_ind
if end > pair[0] and end <= pair[1]:
e_ind = p_ind

ln = e_ind - s_ind
tkns = tkns[:cpos] + tokenizer(
replace_center)['input_ids'] + tkns[
cpos + ln + 1:]

# Backward compatibility if meta_anns is a list vs
# dict in the new approach
meta_anns: list[dict] = []
if 'meta_anns' in ann:
if isinstance(ann['meta_anns'], dict):
meta_anns.extend(ann['meta_anns'].values())
else:
meta_anns.extend(ann['meta_anns'])

# If the annotation is validated
for meta_ann in meta_anns:
name = meta_ann['name']
value = meta_ann['value']

sample = [tkns, cpos_new, value]

if name in out_data:
out_data[name].append(sample)
else:
out_data[name] = [sample]
return out_data


def _prepare_from_json_loop(document: dict,
prerequisites: dict,
cui_filter: Optional[set],
doc_text: dict,
cntx_left: int,
cntx_right: int,
lowercase: bool,
replace_center: Optional[str],
tokenizer: TokenizerWrapperBase,
) -> Iterator[tuple[str, tuple[list, list, str]]]:
for ann in document.get('annotations', document.get(
# A hack to support entities and annotations
'entities', {}).values()):
cui = ann['cui']
skip = False
if 'meta_anns' in ann and prerequisites:
# It is possible to require certain meta_anns to exist
# and have a specific value
for meta_ann in prerequisites:
if (meta_ann not in ann['meta_anns'] or
ann['meta_anns'][meta_ann][
'value'] != prerequisites[meta_ann]):
# Skip this annotation as the prerequisite
# is not met
skip = True
break

if not skip and (cui_filter is None or
not cui_filter or cui in cui_filter):
if ann.get('validated', True) and (
not ann.get('deleted', False) and
not ann.get('killed', False)
and not ann.get('irrelevant', False)):
start = ann['start']
end = ann['end']

# Updated implementation to extract all the tokens
# for the medical entity (rather than the one)
ctoken_idx = []
for ind, pair in enumerate(
doc_text['offset_mapping']):
if start <= pair[0] or start <= pair[1]:
if end <= pair[1]:
ctoken_idx.append(ind)
break
else:
ctoken_idx.append(ind)

_start = max(0, ctoken_idx[0] - cntx_left)
_end = min(len(doc_text['input_ids']),
ctoken_idx[-1] + 1 + cntx_right)

cpos = cntx_left + min(0, ind - cntx_left)
cpos_new = [x - _start for x in ctoken_idx]
tkns = doc_text['input_ids'][_start:_end]

if replace_center is not None:
if lowercase:
replace_center = replace_center.lower()
for p_ind, pair in enumerate(
doc_text['offset_mapping']):
if start >= pair[0] and start < pair[1]:
s_ind = p_ind
if end > pair[0] and end <= pair[1]:
e_ind = p_ind

ln = e_ind - s_ind
tkns = tkns[:cpos] + tokenizer(
replace_center)['input_ids'] + tkns[
cpos + ln + 1:]

# Backward compatibility if meta_anns is a list vs
# dict in the new approach
meta_anns: list[dict] = []
if 'meta_anns' in ann:
if isinstance(ann['meta_anns'], dict):
meta_anns.extend(ann['meta_anns'].values())
else:
meta_anns.extend(ann['meta_anns'])

# If the annotation is validated
for meta_ann in meta_anns:
name = meta_ann['name']
value = meta_ann['value']

# NOTE: representing as tuple so as to have better typing
# but using a list to allow assignment
sample: tuple[list, list, str] = cast(
tuple[list, list, str], [tkns, cpos_new, value])
yield name, sample


def prepare_for_oversampled_data(data: list,
tokenizer: TokenizerWrapperBase) -> list:
"""Convert the data from a json format into a CSV-like format for
Expand Down Expand Up @@ -189,20 +209,21 @@ def prepare_for_oversampled_data(data: list,
return data_sampled


def encode_category_values(data: dict,
def encode_category_values(data: list[tuple[list, list, str]],
existing_category_value2id: Optional[dict] = None,
category_undersample=None,
category_undersample: Optional[str] = None,
alternative_class_names: list[list[str]] = []
) -> tuple:
) -> tuple[
list[tuple[list, list, str]], list, dict]:
"""Converts the category values in the data outputted by
`prepare_from_json` into integer values.

Args:
data (dict):
data (list[tuple[list, list, str]]):
Output of `prepare_from_json`.
existing_category_value2id(Optional[dict]):
Map from category_value to id (old/existing).
category_undersample:
category_undersample (Optional[str]):
Name of class that should be used to undersample the data (for 2
phase learning)
alternative_class_names (list[list[str]]):
Expand All @@ -211,9 +232,9 @@ def encode_category_values(data: dict,
`config.general.alternative_class_names`.

Returns:
dict:
list[tuple[list, list, str]]:
New data with integers inplace of strings for category values.
dict:
list:
New undersampled data (for 2 phase learning) with integers
inplace of strings for category values
dict:
Expand Down Expand Up @@ -288,7 +309,8 @@ def encode_category_values(data: dict,

# Map values to numbers
for i in range(len(data_list)):
data_list[i][2] = category_value2id[data_list[i][2]]
# NOTE: internally, it's a a list so assingment will work
data_list[i][2] = category_value2id[data_list[i][2]] # type: ignore

# Creating dict with labels and its number of samples
label_data_ = {v: 0 for v in category_value2id.values()}
Expand Down
6 changes: 3 additions & 3 deletions medcat-v2/medcat/components/addons/meta_cat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from medcat.config.config import ComponentConfig
from medcat.config.config_meta_cat import ConfigMetaCAT
from medcat.components.addons.meta_cat.ml_utils import (
predict, train_model, set_all_seeds, eval_model)
predict, train_model, set_all_seeds, eval_model, EvalModelResults)
from medcat.components.addons.meta_cat.data_utils import (
prepare_from_json, encode_category_values, prepare_for_oversampled_data)
from medcat.components.addons.addons import AddonComponent
Expand Down Expand Up @@ -632,15 +632,15 @@ def train_raw(self, data_loaded: dict, save_dir_path: Optional[str] = None,
self.config.train.last_train_on = datetime.now().timestamp()
return report

def eval(self, json_path: str) -> dict:
def eval(self, json_path: str) -> EvalModelResults:
"""Evaluate from json.

Args:
json_path (str):
The json file ath

Returns:
dict:
EvalModelResults:
The resulting model dict

Raises:
Expand Down
48 changes: 36 additions & 12 deletions medcat-v2/medcat/components/addons/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd
import torch.optim as optim
from typing import Optional, Any, Union
from typing import Optional, Any, Union, TypedDict
from torch import nn
from scipy.special import softmax
from medcat.config.config_meta_cat import ConfigMetaCAT
Expand Down Expand Up @@ -34,7 +34,9 @@ def set_all_seeds(seed: int) -> None:
def create_batch_piped_data(data: list[tuple[list[int], int, Optional[int]]],
start_ind: int, end_ind: int,
device: Union[torch.device, str],
pad_id: int) -> tuple:
pad_id: int
) -> tuple[torch.Tensor, list[int],
torch.Tensor, Optional[torch.Tensor]]:
"""Creates a batch given data and start/end that denote batch size,
will also add padding and move to the right device.

Expand All @@ -52,13 +54,13 @@ def create_batch_piped_data(data: list[tuple[list[int], int, Optional[int]]],
Padding index

Returns:
x ():
x (torch.Tensor):
Same as data, but subsetted and as a tensor
cpos ():
cpos (list[int]):
Center positions for the data
attention_mask:
attention_mask (torch.Tensor):
Indicating padding mask for the data
y:
y (Optional[torch.Tensor]):
class label of the data
"""
max_seq_len = max([len(x[0]) for x in data])
Expand All @@ -78,7 +80,7 @@ class label of the data


def predict(model: nn.Module, data: list[tuple[list[int], int, Optional[int]]],
config: ConfigMetaCAT) -> tuple:
config: ConfigMetaCAT) -> tuple[list[int], list[float]]:
"""Predict on data used in the meta_cat.pipe

Args:
Expand Down Expand Up @@ -399,8 +401,17 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):
return winner_report


EvalModelResults = TypedDict('EvalModelResults', {
"precision": float,
"recall": float,
"f1": float,
"examples": dict,
"confusion matrix": pd.DataFrame,
})


def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT,
tokenizer: TokenizerWrapperBase) -> dict:
tokenizer: TokenizerWrapperBase) -> EvalModelResults:
"""Evaluate a trained model on the provided data

Args:
Expand Down Expand Up @@ -474,9 +485,22 @@ def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT,
examples: dict = {'FP': {}, 'FN': {}, 'TP': {}}
id2category_value = {v: k for k, v
in config.general.category_value2id.items()}
return _eval_predictions(
tokenizer, data, predictions, confusion, id2category_value,
y_eval, precision, recall, f1, examples)


def _eval_predictions(
tokenizer: TokenizerWrapperBase,
data: list,
predictions: list[int],
confusion: pd.DataFrame,
id2category_value: dict[int, str],
y_eval: list,
precision, recall, f1, examples: dict) -> EvalModelResults:
for i, p in enumerate(predictions):
y = id2category_value[y_eval[i]]
p = id2category_value[p]
pred = id2category_value[p]
c = data[i][1]
if isinstance(c, list):
c = c[-1]
Expand All @@ -487,11 +511,11 @@ def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT,
tokenizer.hf_tokenizers.decode(
tkns[c:c + 1]).strip() + ">> " +
tokenizer.hf_tokenizers.decode(tkns[c + 1:]))
info = "Predicted: {}, True: {}".format(p, y)
if p != y:
info = "Predicted: {}, True: {}".format(pred, y)
if pred != y:
# We made a mistake
examples['FN'][y] = examples['FN'].get(y, []) + [(info, text)]
examples['FP'][p] = examples['FP'].get(p, []) + [(info, text)]
examples['FP'][pred] = examples['FP'].get(pred, []) + [(info, text)]
else:
examples['TP'][y] = examples['TP'].get(y, []) + [(info, text)]

Expand Down
Loading
Loading