Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/medcat-v2_main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ jobs:
uv run ruff check medcat --preview
- name: Test
run: |
timeout 20m uv run python -m unittest discover
timeout 30m uv run python -m unittest discover
- name: Model regression
run: |
uv run bash tests/backwards_compatibility/run_current.sh
- name: Backwards compatibility
run: |
uv run bash tests/backwards_compatibility/check_backwards_compatibility.sh
- name: Minimize uv cache
run: uv cache prune --ci
run: uv cache prune --ci
243 changes: 159 additions & 84 deletions medcat-v2/medcat/components/addons/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from medcat.components.addons.meta_cat.mctokenizers.tokenizers import (
TokenizerWrapperBase)
from medcat.config.config_meta_cat import ConfigMetaCAT
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -209,10 +210,124 @@ def prepare_for_oversampled_data(data: list,
return data_sampled


def find_alternate_classname(category_value2id: dict, category_values: set[str],
alternative_class_names: list[list[str]]) -> dict:
"""Find and map to alternative class names for the given category.

Example:
For Temporality category, 'Recent' is an alternative to 'Present'.

Args:
category_value2id (dict):
The pre-defined category_value2id
category_values (set[str]):
Contains the classes (labels) found in the data
alternative_class_names (list[list[str]]):
Contains the mapping of alternative class names

Returns:
category_value2id (dict):
Updated category_value2id with keys corresponding to
alternative class names

Raises:
Exception: If no alternatives are found for labels in
category_value2id that don't match any of the labels in
the data
Exception: If the alternatives defined for labels in
category_value2id that don't match any of the labels in
the data
"""

updated_category_value2id = {}
for _class in category_value2id.keys():
if _class in category_values:
updated_category_value2id[_class] = category_value2id[_class]
else:
found_in = [sub_map for sub_map in alternative_class_names
if _class in sub_map]
failed_to_find = False
if len(found_in) != 0:
class_name_matched = [label for label in found_in[0]
if label in category_values]
if len(class_name_matched) != 0:
updated_category_value2id[
class_name_matched[0]] = category_value2id[_class]
logger.info(
"Class name '%s' does not exist in the data; however "
"a variation of it '%s' is present; updating it...",
_class, class_name_matched[0])
else:
failed_to_find = True
else:
failed_to_find = True
if failed_to_find:
raise Exception(
"The classes set in the config are not the same as the "
"one found in the data. The classes present in the config "
"vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. "
"Additionally, ensure the populate the "
"'alternative_class_names' attribute to accommodate for "
"variations.")
category_value2id = copy.deepcopy(updated_category_value2id)
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
return category_value2id


def undersample_data(data: list, category_value2id: dict, label_data_,
config: ConfigMetaCAT) -> list:
"""Undersamples the data for 2 phase learning

Args:
data (list):
Output of `prepare_from_json`.
category_value2id(dict):
Map from category_value to id.
label_data_:
Map that stores the number of samples for each label
config:
MetaCAT config

Returns:
data_undersampled (list):
Return the data created for 2 phase learning) with integers
inplace of strings for category values
"""

data_undersampled = []
category_undersample = config.model.category_undersample
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())

else:
if (category_undersample not in label_data_.keys() and
category_undersample in category_value2id.keys()):
min_label = label_data_[category_value2id[category_undersample]]
else:
min_label = label_data_[category_undersample]

label_data_counter = {v: 0 for v in category_value2id.values()}

for sample in data:
if label_data_counter[sample[-1]] < min_label:
data_undersampled.append(sample)
label_data_counter[sample[-1]] += 1

label_data = {v: 0 for v in category_value2id.values()}
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = (
label_data[data_undersampled[i][2]] + 1)
logger.info("Updated number of samples per label (for 2-phase learning):"
" %s", label_data)
return data_undersampled


