The current implementation offers support for HF LLama models and BERT models.
We will cover only BERT in this section as the Llama usage is the same, just different imports.

In [4]:
# Install medcat
! pip install medcat~=1.16.0

[0m[31mERROR: Could not find a version that satisfies the requirement medcat~=1.16.0 (from versions: 0.1.4, 0.1.5, 0.1.6, 0.1.7, 0.1.7.1, 0.1.8, 0.1.9, 0.1.9.2, 0.1.9.3, 0.1.9.4, 0.1.9.5, 0.1.9.6, 0.1.9.7, 0.1.9.9, 0.2.0.0, 0.2.0.1, 0.2.0.2, 0.2.0.3, 0.2.0.4, 0.2.0.5, 0.2.0.6, 0.2.0.7, 0.2.1, 0.2.2, 0.2.3, 0.2.3.1, 0.2.3.2, 0.2.3.3, 0.2.3.4, 0.2.3.5, 0.2.3.6, 0.2.3.7, 0.2.4.0, 0.2.4.1, 0.2.4.2, 0.2.4.3, 0.2.4.4, 0.2.4.5, 0.2.4.6, 0.2.4.7, 0.2.4.8, 0.2.4.9, 0.2.5.0, 0.2.5.1, 0.2.5.2, 0.2.5.3, 0.2.5.4, 0.2.5.5, 0.2.5.6, 0.2.5.7, 0.2.5.8, 0.2.5.9, 0.2.6.0, 0.2.6.1, 0.2.6.2, 0.2.6.3, 0.2.6.4, 0.2.6.5, 0.2.6.7, 0.2.6.8, 0.2.6.9, 0.2.7.0, 0.2.7.1, 0.2.7.2, 0.2.7.3, 0.2.7.4, 0.2.7.6, 0.2.7.7, 0.2.7.8, 0.2.7.9, 0.2.8.0, 0.2.8.1, 0.2.8.2, 0.2.8.3, 0.2.8.4, 0.2.8.5, 0.2.8.6, 0.2.8.7, 0.2.8.8, 0.2.8.9, 0.2.9.0, 0.2.9.1, 0.2.9.2, 0.2.9.3, 0.2.9.4, 0.2.9.5, 0.2.9.6, 0.2.9.7, 0.2.9.8, 0.2.9.9, 0.3.0.0, 0.3.0.1, 0.3.0.2, 0.3.0.3, 0.3.0.4, 0.3.0.5, 0.3.0.6, 0.3.0.7, 0.3.0.8, 0.3.0.9, 0.3.1.0, 0.3.1.

In [2]:
import logging
from medcat.cdb import CDB
from medcat.config_rel_cat import ConfigRelCAT
from medcat.rel_cat import RelCAT
from medcat.utils.relation_extraction.base_component import BaseComponent_RelationExtraction
from medcat.utils.relation_extraction.bert.model import BaseModel_RelationExtraction
from medcat.utils.relation_extraction.bert.config import BaseConfig_RelationExtraction
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction

<h1>Training RelCAT models with custom datasets from scratch.</h1>
<h2>1. create the RelCAT config and set the parameters</h2>

In [3]:
config = ConfigRelCAT()
config.general.log_level = logging.INFO
config.general.model_name = "bert-base-uncased" # base model that you want to use, we're going to use the HuggingFace bert-base-uncased model

<h3> 1.1 Based on what model you use, you might want to keep an eye on config.model.hidden_size, config.model.model_size and config.model.hidden_layers</h3>

In [4]:
config.model.hidden_size= 256
config.model.model_size = 2304 # 4096 for llama

<h3> 1.2 Other notable configurations</h3>

In [5]:
config.general.cntx_left = 15 # how many tokens to the left of the start entity we select
config.general.cntx_right = 15 # how many tokens to the right of the end entity we selecd
config.general.window_size = 300 # distance (in characters) between two entities to be considered a relation
config.train.nclasses = 2 # number of classes in your medcat export / dataset
config.train.nepochs = 10 # number of epochs to train for
config.model.freeze_layers = False # whether to freeze the layers of the base model
config.general.limit_samples_per_class = 300 # limit the number of training samples per class to this number, to avoid overfitting in unbalanced datasets
config.train.batch_size = 32 # batch size
config.train.lr = 3e-5
config.train.adam_epsilon = 1e-8
config.train.adam_weight_decay = 0.0005

<h2>2. create a CDB, it can be a CDB from another model of your choice or an empty one.
The CDB is used only when filtering by concept unique identifiers (CUI) or concept type ids (TUI).

In [6]:
cdb = CDB()

<h2>3. Create a tokenizer

In [7]:
tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=config.general.model_name,
                                                                           relcat_config=config)   

<h2>4. Add token tags to tokenizer.
 This step is optional because the [s1], [e1], [s2], [e2] tags are already located in the default RelCATConfig.
 If you are using a LLama based model, you will need to add the [PAD] token to the tokenizer, as shown below.

In [8]:
special_ent_tokens = ["[s1]", "[e1]", "[s2]", "[e2]"]
tokenizer.hf_tokenizers.add_tokens(special_ent_tokens, special_tokens=True)
tokenizer.hf_tokenizers.add_special_tokens({'pad_token': '[PAD]'}) # used in llama tokenizer

0

<h2>5. Add tokens to the RelCATConfig

In [9]:
config.general.tokenizer_relation_annotation_special_tokens_tags = special_ent_tokens
config.general.annotation_schema_tag_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(special_ent_tokens)

<h2>6. Create the relCAT object and initialize its components</h2>

In [10]:
# if you wish to skip the steps in section 6.1 you can pass the init_model=True arguement to intialize the components with the default ConfigRelCAT settings.
relCAT = RelCAT(cdb, config=config)

INFO:medcat.utils.relation_extraction.base_component:BaseComponent_RelationExtraction initialized


<h3>6.1 Use the BaseComponent object, this one holds the tokenizer, model and model config. We will have to initialize each component beforehand.</h3>

<p>Resize token embeddings since we added the tokens before, this should be done after adding tokens to the tokenizer. It is not required after creating and saving/loading a model as the value will be retained.</p>

In [11]:
model_config = BaseConfig_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                                                   relcat_config=config)

# update the model config with the proper vocab size, since we added special tokens to the tokenizer
model_config.hf_model_config.vocab_size = tokenizer.get_size()

# set the padding idx in the model config and relcat config, this is necesasry as it depends on what tokenizer you use
config.model.padding_idx = model_config.pad_token_id = tokenizer.get_pad_id()

model = BaseModel_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                                                   model_config=model_config,
                                                                   relcat_config=config)

# we have to update the model to reflect the new token embeddings, since we added special tokens to the tokenizer
model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) # type: ignore

component = BaseComponent_RelationExtraction(tokenizer=tokenizer, config=config)
component.model = model
component.model_config = model_config
component.relcat_config = config
component.tokenizer = tokenizer

relCAT.component = component

You are using a model of type bert to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:medcat.utils.relation_extraction.config:Loaded config from : bert-base-uncased/model_config.json
INFO:medcat.utils.relation_extraction.models:RelCAT model config: PretrainedConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.51.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30526
}

INFO:medcat.utils.relation_extraction.bert.model:RelCAT model config: PretrainedConfig 

<h2> 7. Train the model from the ADE dataset. </h2>

In [12]:
! rm -rf "./ade_relcat_model"
! mkdir -p "./ade_relcat_model"

In [13]:
relCAT.train(train_csv_path="./data/rel_cat_ADE_V2.tsv", checkpoint_path="./ade_relcat_model")

# for MedCAT Trainer Exports, use the export_path argument : relCAT.train(export_data_path="./data/MedCAT_Export_relation_extraction.json")


INFO:medcat.utils.relation_extraction.rel_dataset:CSV dataset | No. of relations detected:7093| from : ./data/rel_cat_ADE_V2.tsv | nclasses: 2 | idx2label: {0: 'DRUG-DOSE', 1: 'DRUG-AE'}
INFO:medcat.utils.relation_extraction.rel_dataset:Samples per class: 
INFO:medcat.utils.relation_extraction.rel_dataset: label: DRUG-DOSE | samples: 279
INFO:medcat.utils.relation_extraction.rel_dataset: label: DRUG-AE | samples: 6814
INFO:root:Relations after train, test split :  train - 524 | test - 115
INFO:root: label: DRUG-AE samples | train 300 | test 60
INFO:root: label: DRUG-DOSE samples | train 224 | test 55
INFO:root:Attempting to load RelCAT model on device: cpu
INFO:medcat.rel_cat:Starting training process...
INFO:medcat.rel_cat:Total epochs on this model: 10 | currently training epoch 0
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- E

In [14]:
# save the model
relCAT.save(save_path="./ade_relcat_model")