diff --git a/v1/medcat/medcat/meta_cat.py b/v1/medcat/medcat/meta_cat.py index c646f44d..8b922d48 100644 --- a/v1/medcat/medcat/meta_cat.py +++ b/v1/medcat/medcat/meta_cat.py @@ -252,7 +252,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data "The category name does not exist in this json file. You've provided '{}', " "while the possible options are: {}. Additionally, ensure the populate the " "'alternative_category_names' attribute to accommodate for variations.".format( - category_name, " | ".join(list(data.keys())))) + g_config['category_name'], " | ".join(list(data.keys())))) data = data[category_name] if data_oversampled: @@ -263,12 +263,12 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data if not category_value2id: # Encode the category values full_data, data_undersampled, category_value2id = encode_category_values(data, - category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names']) + alternative_class_names=g_config['alternative_class_names'],config=self.config) else: # We already have everything, just get the data full_data, data_undersampled, category_value2id = encode_category_values(data, existing_category_value2id=category_value2id, - category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names']) + alternative_class_names=g_config['alternative_class_names'],config=self.config) g_config['category_value2id'] = category_value2id self.config.model['nclasses'] = len(category_value2id) diff --git a/v1/medcat/medcat/utils/meta_cat/data_utils.py b/v1/medcat/medcat/utils/meta_cat/data_utils.py index 3fff0651..9272d1a2 100644 --- a/v1/medcat/medcat/utils/meta_cat/data_utils.py +++ b/v1/medcat/medcat/utils/meta_cat/data_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Iterable, List, Union +from typing import Any, Dict, Optional, Tuple, Iterable, List, Union, Set from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase import copy import logging @@ -153,8 +153,100 @@ def prepare_for_oversampled_data(data: List, return data_sampled +def find_alternate_classname(category_value2id: Dict, category_values: Set, alternative_class_names: List[List]) -> Dict: + """Helper function to find and map to alternative class names for the given category. + Example: For Temporality category, 'Recent' is an alternative to 'Present'. + + Args: + category_value2id (Dict): + The pre-defined category_value2id + category_values (Set): + Contains the classes (labels) found in the data + alternative_class_names (List): + Contains the mapping of alternative class names + + Returns: + category_value2id (Dict): + Updated category_value2id with keys corresponding to alternative class names + + Raises: + Exception: If no alternatives are found for labels in category_value2id that don't match any of the labels in the data + Exception: If the alternatives defined for labels in category_value2id that don't match any of the labels in the data + """ + + updated_category_value2id = {} + for _class in category_value2id.keys(): + if _class in category_values: + updated_category_value2id[_class] = category_value2id[_class] + else: + found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map] + failed_to_find = False + if len(found_in) != 0: + class_name_matched = [label for label in found_in[0] if label in category_values] + if len(class_name_matched) != 0: + updated_category_value2id[class_name_matched[0]] = category_value2id[_class] + logger.info("Class name '%s' does not exist in the data; however a variation of it " + "'%s' is present; updating it...", _class, class_name_matched[0]) + else: + failed_to_find = True + else: + failed_to_find = True + if failed_to_find: + raise Exception("The classes set in the config are not the same as the one found in the data. " + "The classes present in the config vs the ones found in the data - " + f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the " + "populate the 'alternative_class_names' attribute to accommodate for variations.") + category_value2id = copy.deepcopy(updated_category_value2id) + logger.info("Updated categoryvalue2id mapping - %s", category_value2id) + return category_value2id + + +def undersample_data(data: List, category_value2id: Dict, label_data_,config) -> List: + """Undersamples the data for 2 phase learning + + Args: + data (List): + Output of `prepare_from_json`. + category_value2id(Dict): + Map from category_value to id. + label_data_: + Map that stores the number of samples for each label + config: + MetaCAT config + + Returns: + data_undersampled (list): + Return the data created for 2 phase learning) with integers inplace of strings for category values + """ + + data_undersampled = [] + category_undersample = config.model.category_undersample + if category_undersample is None or category_undersample == '': + min_label = min(label_data_.values()) + + else: + if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys(): + min_label = label_data_[category_value2id[category_undersample]] + else: + min_label = label_data_[category_undersample] + + label_data_counter = {v: 0 for v in category_value2id.values()} + + for sample in data: + if label_data_counter[sample[-1]] < min_label: + data_undersampled.append(sample) + label_data_counter[sample[-1]] += 1 + + label_data = {v: 0 for v in category_value2id.values()} + for i in range(len(data_undersampled)): + if data_undersampled[i][2] in category_value2id.values(): + label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1 + logger.info("Updated number of samples per label (for 2-phase learning): %s", label_data) + return data_undersampled + + def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None, - category_undersample=None, alternative_class_names: List[List] = []) -> Tuple: + alternative_class_names: List[List] = [], config=None) -> Tuple: """Converts the category values in the data outputted by `prepare_from_json` into integer values. @@ -163,22 +255,24 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict Output of `prepare_from_json`. existing_category_value2id(Optional[Dict]): Map from category_value to id (old/existing). - category_undersample: - Name of class that should be used to undersample the data (for 2 phase learning) alternative_class_names: Map that stores the variations of possible class names for the given category (task) + config: + MetaCAT config Returns: - dict: + data (list): New data with integers inplace of strings for category values. - dict: + data_undersampled (list): New undersampled data (for 2 phase learning) with integers inplace of strings for category values - dict: + category_value2id (dict): Map from category value to ID for all categories in the data. Raises: - Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data + Exception: If the number of classes in config do not match the number of classes found in the data + Exception: If category_value2id is pre-defined, its labels do not match the labels found in the data and alternative_class_names is empty """ + data = list(data) if existing_category_value2id is not None: category_value2id = existing_category_value2id @@ -187,43 +281,29 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict category_values = set([x[2] for x in data]) - # If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data - if len(category_value2id) != 0 and set(category_value2id.keys()) != category_values: - # if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations - if len(alternative_class_names) == 0: - # Raise an exception since the labels don't match + if config: + if len(category_values) != config.model.nclasses: raise Exception( - "The classes set in the config are not the same as the one found in the data. " - "The classes present in the config vs the ones found in the data - " - f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the " - "'alternative_class_names' attribute to accommodate for variations.") - updated_category_value2id = {} - for _class in category_value2id.keys(): - if _class in category_values: - updated_category_value2id[_class] = category_value2id[_class] - else: - found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map] - failed_to_find = False - if len(found_in) != 0: - class_name_matched = [label for label in found_in[0] if label in category_values] - if len(class_name_matched) != 0: - updated_category_value2id[class_name_matched[0]] = category_value2id[_class] - logger.info("Class name '%s' does not exist in the data; however a variation of it " - "'%s' is present; updating it...", _class, class_name_matched[0]) - else: - failed_to_find = True - else: - failed_to_find = True - if failed_to_find: - raise Exception("The classes set in the config are not the same as the one found in the data. " - "The classes present in the config vs the ones found in the data - " - f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the " - "populate the 'alternative_class_names' attribute to accommodate for variations.") - category_value2id = copy.deepcopy(updated_category_value2id) - logger.info("Updated categoryvalue2id mapping - %s", category_value2id) + "The number of classes found in the data - %s does not match the number of classes defined in the config - %s (config.model.nclasses). Please update the number of classes and initialise the model again.", + len(category_values), config.model.nclasses) + + # If categoryvalue2id is pre-defined or if all the classes aren't mentioned + if len(category_value2id) != 0: + # making sure it is same as the labels found in the data + if set(category_value2id.keys()) != category_values: + # if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations + if len(alternative_class_names) == 0: + # Raise an exception since the labels don't match + raise Exception( + "The classes set in the config are not the same as the one found in the data. " + "The classes present in the config vs the ones found in the data - " + f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the " + "'alternative_class_names' attribute to accommodate for variations.") + + category_value2id = find_alternate_classname(category_value2id, category_values, alternative_class_names) # Else create the mapping from the labels found in the data - else: + if len(category_value2id) != len(category_values): for c in category_values: if c not in category_value2id: category_value2id[c] = len(category_value2id) @@ -239,30 +319,11 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict if data[i][2] in category_value2id.values(): label_data_[data[i][2]] = label_data_[data[i][2]] + 1 - logger.info("Original number of samples per label: %s",label_data_) - # Undersampling data - if category_undersample is None or category_undersample == '': - min_label = min(label_data_.values()) - - else: - if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys(): - min_label = label_data_[category_value2id[category_undersample]] - else: - min_label = label_data_[category_undersample] + logger.info("Original number of samples per label: %s", label_data_) data_undersampled = [] - label_data_counter = {v: 0 for v in category_value2id.values()} - - for sample in data: - if label_data_counter[sample[-1]] < min_label: - data_undersampled.append(sample) - label_data_counter[sample[-1]] += 1 - - label_data = {v: 0 for v in category_value2id.values()} - for i in range(len(data_undersampled)): - if data_undersampled[i][2] in category_value2id.values(): - label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1 - logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data) + if config and config.model.phase_number != 0: + data_undersampled = undersample_data(data, category_value2id, label_data_, config) return data, data_undersampled, category_value2id