# Evaluate a MedCATtrainer project export

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Only use for development
%load_ext autoreload
%autoreload 2

In [None]:
from mct_analysis import MedcatTrainer_export

## Load MCT exports and MedCAT model

In [None]:
lst_mct_export=[
    '../../data/medcattrainer_export/20220817_KCH_export/MedCAT_Export_With_Text_2022-08-17_13_34_54.json',
    '../../data/medcattrainer_export/20220817_KCH_export/all_from_medcat_pc_fixed_MedCAT_Export_With_Text_2021-01-15_22_01_45_correct.json'
               ]
#lst_mct_export = ['../../data/medcattrainer_export/'+'MedCAT_Export_With_Text_2022-07-20_10_58_45.json']  # mct_export .json here
mct_model = "/Users/shek/Documents/medcat_models/medcat_model_pack_v1.2.8"
mct = MedcatTrainer_export(mct_export_paths=lst_mct_export, model_pack_path= mct_model)
#mct = MedcatTrainer_export(lst_mct_export)

# Evaluate model card

In [None]:
# Load the model card
mct.cat.get_model_card(as_dict=True)

In [None]:
# look to potentially remove any filters that exisit in the model
"""
mct.cat.config.linking['filters']
"""

# Evaluate MCT export

### View all Annotations and Meta-annotations created

In [None]:
# Load all annotations created
anns_df = mct.annotation_df()
anns_df

### Summarise all Meta-annotations

In [None]:
# Meta_annotation summary
for col in anns_df.loc[:,'acc':].iloc[:,1:]:
    print(anns_df[col].value_counts())

In [None]:
# Meta_annotation summary of combinations
for k,v in anns_df.loc[:,'acc':].iloc[:,1:].value_counts().iteritems():
    print(k,v)

### Overview of the entire MCT export
This includes all names of all projects within the export and the document ids.

In [None]:
# projects
anns_df['project'].unique()

In [None]:
# documents
anns_df['document_name'].unique()

# Annotation Summary

In [None]:
### to delete
# del mct.mct_export['projects'][8]
# 

In [None]:
performance_summary_df = mct.concept_summary()

In [None]:
performance_summary_df

# Annotator stats

In [None]:
# User Stats
mct.user_stats()

In [None]:
mct.plot_user_stats(save_fig=True, save_fig_filename='20220817_KCH_user_mct_annotations.html')

### Generate report
All of the above functions added into a single Excel file report

In [None]:
help(mct.generate_report)

In [None]:
mct.generate_report(path='20220817_KCH_mct_report.xlsx')

# Meta Annotations

helper function to rename meta_task and meta_task values.

### Rename meta annotation tasks

In [None]:
# select which meta tasks to rename
rename_meta_anns = {'Subject/Experiencer':'Subject'}
# select which meta values for the corresponding meta tasks.
rename_meta_anns_values = {'Subject':{'Relative':'Other'}}
# run the renaming
mct.rename_meta_anns(meta_anns2rename=rename_meta_anns, meta_ann_values2rename=rename_meta_anns_values)

In [None]:
anns_df = mct.annotation_df()
anns_df.head()

### Performance evaluation

In [None]:
mct.cat.get_model_card(as_dict=True)

In [None]:
# Check meta models
meta_models = list(mct.cat.get_model_card(as_dict=True)['MetaCAT models'].keys())
meta_models

In [None]:
for meta_ann in meta_models:
    print(meta_ann, anns_df[meta_ann].unique())

## Meta annotation performance summary

In [None]:
import json

import torch
import math
from torch import nn
import numpy as np
import pandas as pd
from typing import List, Optional, Tuple, Any, Dict
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
from sklearn.metrics import classification_report, precision_recall_fscore_support
from medcat.utils.meta_cat.ml_utils import eval_model

In [None]:
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values

In [None]:
def create_batch_piped_data(data: List, start_ind: int, end_ind: int, device: torch.device, pad_id: int) -> Tuple:
    r''' Creates a batch given data and start/end that denote batch size, will also add
    padding and move to the right device.
    Args:
        data (List[List[int], int, Optional[int]]):
            Data in the format: [[<[input_ids]>, <cpos>, Optional[int]], ...], the third column is optional
            and represents the output label
        start_ind (`int`):
            Start index of this batch
        end_ind (`int`):
            End index of this batch
        device (`torch.device`):
            Where to move the data
        pad_id (`int`):
            Padding index
    Returns:
        x ():
            Same as data, but subsetted and as a tensor
        cpos ():
            Center positions for the data
    '''
    max_seq_len = max([len(x[0]) for x in data])
    x = [x[0][0:max_seq_len] + [pad_id]*max(0, max_seq_len - len(x[0])) for x in data[start_ind:end_ind]]
    cpos = [x[1] for x in data[start_ind:end_ind]]
    y = None
    if len(data[0]) == 3:
        # Means we have the y column
        y = torch.tensor([x[2] for x in data[start_ind:end_ind]], dtype=torch.long).to(device)

    x = torch.tensor(x, dtype=torch.long).to(device)
    cpos = torch.tensor(cpos, dtype=torch.long).to(device)

    return x, cpos, y


def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: TokenizerWrapperBase) -> Dict:
    r''' Evaluate a trained model on the provided data
    Args:
        model
        data
        config
    '''
    device = torch.device(config.general['device']) # Create a torch device
    batch_size_eval = config.general['batch_size_eval']
    pad_id = config.model['padding_idx']
    ignore_cpos = config.model['ignore_cpos']
    class_weights = config.train['class_weights']

    if class_weights is not None:
        class_weights = torch.FloatTensor(class_weights).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights) # Set the criterion to Cross Entropy Loss
    else:
        criterion = nn.CrossEntropyLoss() # Set the criterion to Cross Entropy Loss

    y_eval = [x[2] for x in data]
    num_batches = math.ceil(len(data) / batch_size_eval)
    running_loss = []
    all_logits = []
    model.to(device)
    model.eval()

    with torch.no_grad():
        for i in range(num_batches):
            x, cpos, y = create_batch_piped_data(data, i*batch_size_eval, (i+1)*batch_size_eval, device=device, pad_id=pad_id)
            logits = model(x, cpos, ignore_cpos=ignore_cpos)
            loss = criterion(logits, y)

            # Track loss and logits
            running_loss.append(loss.item())
            all_logits.append(logits.detach().cpu().numpy())

    #print_report(0, running_loss, all_logits, y=y_eval, name='Eval')

    score_average = config.train['score_average']
    predictions = np.argmax(np.concatenate(all_logits, axis=0), axis=1)
    precision, recall, f1, support = precision_recall_fscore_support(y_eval, predictions, average=score_average)
    
    """examples: Dict = {'FP': {}, 'FN': {}, 'TP': {}}
    id2category_value = {v: k for k, v in config.general['category_value2id'].items()}
    for i, p in enumerate(predictions):
        y = id2category_value[y_eval[i]]
        p = id2category_value[p]
        c = data[i][1]
        tkns = data[i][0]
        assert tokenizer.hf_tokenizers is not None
        text = tokenizer.hf_tokenizers.decode(tkns[0:c]) + " <<"+ tokenizer.hf_tokenizers.decode(tkns[c:c+1]).strip() + ">> " + \
            tokenizer.hf_tokenizers.decode(tkns[c+1:])
        info = "Predicted: {}, True: {}".format(p, y)
        if p != y:
            # We made a mistake
            examples['FN'][y] = examples['FN'].get(y, []) + [(info, text)]
            examples['FP'][p] = examples['FP'].get(p, []) + [(info, text)]
        else:
            examples['TP'][y] = examples['TP'].get(y, []) + [(info, text)]
"""
    return predictions#{'predictions':predictions,'precision': precision, 'recall': recall, 'f1': f1, 'examples': examples}


In [None]:
def eval(metacat_model, mct_export):    
    g_config = metacat_model.config.general
    t_config = metacat_model.config.train
    #t_config['test_size'] = 0
    t_config['shuffle_data']= False
    t_config['prerequisites']={}
    t_config['cui_filter']={}    
                
    with open(mct_export, 'r') as f:
        data_loaded: Dict = json.load(f)

    # Prepare the data
    assert metacat_model.tokenizer is not None
    data = prepare_from_json(data_loaded, g_config['cntx_left'], g_config['cntx_right'], metacat_model.tokenizer,
                             cui_filter=t_config['cui_filter'],
                             replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'],
                             lowercase=g_config['lowercase'])

    # Check is the name there
    category_name = g_config['category_name']
    if category_name not in data:
        warnings.warn(f"The meta_model {category_name} does not exist in this MedCATtrainer export.", UserWarning)
        return {category_name:f"{category_name} does not exist"}

    data = data[category_name]

    # We already have everything, just get the data
    category_value2id = g_config['category_value2id']
    data, _ = encode_category_values(data, existing_category_value2id=category_value2id)
    print(_)
    print(len(data))
    # Run evaluation
    assert metacat_model.tokenizer is not None
    result = eval_model(metacat_model.model, data, config=metacat_model.config, tokenizer=metacat_model.tokenizer)

    return {'predictions': result, 'meta_values':_}


In [None]:
meta_df=anns_df[(anns_df['validated']==True)&(anns_df['deleted']==False)&(anns_df['killed']==False)&(anns_df['irrelevant']==False)]
meta_df=meta_df.reset_index(drop=True)
for meta_model in meta_models:
    print(f'Checking metacat model: {meta_model}')
    _meta_model = MetaCAT.load(mct_model+'/meta_'+meta_model)
    meta_results=eval(_meta_model, '../../data/medcattrainer_export/test.json')
    _meta_values = { v:k for k,v in meta_results['meta_values'].items()}
    print(_meta_values)
    pred_meta_values = []
    counter = 0
    for meta_value in meta_df[meta_model]:
        if pd.isnull(meta_value):
            pred_meta_values.append(np.nan)
        else:
            pred_meta_values.append(_meta_values.get(meta_results['predictions'][counter],np.nan))
            counter+=1
    meta_df.insert(meta_df.columns.get_loc(meta_model)+1,'predict_'+meta_model,pred_meta_values)
    

In [None]:
meta_df

In [None]:
meta_df['Subject'].value_counts()

In [None]:
meta_df['predict_Subject'].value_counts()

### Meta annotation summary stats 

In [None]:
meta_df[meta_task].unique()

In [None]:
meta_task = 'Presence'
meta_task_values = meta_df[meta_task].unique()

In [None]:
meta_df[meta_df[meta_task] == 'True']

In [None]:
meta_task_values[0]

In [None]:
for task in meta_task_values:
    temp_df = meta_df[meta_df[meta_task] == task][[meta_task,'predict_'+meta_task]]

In [None]:
temp_df

In [None]:
meta_df

In [None]:
task

### Junk

In [None]:
_meta_model = MetaCAT.load(mct_model+'/meta_'+meta_model)
    
meta_results=eval(_meta_model, '../../data/medcattrainer_export/test.json')

In [None]:
meta_results

In [None]:
pred_meta

In [None]:
#########
meta_df=anns_df[(anns_df['validated']==True)&(anns_df['deleted']==False)&(anns_df['killed']==False)&(anns_df['irrelevant']==False)]


In [None]:
test_results = list(results[0].values())[0]['predictions']



classe = {"False": 2,"Hypothetical": 1,"True": 0}
classes = { v:k for k,v in classe.items()}
pred_meta = []

counter = 0
for meta_value in meta_df['Presence']:
    if pd.isnull(meta_value):
        pred_meta.append(np.nan)
    else:
        pred_meta.append(classes.get(test_results[counter],np.nan))
        counter+=1

In [None]:
pred_meta

In [None]:
counter

In [None]:
classe = {"False": 2,"Hypothetical": 1,"True": 0}
classes = { v:k for k,v in classe.items()}
test = [classes.get(i, np.nan) for i in list(results[0].values())[0]['predictions'])]

In [None]:
results

In [None]:
print(len(list(results[0].values())[0]['predictions']))
print(len(list(results[1].values())[0]['predictions']))
print(len(list(results[2].values())[0]['predictions']))

In [None]:
print(len(list(results[0].values())[0]['y_eval']))
print(len(list(results[1].values())[0]['y_eval']))
print(len(list(results[2].values())[0]['y_eval']))

In [None]:
print(anns_df['Presence'].unique())
print(anns_df['Subject'].unique())
print(anns_df['Time'].unique())

In [None]:
print(anns_df['Presence'].isnull().sum())
print(anns_df['Subject'].isnull().sum())
print(anns_df['Time'].isnull().sum())

In [None]:
len(anns_df['Subject'])

In [None]:
print(len(list(results[0].values())[0]['predictions'])-anns_df['Presence'].isnull().sum())
print(len(list(results[1].values())[0]['predictions'])-anns_df['Subject'].isnull().sum())
print(len(list(results[2].values())[0]['predictions'])-anns_df['Time'].isnull().sum())

In [None]:
print(1015-test_df['Presence'].isnull().sum())
print(1015-test_df['Subject'].isnull().sum())
print(1015-test_df['Time'].isnull().sum())

In [None]:
_meta_model = MetaCAT.load(mct_model+'/meta_'+'Presence')

In [None]:
_meta_model.config.general

In [None]:
_meta_model.config.train

In [None]:
_meta_model.config.model

In [None]:
_meta_model.config.general['device'] # Create a torch device
_meta_model.config.general['batch_size_eval']
_meta_model.config.model['padding_idx']
_meta_model.config.model['ignore_cpos']
_meta_model.config.train['class_weights']

In [None]:
classe = {"False": 2,"Hypothetical": 1,"True": 0}
classes = { v:k for k,v in classe.items()}
test = [classes.get(i, np.nan) for i in presence_results]

In [None]:
classes.get(1, 'Nan')

In [None]:
test = [classes.get(i, np.nan) for i in presence_results]

In [None]:
test

In [None]:
presence_results

In [None]:
len(anns_df['Presence'])-878

In [None]:
anns_df[(anns_df['correct']==True)|(anns_df['alternative']==True)&(anns_df['deleted']==False)&(anns_df['killed']==False)&(anns_df['irrelevant']==False)]

In [None]:
anns_df['alternative'].value_counts()

In [None]:
anns_df['correct'].value_counts()

In [None]:
anns_df['deleted'].value_counts()

In [None]:
anns_df['killed'].value_counts()

In [None]:
anns_df['irrelevant'].value_counts()

In [None]:
list(results[0].keys())[0]

In [None]:
def prepare_from_json(data: Dict,
                      cntx_left: int,
                      cntx_right: int,
                      tokenizer: TokenizerWrapperBase,
                      cui_filter: Optional[set] = None,
                      replace_center: Optional[str] = None,
                      prerequisites: Dict = {},
                      lowercase: bool = True) -> Dict:
    """ Convert the data from a json format into a CSV-like format for training. This function is not very efficient (the one
    working with spacy documents as part of the meta_cat.pipe method is much better). If your dataset is > 1M documents think
    about rewriting this function - but would be strange to have more than 1M manually annotated documents.

    Args:
        data (`dict`):
            Loaded output of MedCATtrainer. If we have a `my_export.json` from MedCATtrainer, than data = json.load(<my_export>).
        cntx_left (`int`):
            Size of context to get from the left of the concept
        cntx_right (`int`):
            Size of context to get from the right of the concept
        tokenizer (`medcat.tokenizers.meta_cat_tokenizers`):
            Something to split text into tokens for the LSTM/BERT/whatever meta models.
        replace_center (`str`, optional):
            If not None the center word (concept) will be replaced with whatever this is.
        prerequisites (`dict`, optional):
            A map of prerequisities, for example our data has two meta-annotations (experiencer, negation). Assume I want to create
            a dataset for `negation` but only in those cases where `experiencer=patient`, my prerequisites would be:
                {'Experiencer': 'Patient'} - Take care that the CASE has to match whatever is in the data
        lowercase (`bool`, defaults to True):
            Should the text be lowercased before tokenization

    Returns:
        out_data (`dict`):
            Example: {'category_name': [('<category_value>', '<[tokens]>', '<center_token>'), ...], ...}
    """
    out_data: Dict = {}
    for project in data['projects']:
        for document in project['documents']:
            text = str(document['text'])
            if lowercase:
                text = text.lower()

            if len(text) > 0:
                doc_text = tokenizer(text)

                for ann in document.get('annotations', document.get('entities', {}).values()): # A hack to suport entities and annotations
                    cui = ann['cui']
                    skip = False
                    if 'meta_anns' in ann and prerequisites:
                        # It is possible to require certain meta_anns to exist and have a specific value
                        for meta_ann in prerequisites:
                            if meta_ann not in ann['meta_anns'] or ann['meta_anns'][meta_ann]['value'] != prerequisites[meta_ann]:
                                # Skip this annotation as the prerequisite is not met
                                skip = True
                                break

                    if not skip and (cui_filter is None or not cui_filter or cui in cui_filter):
                        if ann.get('validated', True) and (not ann.get('deleted', False) and not ann.get('killed', False)
                                                           and not ann.get('irrelevant', False)):
                            start = ann['start']
                            end = ann['end']

                            # Get the index of the center token
                            ind = 0
                            for ind, pair in enumerate(doc_text['offset_mapping']):
                                if start >= pair[0] and start < pair[1]:
                                    break

                            _start = max(0, ind - cntx_left)
                            _end = min(len(doc_text['input_ids']), ind + 1 + cntx_right)
                            tkns = doc_text['input_ids'][_start:_end]
                            cpos = cntx_left + min(0, ind-cntx_left)

                            if replace_center is not None:
                                if lowercase:
                                    replace_center = replace_center.lower()
                                for p_ind, pair in enumerate(doc_text['offset_mapping']):
                                    if start >= pair[0] and start < pair[1]:
                                        s_ind = p_ind
                                    if end > pair[0] and end <= pair[1]:
                                        e_ind = p_ind

                                ln = e_ind - s_ind
                                tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos+ln+1:]

                            # Backward compatibility if meta_anns is a list vs dict in the new approach
                            meta_anns = []
                            if 'meta_anns' in ann:
                                meta_anns = ann['meta_anns'].values() if type(ann['meta_anns']) == dict else ann['meta_anns']

                            # If the annotation is validated
                            for meta_ann in meta_anns:
                                name = meta_ann['name']
                                value = meta_ann['value']

                                sample = [tkns, cpos, value]

                                if name in out_data:
                                    out_data[name].append(sample)
                                else:
                                    out_data[name] = [sample]

    return out_data

