In [1]:
import json
import os
from datetime import date
from medcat.cat import CAT
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT

In [2]:
# if you want to enable info level logging
import logging
logging.basicConfig(level=logging.INFO,force=True)

#### 💡 To understand the model loading and other functionalities, please refer to the 'meta_annotation_training.ipynb' notebook

In [3]:
model_pack = '<enter path to the model pack>' # .zip model pack location
mctrainer_export = "<enter mct export location>"  # name of your mct export

We won't load the models at this stage as they need to be seperately loaded later. <br> Let's check for meta models in the directory

In [4]:
# Iterate through the meta_models contained in the model
meta_model_names = []
for dirpath, dirnames, filenames in os.walk(model_pack):
    for dirname in dirnames:
        if dirname.startswith('meta_'):
            meta_model_names.append(dirname[5:])

print("Meta models:",meta_model_names)

# Class weights 

Adjusting class weights to give more importance to specific classes. Generally, class weights are used in favour of minority classes(classes with less number of samples) to boost their performance.
<br><br>To use class weights, we have 2 options:
<br>1. calculate class weights based on class distribution
<br>2. using specified class weights


<b>#option 1 </b><br>
metacat.config.train['class_weights'] = []<br>
metacat.config.train['compute_class_weights'] = True<br>
<br>
<b>#option 2</b><br>
metacat.config.train['class_weights'] = [0.4,0.3,0.1]<br>

<b>NOTE:</b> Make sure to correctly map the class weights to their corresponding class index. <br>To check the index assigned to the classes, use: <br>`print(mc.config.general['category_value2id'])`
<br>This will print a dictionary where the class names and their corresponding IDs (indices) are displayed. <br>
The first position in the class weight list corresponds to the class with ID 0 in the dictionary, and so on.

# 2 phase learning for training

2 phase learning is used to mitigate class imbalance. In 2 phase learning, the models are trained twice: <br> 
Phase 1: trains for minority class(es) by undersampling data so that there is no class imbalance
<br>Phase 2: trains for all classes

Phase 1 ensures that the model learns minority class(es) and captures the details correctly.
<br> Phase 2 is when the model is expected to learn the majority class as it is trained on the entire dataset.

Paper reference - https://ieeexplore.ieee.org/document/7533053
<br>Make sure to use class weights in favour of minority classes with 2 phase learning

In [5]:
#--------------------------------Phase 1--------------------------------
def run_phase_1(meta_model,class_wt_phase1 = None):
    #Loading the pre-defined config for phase 1
    config_ph_1_path = os.path.join(model_pack,"meta_"+meta_model,"config_ph1.json")
    with open(config_ph_1_path) as f:
        config_ph1 = json.load(f)
    mc = MetaCAT.load(save_dir_path=os.path.join(model_pack,"meta_"+meta_model),config_dict = config_ph1)

    if class_wt_phase1:
        mc.config.train['class_weights'] = class_wt_phase1

    #You can change the number of epochs, remember to keep them higher for phase 1
    mc.config.train['nepochs'] = 40 

    results = mc.train_from_json(mctrainer_export, save_dir_path=save_dir_path)
    # Save results
    json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results_phase1.json'), 'w'))

#--------------------------------Phase 2--------------------------------
def run_phase_2(meta_model,class_wt_phase2 = None): 
    #Loading the pre-defined config for phase 2
    config_ph_2_path = os.path.join(model_pack,"meta_"+meta_model,"config_ph2.json")
    with open(config_ph_2_path) as f:
        config_ph2 = json.load(f)

    mc = MetaCAT.load(save_dir_path=os.path.join(model_pack,"meta_"+meta_model),config_dict = config_ph2)

    if class_wt_phase2:
        mc.config.train['class_weights'] = class_wt_phase2

    #You can change the number of epochs
    mc.config.train['nepochs'] = 20

    results = mc.train_from_json(mctrainer_export, save_dir_path=save_dir_path)
    # Save results
    json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results_phase2.json'), 'w'))

#--------------------------------Driver--------------------------------
# Train the first meta cat model
meta_model = meta_model_names[0]

# to overwrite the existing model, resave the fine-tuned model with the same model pack dir
meta_cat_task = meta_model
save_dir_path = os.path.join(model_pack,"meta_"+ meta_cat_task)

# To use your own class weights instead of the pre-defined ones for the 2 phases, put the weights in the lists below
class_wt_phase1 = [] # Example [0.4,0.4,0.2]
class_wt_phase2 = [] # Example [0.4,0.3,0.3]


# Train 2 phase learning
print("*** Training meta cat: ",meta_model)
print("Beginning Phase 1...")
run_phase_1(meta_model,class_wt_phase1)
print("Beginning Phase 2...")
run_phase_2(meta_model,class_wt_phase2)

# Generating synthetic data

You can generate synthetic data to help mitigate class imbalance. <br> Use this code to generate synthetic data using LLM - [link](https://gist.github.com/shubham-s-agarwal/401ef8bf6cbbd66fa0c76a8fbfc1f6c4) <br> <b>NOTE</b>: the generated data will require manual quality check to ensure that high quality and relevant data is used for training. 

The data generated from the gist code and the format of the data required by MedCAT are different, requiring manual formatting at the moment. We will update this module to include the code to handle the same.

In [None]:
# To run the training with original + synthetic data
# Follow all the same steps till and load the model

# the format expected is [[['text','of','the','document'], [index of medical entity], "label" ],
#                ['text','of','the','document'], [index of medical entity], "label" ]]

synthetic_data_export = [[],[],[]]

results = mc.train_from_json(mctrainer_export, save_dir_path=save_dir_path,data_oversampled=synthetic_data_export)

# Save results
json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results.json'), 'w'))