def encode_category_values(data: list[tuple[list, list, str]],
existing_category_value2id: Optional[dict] = None,
category_undersample: Optional[str] = None,
alternative_class_names: list[list[str]] = []
alternative_class_names: list[list[str]] = [],
config: Optional[ConfigMetaCAT] = None,
) -> tuple[
list[tuple[list, list, str]], list, dict]:
"""Converts the category values in the data outputted by
Expand All @@ -223,13 +338,12 @@ def encode_category_values(data: list[tuple[list, list, str]],
Output of `prepare_from_json`.
existing_category_value2id(Optional[dict]):
Map from category_value to id (old/existing).
category_undersample (Optional[str]):
Name of class that should be used to undersample the data (for 2
phase learning)
alternative_class_names (list[list[str]]):
A list of lists of strings, where each list contains variations
of a class name. Usually read from the config at
`config.general.alternative_class_names`.
config (Optional[ConfigMetaCAT]):
The MetaCAT Config.

Returns:
list[tuple[list, list, str]]:
Expand All @@ -252,98 +366,59 @@ def encode_category_values(data: list[tuple[list, list, str]],

category_values = set([x[2] for x in data_list])

if (len(category_value2id) != 0 and
set(category_value2id.keys()) != category_values):
# if categoryvalue2id doesn't match the labels in the data,
# then 'alternative_class_names' has to be defined to check
# for variations
if len(alternative_class_names) == 0:
# Raise an exception since the labels don't match
if config:
if len(category_values) != config.model.nclasses:
raise Exception(
"The classes set in the config are not the same as the one "
"found in the data. The classes present in the config vs the "
"ones found in the data - {set(category_value2id.keys())}, "
f"{category_values}. Additionally, ensure the populate the "
"'alternative_class_names' attribute to accommodate for "
"variations.")
updated_category_value2id = {}
for _class in category_value2id.keys():
if _class in category_values:
updated_category_value2id[_class] = category_value2id[_class]
else:
found_in = [sub_map for sub_map in alternative_class_names
if _class in sub_map]
failed_to_find = False
if len(found_in) != 0:
class_name_matched = [label for label in found_in[0]
if label in category_values]
if len(class_name_matched) != 0:
updated_category_value2id[class_name_matched[0]
] = category_value2id[_class]
logger.info(
"Class name '%s' does not exist in the data; "
"however a variation of it '%s' is present; "
"updating it...", _class, class_name_matched[0])
else:
failed_to_find = True
else:
failed_to_find = True
if failed_to_find:
raise Exception(
"The classes set in the config are not the same as "
"the one found in the data. The classes present in "
"the config vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. "
"Additionally, ensure the populate the "
"'alternative_class_names' attribute to accommodate "
"for variations.")
category_value2id = copy.deepcopy(updated_category_value2id)
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
"The number of classes found in the data - %s does not match "
"the number of classes defined in the config - %s "
"(config.model.nclasses). Please update the number of classes "
"and initialise the model again.", len(category_values),
config.model.nclasses)
# If categoryvalue2id is pre-defined or if all the classes aren't mentioned
if len(category_value2id) != 0:
# making sure it is same as the labels found in the data
if set(category_value2id.keys()) != category_values:
# if categoryvalue2id doesn't match the labels in the data,
# then 'alternative_class_names' has to be defined to check for
# variations
if len(alternative_class_names) == 0:
# Raise an exception since the labels don't match
raise Exception(
"The classes set in the config are not the same as the "
"one found in the data. The classes present in the config "
"vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. "
"Additionally, ensure the populate the "
"'alternative_class_names' attribute to accommodate for "
"variations.")

category_value2id = find_alternate_classname(
category_value2id, category_values, alternative_class_names)

# Else create the mapping from the labels found in the data
else:
if len(category_value2id) != len(category_values):
for c in category_values:
if c not in category_value2id:
category_value2id[c] = len(category_value2id)
logger.info("Categoryvalue2id mapping created with labels found "
"in the data - %s", category_value2id)
logger.info("Categoryvalue2id mapping created with labels found in "
"the data - %s", category_value2id)

# Map values to numbers
for i in range(len(data_list)):
# NOTE: internally, it's a a list so assingment will work
data_list[i][2] = category_value2id[data_list[i][2]] # type: ignore
for i in range(len(data)):
# represented as a tuple so that we can type hint, but it's a list
data[i][2] = category_value2id[data[i][2]] # type: ignore

# Creating dict with labels and its number of samples
label_data_ = {v: 0 for v in category_value2id.values()}
for i in range(len(data_list)):
if data_list[i][2] in category_value2id.values():
label_data_[data_list[i][2]] = label_data_[data_list[i][2]] + 1
for i in range(len(data)):
if data[i][2] in category_value2id.values():
label_data_[data[i][2]] = label_data_[data[i][2]] + 1

logger.info("Original number of samples per label: %s", label_data_)
# Undersampling data
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())

else:
if (category_undersample not in label_data_.keys() and
category_undersample in category_value2id.keys()):
min_label = label_data_[category_value2id[category_undersample]]
else:
min_label = label_data_[category_undersample]

data_undersampled = []
label_data_counter = {v: 0 for v in category_value2id.values()}

for sample in data_list:
if label_data_counter[sample[-1]] < min_label:
data_undersampled.append(sample)
label_data_counter[sample[-1]] += 1

label_data = {v: 0 for v in category_value2id.values()}
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = label_data[
data_undersampled[i][2]] + 1
logger.info("Updated number of samples per label (for 2-phase learning): "
"%s", label_data)
if config and config.model.phase_number != 0:
data_undersampled = undersample_data(
data, category_value2id, label_data_, config)

return data_list, data_undersampled, category_value2id
5 changes: 2 additions & 3 deletions medcat-v2/medcat/components/addons/meta_cat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,15 +556,14 @@ def train_raw(self, data_loaded: dict, save_dir_path: Optional[str] = None,
# Encode the category values
(full_data, data_undersampled,
category_value2id) = encode_category_values(
data,
category_undersample=self.config.model.category_undersample,
data, config=self.config,
alternative_class_names=g_config.alternative_class_names)
else:
# We already have everything, just get the data
(full_data, data_undersampled,
category_value2id) = encode_category_values(
data, existing_category_value2id=category_value2id,
category_undersample=self.config.model.category_undersample,
config=self.config,
alternative_class_names=g_config.alternative_class_names)
g_config.category_value2id = category_value2id
self.config.model.nclasses = len(category_value2id)
Expand Down
11 changes: 6 additions & 5 deletions medcat-v2/medcat/components/addons/meta_cat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ def __init__(self, config: ConfigMetaCAT,
"DO NOT use this model without loading the model state!",
exc_info=e)

self.config = config
self._config = config
self.bert = bert
self.bert_config = _bertconfig
# NOTE: potentially used downstream
self.config = self.bert_config = _bertconfig
self.num_labels = config.model.nclasses
for param in self.bert.parameters():
param.requires_grad = not config.model.model_freeze_layers
Expand Down Expand Up @@ -252,14 +253,14 @@ def forward(
x = self.fc1(x)
x = self.relu(x)

if self.config.model.model_architecture_config is not None:
if self.config.model.model_architecture_config['fc2'] is True:
if self._config.model.model_architecture_config is not None:
if self._config.model.model_architecture_config['fc2'] is True:
# fc2
x = self.fc2(x)
x = self.relu(x)
x = self.dropout(x)

if self.config.model.model_architecture_config['fc3'] is True:
if self._config.model.model_architecture_config['fc3'] is True:
# fc3
x = self.fc3(x)
x = self.relu(x)
Expand Down
Loading
Loading