In [None]:
import os
import json
import pandas as pd
from datetime import date
from medcat.cat import CAT
from medcat.stats.stats import get_stats

In [None]:
# will be used to date the trained model
today = str(date.today())
today = today.replace("-","")

In [None]:
ann_dir = "working_with_cogstack/data/annotated_docs/"
mctrainer_export_path = ann_dir + "MedCAT_Export_With_Text_2021-08-25_19_55_45.json"  # name of your mct export

model_dir = 'working_with_cogstack/models/modelpack'

modelpack = '' # name of modelpack
model_pack_path = os.path.join(model_dir, modelpack)

output_modelpack = model_dir + f"{today}_trained_model"

# Add training filter if needed
snomed_filter_path = None  # path to snomed filter

In [None]:
# Create CAT - the main class from medcat used for concept annotation
cat = CAT.load_model_pack(model_pack_path)
cat.config.components.linking.filters.cuis = set()  # To remove exisitng filters

# Set filter

This will speed up the training time. As you will only train a select number of concepts at once.

In [None]:
# Add extra training filter if required.
if snomed_filter_path:
    snomed_filter = set(json.load(open(snomed_filter_path)))
else:
    snomed_filter = set(cat.cdb.cui2info.keys())


# Train

In [None]:
import json
with open(mctrainer_export_path) as f:
    data = json.load(f)
cat.trainer.train_supervised_raw(
    data=data, 
    nepochs=3,
    reset_cui_count=False,
    print_stats=True,
    use_filters=True,
    extra_cui_filter=snomed_filter,  # If not filter is set remove this line
)


# Stats

In [None]:
fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = get_stats(cat, data, use_project_filters=True)

In [None]:
cui = "22298006" # Myocardial infarction
print(cui_f1[cui], cui_prec[cui], cui_rec[cui])

# Save

Also remember that you can save specific components within the modelpack. Rather than create a new one

In [None]:
# save modelpack
cat.save_model_pack(os.path.join(model_dir, output_modelpack))

# Test

In [None]:
text = "The pateint has hypertension and an MI"
doc = cat.get_entities(text)

In [None]:
doc