In [None]:
test = prepare_from_json(mct.mct_export, cntx_left=_meta_model.config.general['cntx_left'],
                         cntx_right= _meta_model.config.general['cntx_right'],
                         tokenizer=_meta_model.tokenizer, cui_filter={},prerequisites={})


In [None]:
len(test['Subject'])

In [None]:
len(test['Presence'])

In [None]:
len(test['Time'])

In [None]:
mct.mct_export['projects']

In [None]:
_meta_model.config.train

In [None]:
anns_df['Subject/Experiencer'].unique()

In [None]:
meta_model

In [None]:
_meta_model = MetaCAT.load(mct_model+'/meta_'+meta_model)
_meta_model.config.train['prerequisites']

In [None]:
_meta_model.config.train

In [None]:
_meta_model.config.general

In [None]:
g_config = metacat_model.config.general
t_config = metacat_model.config.train

In [None]:
g_config = _.config.general
t_config = _.config.train
p =_.config.train['prerequisites']={}
data = prepare_from_json(mct.mct_export, g_config['cntx_left'], g_config['cntx_right'], _.tokenizer,
                         cui_filter=t_config['cui_filter'], replace_center=g_config['replace_center'], prerequisites=p,
                         lowercase=g_config['lowercase'])

In [None]:
data

In [None]:
_.config.train['prerequisites']

In [None]:
test_Presence = MetaCAT.load(mct_model+'/meta_'+meta_models[0])

In [None]:
test_Presence.config.train['prerequisites']={}
results_presence = test_Presence.eval(mct_export)

In [None]:
results_presence

In [None]:
test_subject = MetaCAT.load(mct_model+'/meta_'+meta_models[1])

In [None]:
test_subject

In [None]:
test_subject = MetaCAT.load(mct_model+'/meta_'+meta_models[1])
test_subject.config.general['category_name'] = 'Subject/Experiencer'
test_subject.config.train['prerequisites']={}
results_subject = test_subject.eval(mct_export)

In [None]:
test_time = MetaCAT.load(mct_model+'/meta_'+meta_models[2])

In [None]:
test_time.config.train['prerequisites']={}
results_time =test_time.eval(mct_export)

In [None]:
results_time

In [None]:
meta_models[2] == test_time.config.general['category_name']

In [None]:
test_time

In [None]:
test_time.config.general['category_name']

In [None]:
import json
from medcat.utils.meta_cat.ml_utils import eval_model

In [None]:
ml_utils.eval_model(model=test_time.model,
                    config=test_time.config,
                    data=mct_export,
                    tokenizer=test_time.tokenizer)

In [None]:
def prepare_from_json(data: Dict,
                      cntx_left: int,
                      cntx_right: int,
                      tokenizer: TokenizerWrapperBase,
                      cui_filter: Optional[set] = None,
                      replace_center: Optional[str] = None,
                      prerequisites: Dict = {},
                      lowercase: bool = True) -> Dict:
    """ Convert the data from a json format into a CSV-like format for training. This function is not very efficient (the one
    working with spacy documents as part of the meta_cat.pipe method is much better). If your dataset is > 1M documents think
    about rewriting this function - but would be strange to have more than 1M manually annotated documents.

    Args:
        data (`dict`):
            Loaded output of MedCATtrainer. If we have a `my_export.json` from MedCATtrainer, than data = json.load(<my_export>).
        cntx_left (`int`):
            Size of context to get from the left of the concept
        cntx_right (`int`):
            Size of context to get from the right of the concept
        tokenizer (`medcat.tokenizers.meta_cat_tokenizers`):
            Something to split text into tokens for the LSTM/BERT/whatever meta models.
        replace_center (`str`, optional):
            If not None the center word (concept) will be replaced with whatever this is.
        prerequisites (`dict`, optional):
            A map of prerequisities, for example our data has two meta-annotations (experiencer, negation). Assume I want to create
            a dataset for `negation` but only in those cases where `experiencer=patient`, my prerequisites would be:
                {'Experiencer': 'Patient'} - Take care that the CASE has to match whatever is in the data
        lowercase (`bool`, defaults to True):
            Should the text be lowercased before tokenization

    Returns:
        out_data (`dict`):
            Example: {'category_name': [('<category_value>', '<[tokens]>', '<center_token>'), ...], ...}
    """
    out_data: Dict = {}

    for project in data['projects']:
        for document in project['documents']:
            text = str(document['text'])
            if lowercase:
                text = text.lower()

            if len(text) > 0:
                doc_text = tokenizer(text)

                for ann in document.get('annotations', document.get('entities', {}).values()): # A hack to suport entities and annotations
                    cui = ann['cui']
                    skip = False
                    if 'meta_anns' in ann and prerequisites:
                        # It is possible to require certain meta_anns to exist and have a specific value
                        for meta_ann in prerequisites:
                            if meta_ann not in ann['meta_anns'] or ann['meta_anns'][meta_ann]['value'] != prerequisites[meta_ann]:
                                # Skip this annotation as the prerequisite is not met
                                skip = True
                                break

                    if not skip and (cui_filter is None or not cui_filter or cui in cui_filter):
                        if ann.get('validated', True) and (not ann.get('deleted', False) and not ann.get('killed', False)
                                                           and not ann.get('irrelevant', False)):
                            start = ann['start']
                            end = ann['end']

                            # Get the index of the center token
                            ind = 0
                            for ind, pair in enumerate(doc_text['offset_mapping']):
                                if start >= pair[0] and start < pair[1]:
                                    break

                            _start = max(0, ind - cntx_left)
                            _end = min(len(doc_text['input_ids']), ind + 1 + cntx_right)
                            tkns = doc_text['input_ids'][_start:_end]
                            cpos = cntx_left + min(0, ind-cntx_left)

                            if replace_center is not None:
                                if lowercase:
                                    replace_center = replace_center.lower()
                                for p_ind, pair in enumerate(doc_text['offset_mapping']):
                                    if start >= pair[0] and start < pair[1]:
                                        s_ind = p_ind
                                    if end > pair[0] and end <= pair[1]:
                                        e_ind = p_ind

                                ln = e_ind - s_ind
                                tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos+ln+1:]

                            # Backward compatibility if meta_anns is a list vs dict in the new approach
                            meta_anns = []
                            if 'meta_anns' in ann:
                                meta_anns = ann['meta_anns'].values() if type(ann['meta_anns']) == dict else ann['meta_anns']

                            # If the annotation is validated
                            for meta_ann in meta_anns:
                                name = meta_ann['name']
                                value = meta_ann['value']

                                sample = [tkns, cpos, value]

                                if name in out_data:
                                    out_data[name].append(sample)
                                else:
                                    out_data[name] = [sample]
    return out_data

In [None]:
g_config = test_time.config.general
t_config = test_time.config.train['prerequisites']

In [None]:
data = prepare_from_json(data_loaded, g_config['cntx_left'], g_config['cntx_right'], test_time.tokenizer,
                         cui_filter=t_config['cui_filter'],
                         replace_center=g_config['replace_center'], prerequisites={},
                         lowercase=g_config['lowercase'])

In [None]:
category_name = g_config['category_name']
if category_name not in data:
    raise Exception("The category name does not exist in this json file.")
data = data[category_name]

In [None]:
from medcat.utils.meta_cat.ml_utils import predict, train_model, set_all_seeds, eval_model

In [None]:
category_value2id = g_config['category_value2id']
data, _ = encode_category_values(data, existing_category_value2id=category_value2id)

# Run evaluation
assert test_time.tokenizer is not None
result = eval_model(test_time.model, data, config=test_time.config, tokenizer=test_time.tokenizer)
