Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions v1/medcat/medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
"The category name does not exist in this json file. You've provided '{}', "
"while the possible options are: {}. Additionally, ensure the populate the "
"'alternative_category_names' attribute to accommodate for variations.".format(
category_name, " | ".join(list(data.keys()))))
g_config['category_name'], " | ".join(list(data.keys()))))

data = data[category_name]
if data_oversampled:
Expand All @@ -263,12 +263,12 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
if not category_value2id:
# Encode the category values
full_data, data_undersampled, category_value2id = encode_category_values(data,
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
alternative_class_names=g_config['alternative_class_names'],config=self.config)
else:
# We already have everything, just get the data
full_data, data_undersampled, category_value2id = encode_category_values(data,
existing_category_value2id=category_value2id,
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
alternative_class_names=g_config['alternative_class_names'],config=self.config)
g_config['category_value2id'] = category_value2id
self.config.model['nclasses'] = len(category_value2id)

Expand Down
189 changes: 125 additions & 64 deletions v1/medcat/medcat/utils/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Iterable, List, Union
from typing import Any, Dict, Optional, Tuple, Iterable, List, Union, Set
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
import copy
import logging
Expand Down Expand Up @@ -153,8 +153,100 @@ def prepare_for_oversampled_data(data: List,
return data_sampled


def find_alternate_classname(category_value2id: Dict, category_values: Set, alternative_class_names: List[List]) -> Dict:
"""Helper function to find and map to alternative class names for the given category.
Example: For Temporality category, 'Recent' is an alternative to 'Present'.

Args:
category_value2id (Dict):
The pre-defined category_value2id
category_values (Set):
Contains the classes (labels) found in the data
alternative_class_names (List):
Contains the mapping of alternative class names

Returns:
category_value2id (Dict):
Updated category_value2id with keys corresponding to alternative class names

Raises:
Exception: If no alternatives are found for labels in category_value2id that don't match any of the labels in the data
Exception: If the alternatives defined for labels in category_value2id that don't match any of the labels in the data
"""

updated_category_value2id = {}
for _class in category_value2id.keys():
if _class in category_values:
updated_category_value2id[_class] = category_value2id[_class]
else:
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
failed_to_find = False
if len(found_in) != 0:
class_name_matched = [label for label in found_in[0] if label in category_values]
if len(class_name_matched) != 0:
updated_category_value2id[class_name_matched[0]] = category_value2id[_class]
logger.info("Class name '%s' does not exist in the data; however a variation of it "
"'%s' is present; updating it...", _class, class_name_matched[0])
else:
failed_to_find = True
else:
failed_to_find = True
if failed_to_find:
raise Exception("The classes set in the config are not the same as the one found in the data. "
"The classes present in the config vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the "
"populate the 'alternative_class_names' attribute to accommodate for variations.")
category_value2id = copy.deepcopy(updated_category_value2id)
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
return category_value2id


def undersample_data(data: List, category_value2id: Dict, label_data_,config) -> List:
"""Undersamples the data for 2 phase learning

Args:
data (List):
Output of `prepare_from_json`.
category_value2id(Dict):
Map from category_value to id.
label_data_:
Map that stores the number of samples for each label
config:
MetaCAT config

Returns:
data_undersampled (list):
Return the data created for 2 phase learning) with integers inplace of strings for category values
"""

data_undersampled = []
category_undersample = config.model.category_undersample
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())

else:
if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys():
min_label = label_data_[category_value2id[category_undersample]]
else:
min_label = label_data_[category_undersample]

label_data_counter = {v: 0 for v in category_value2id.values()}

for sample in data:
if label_data_counter[sample[-1]] < min_label:
data_undersampled.append(sample)
label_data_counter[sample[-1]] += 1

label_data = {v: 0 for v in category_value2id.values()}
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
logger.info("Updated number of samples per label (for 2-phase learning): %s", label_data)
return data_undersampled


def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is used by the trainer:

data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id)

Now, it looks like this change doesn't change the API in a way that would break that (at least not immediately). However, I'd like to have some stability in our API.

Perhaps a test for this method to make sure the behaviour is consistent?

category_undersample=None, alternative_class_names: List[List] = []) -> Tuple:
alternative_class_names: List[List] = [], config=None) -> Tuple:
"""Converts the category values in the data outputted by `prepare_from_json`
into integer values.

Expand All @@ -163,22 +255,24 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
Output of `prepare_from_json`.
existing_category_value2id(Optional[Dict]):
Map from category_value to id (old/existing).
category_undersample:
Name of class that should be used to undersample the data (for 2 phase learning)
alternative_class_names:
Map that stores the variations of possible class names for the given category (task)
config:
MetaCAT config

Returns:
dict:
data (list):
New data with integers inplace of strings for category values.
dict:
data_undersampled (list):
New undersampled data (for 2 phase learning) with integers inplace of strings for category values
dict:
category_value2id (dict):
Map from category value to ID for all categories in the data.

Raises:
Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data
Exception: If the number of classes in config do not match the number of classes found in the data
Exception: If category_value2id is pre-defined, its labels do not match the labels found in the data and alternative_class_names is empty
"""

data = list(data)
if existing_category_value2id is not None:
category_value2id = existing_category_value2id
Expand All @@ -187,43 +281,29 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict

category_values = set([x[2] for x in data])

# If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data
if len(category_value2id) != 0 and set(category_value2id.keys()) != category_values:
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
if len(alternative_class_names) == 0:
# Raise an exception since the labels don't match
if config:
if len(category_values) != config.model.nclasses:
raise Exception(
"The classes set in the config are not the same as the one found in the data. "
"The classes present in the config vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the "
"'alternative_class_names' attribute to accommodate for variations.")
updated_category_value2id = {}
for _class in category_value2id.keys():
if _class in category_values:
updated_category_value2id[_class] = category_value2id[_class]
else:
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
failed_to_find = False
if len(found_in) != 0:
class_name_matched = [label for label in found_in[0] if label in category_values]
if len(class_name_matched) != 0:
updated_category_value2id[class_name_matched[0]] = category_value2id[_class]
logger.info("Class name '%s' does not exist in the data; however a variation of it "
"'%s' is present; updating it...", _class, class_name_matched[0])
else:
failed_to_find = True
else:
failed_to_find = True
if failed_to_find:
raise Exception("The classes set in the config are not the same as the one found in the data. "
"The classes present in the config vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the "
"populate the 'alternative_class_names' attribute to accommodate for variations.")
category_value2id = copy.deepcopy(updated_category_value2id)
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
"The number of classes found in the data - %s does not match the number of classes defined in the config - %s (config.model.nclasses). Please update the number of classes and initialise the model again.",
len(category_values), config.model.nclasses)

# If categoryvalue2id is pre-defined or if all the classes aren't mentioned
if len(category_value2id) != 0:
# making sure it is same as the labels found in the data
if set(category_value2id.keys()) != category_values:
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
if len(alternative_class_names) == 0:
# Raise an exception since the labels don't match
raise Exception(
"The classes set in the config are not the same as the one found in the data. "
"The classes present in the config vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the "
"'alternative_class_names' attribute to accommodate for variations.")

category_value2id = find_alternate_classname(category_value2id, category_values, alternative_class_names)

# Else create the mapping from the labels found in the data
else:
if len(category_value2id) != len(category_values):
for c in category_values:
if c not in category_value2id:
category_value2id[c] = len(category_value2id)
Expand All @@ -239,30 +319,11 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
if data[i][2] in category_value2id.values():
label_data_[data[i][2]] = label_data_[data[i][2]] + 1

logger.info("Original number of samples per label: %s",label_data_)
# Undersampling data
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())

else:
if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys():
min_label = label_data_[category_value2id[category_undersample]]
else:
min_label = label_data_[category_undersample]
logger.info("Original number of samples per label: %s", label_data_)

data_undersampled = []
label_data_counter = {v: 0 for v in category_value2id.values()}

for sample in data:
if label_data_counter[sample[-1]] < min_label:
data_undersampled.append(sample)
label_data_counter[sample[-1]] += 1

label_data = {v: 0 for v in category_value2id.values()}
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data)
if config and config.model.phase_number != 0:
data_undersampled = undersample_data(data, category_value2id, label_data_, config)

return data, data_undersampled, category_value2id

Expand Down