Skip to content

feat(medcat):CU-869cy3xa0 Improve training#414

Merged
mart-r merged 15 commits into
mainfrom
feat/medcat/CU-869cy3xa0-specify-unsupervised-training-in-trainable-component-protocol
May 13, 2026
Merged

feat(medcat):CU-869cy3xa0 Improve training#414
mart-r merged 15 commits into
mainfrom
feat/medcat/CU-869cy3xa0-specify-unsupervised-training-in-trainable-component-protocol

Conversation

@mart-r
Copy link
Copy Markdown
Collaborator

@mart-r mart-r commented Apr 21, 2026

This PR does an overhaul to the training setup of MedCAT:

  • It modifies the existing TrainableComponent protocol to also include a train_unsupervised method
    • And uses that over the "check config for train and run inference" unsupervised training
  • It allows all components that follow the TrainableComponent protocol to be trained supervised
    • Previously only the linker was able to be trained in a supervised manner
  • It provides a few utility methods to allow training and evaluating components individually
    • I.e dataset-aware components that will enable either training or evaluating only NER or Linker if/when required

Example code snippets:

  1. When only training linker
with dataset_aware_component(cat, CoreComponentType.ner, DATASET):
    trainer.train_supervised_raw(DATASET, nepochs=1)
  1. When only training NER
with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET):
    trainer.train_unsupervised([doc['text'] for proj in self.DATASET['projects'] for doc in proj['documents']], nepochs=1)
  1. When doing evaluation / stats one component at a time
with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET):
    tps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = get_stats(
        cat, self.DATASET, do_print=False)

Comment thread medcat-v2/medcat/trainer.py
@adam-sutton-1992
Copy link
Copy Markdown
Contributor

A few queries but I think it looks good. I might've missed these within the commits:

If you have two trainable components. is it possible to turn of training for one of them when running training methods? Do the dataset aware components serve that purpose?

And one more above^^^

@mart-r
Copy link
Copy Markdown
Collaborator Author

mart-r commented Apr 25, 2026

If you have two trainable components. is it possible to turn of training for one of them when running training methods? Do the dataset aware components serve that purpose?

The description already had 2 examples for this :)

The dataset aware implementation can serve that purpose. Because they replace the specific component with another one (which isn't trainable, but that's kind of irrelevant since it's a different component) for the duration of the context manager.

But I think what makes it unclear is that in the example I've given it a dataset, but realistically, you could provide an empty dataset for it, i.e like this:

# supervised
with dataset_aware_component(cat, CoreComponentType.ner, {"projects" : []}):
    trainer.train_supervised_raw(DATASET, nepochs=1)
# unsupervised
with dataset_aware_component(cat, CoreComponentType.ner, {"projects" : []}):
    trainer.train_unsupervisedsupervised(["list", "of", "texts'], nepochs=1)

@adam-sutton-1992 adam-sutton-1992 self-assigned this May 13, 2026
Copy link
Copy Markdown
Contributor

@adam-sutton-1992 adam-sutton-1992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the one comment:

Comment on lines +254 to +278
def test_train_supervised_can_train_only_linker_when_ner_is_cheating(self):
ner = _TrainableNER()
linker = _TrainablePassThroughLinker()
cat = _FakeCat(self.DATASET, [ner, linker])
trainer = Trainer(cat.cdb, cat.__call__, cat.pipe)

with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}):
with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET):
trainer.train_supervised_raw(self.DATASET, disable_progress=True)

self.assertEqual(ner.sup_train_calls, 0)
self.assertEqual(linker.sup_train_calls, 1)

def test_train_supervised_can_train_only_ner_when_linker_is_cheating(self):
ner = _TrainableNER()
linker = _TrainablePassThroughLinker()
cat = _FakeCat(self.DATASET, [ner, linker])
trainer = Trainer(cat.cdb, cat.__call__, cat.pipe)

with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}):
with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET):
trainer.train_supervised_raw(self.DATASET, disable_progress=True)

self.assertEqual(ner.sup_train_calls, 1)
self.assertEqual(linker.sup_train_calls, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not train both at the same time?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can. But the point is that you don't have to! I.e flexibiliy.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missread this. assume you intended for only one or the other.

@mart-r mart-r merged commit de36124 into main May 13, 2026
22 checks passed
@mart-r mart-r deleted the feat/medcat/CU-869cy3xa0-specify-unsupervised-training-in-trainable-component-protocol branch May 13, 2026 13:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants