Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions medcat-v2/medcat/config/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import (Optional, Iterator, Iterable, TypeVar, cast, Type, Any,
Literal)
from typing import Protocol, runtime_checkable
Expand All @@ -12,6 +13,9 @@
from medcat.utils.defaults import workers
from medcat.utils.envsnapshot import Environment, get_environment_info
from medcat.utils.iterutils import callback_iterator
from medcat.utils.defaults import (
avoid_legacy_conversion, doing_legacy_conversion_message,
LegacyConversionDisabledError)
from medcat.storage.serialisables import SerialisingStrategy
from medcat.storage.serialisers import deserialise

Expand Down Expand Up @@ -80,6 +84,13 @@ def merge_config(self, other: dict):

@classmethod
def load(cls, path: str) -> Self:
if os.path.isfile(path) and path.endswith(".dat"):
if avoid_legacy_conversion():
raise LegacyConversionDisabledError(cls.__name__)
doing_legacy_conversion_message(logger, cls.__name__, path)
from medcat.utils.legacy.convert_config import (
get_config_from_old_per_cls)
return cast(Self, get_config_from_old_per_cls(path, cls))
obj = deserialise(path)
if not isinstance(obj, cls):
raise ValueError(f"The path '{path}' is not a {cls.__name__}!")
Expand Down
34 changes: 33 additions & 1 deletion medcat-v2/medcat/utils/legacy/convert_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
from typing import Any, cast, Optional
from typing import Any, cast, Optional, Type
import logging

from pydantic import BaseModel

from medcat.config import Config

from medcat.utils.legacy.helpers import fix_old_style_cnf
from medcat.config.config import SerialisableBaseModel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -185,3 +186,34 @@ def get_config_from_old(path: str) -> Config:
with open(path) as f:
old_cnf_data = json.load(f)
return get_config_from_nested_dict(old_cnf_data)


def get_config_from_old_per_cls(
path: str, cls: Type[SerialisableBaseModel]) -> SerialisableBaseModel:
"""Convert the saved v1 config into a v2 Config for a specific class.

Args:
path (str): The v1 config path.
cls (Type[SerialisableBaseModel]): The class to convert to.

Returns:
SerialisableBaseModel: The converted config.
"""
from medcat.config.config_meta_cat import ConfigMetaCAT
from medcat.config.config_transformers_ner import ConfigTransformersNER
from medcat.config.config_rel_cat import ConfigRelCAT
if cls is Config:
return get_config_from_old(path)
elif cls is ConfigMetaCAT:
from medcat.utils.legacy.convert_meta_cat import (
load_cnf as load_meta_cat_cnf)
return load_meta_cat_cnf(path)
elif cls is ConfigTransformersNER:
from medcat.utils.legacy.convert_deid import (
get_cnf as load_deid_cnf)
return load_deid_cnf(path)
elif cls is ConfigRelCAT:
from medcat.utils.legacy.convert_rel_cat import (
load_cnf as load_rel_cat_cnf)
return load_rel_cat_cnf(path)
raise ValueError(f"The config at '{path}' is not a {cls.__name__}!")
1 change: 1 addition & 0 deletions medcat-v2/tests/resources/mct_v1_deid_cnf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"general": {"name": "NOT-DEID", "model_name": "roberta-base", "seed": 13, "description": "No description", "pipe_batch_size_in_chars": 20000000, "ner_aggregation_strategy": "simple", "chunking_overlap_window": 5, "test_size": 0.2, "last_train_on": null, "verbose_metrics": false}}
1 change: 1 addition & 0 deletions medcat-v2/tests/resources/mct_v1_meta_cat_cnf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"general": {"device": "cpu", "disable_component_lock": false, "seed": 13, "description": "No description", "category_name": "TEST CATEGORY", "alternative_category_names": [], "category_value2id": {}, "alternative_class_names": [[]], "vocab_size": 3, "lowercase": true, "cntx_left": 15, "cntx_right": 10, "replace_center": null, "batch_size_eval": 5000, "annotate_overlapping": false, "tokenizer_name": "bbpe", "save_and_reuse_tokens": false, "pipe_batch_size_in_chars": 20000000, "span_group": null}, "model": {"model_name": "lstm", "model_variant": "bert-base-uncased", "model_freeze_layers": true, "num_layers": 2, "input_size": 300, "hidden_size": 300, "dropout": 0.5, "phase_number": 0, "category_undersample": "", "model_architecture_config": {"fc2": true, "fc3": false, "lr_scheduler": true}, "num_directions": 2, "nclasses": 2, "padding_idx": -1, "emb_grad": true, "ignore_cpos": false}, "train": {"batch_size": 100, "nepochs": 50, "lr": 0.001, "test_size": 0.1, "shuffle_data": true, "class_weights": null, "compute_class_weights": false, "score_average": "weighted", "prerequisites": {}, "cui_filter": null, "auto_save_model": true, "last_train_on": null, "metric": {"base": "weighted avg", "score": "f1-score"}, "loss_funct": "cross_entropy", "gamma": 2}}
1 change: 1 addition & 0 deletions medcat-v2/tests/resources/mct_v1_rel_cat_cnf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"general": {"device": "cpu", "relation_type_filter_pairs": [], "vocab_size": null, "lowercase": true, "cntx_left": 15, "cntx_right": 15, "window_size": 300, "limit_samples_per_class": -1, "addl_rels_max_sample_size": 200, "create_addl_rels": false, "create_addl_rels_by_type": false, "tokenizer_name": "bert", "model_name": "bert-unknown", "log_level": 20, "max_seq_length": 512, "tokenizer_special_tokens": false, "annotation_schema_tag_ids": [30522, 30523, 30524, 30525], "tokenizer_relation_annotation_special_tokens_tags": ["[s1]", "[e1]", "[s2]", "[e2]"], "tokenizer_other_special_tokens": {"pad_token": "[PAD]"}, "labels2idx": {}, "idx2labels": {}, "pin_memory": true, "seed": 13, "task": "train", "language": "en"}, "model": {"input_size": 300, "hidden_size": 768, "hidden_layers": 3, "model_size": 5120, "dropout": 0.2, "num_directions": 2, "freeze_layers": true, "padding_idx": -1, "emb_grad": true, "ignore_cpos": false, "llama_use_pooled_output": false}, "train": {"nclasses": 2, "batch_size": 25, "nepochs": 1, "lr": 0.0001, "stratified_batching": false, "batching_samples_per_class": [], "batching_minority_limit": 0, "adam_betas": [0.9, 0.999], "adam_weight_decay": 0, "adam_epsilon": 1e-08, "test_size": 0.2, "gradient_acc_steps": 1, "multistep_milestones": [2, 4, 6, 8, 12, 15, 18, 20, 22, 24, 26, 30], "multistep_lr_gamma": 0.8, "max_grad_norm": 1.0, "shuffle_data": true, "class_weights": null, "enable_class_weights": false, "score_average": "weighted", "auto_save_model": true}}
53 changes: 51 additions & 2 deletions medcat-v2/tests/utils/legacy/test_convert_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Type, Any
import os

from medcat.utils.legacy import convert_config

from medcat.config import Config
from medcat.config.config import SerialisableBaseModel
from medcat.config.config_meta_cat import ConfigMetaCAT
from medcat.config.config_rel_cat import ConfigRelCAT
from medcat.config.config_transformers_ner import ConfigTransformersNER

import unittest


TESTS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__),
"..", ".."))
from ... import RESOURCES_PATH
TESTS_PATH = os.path.dirname(RESOURCES_PATH)


class ValAndModelGetterTests(unittest.TestCase):
Expand Down Expand Up @@ -78,3 +83,47 @@ def test_migrates_partial(self):
def test_preprocesses_sets(self):
self.assertEqual(self.cnf.preprocessing.words_to_skip,
self.EXP_WORDS_TO_SKIP)


class PerClsConfigConversionTests(unittest.TestCase):
# paths, classes, expected path, expected value
# NOTE: These are hard-coded values I know I changed in the confgis
# before saving
PATHS_AND_CLASSES: list[str, Type[SerialisableBaseModel], str, Any] = [
(os.path.join(RESOURCES_PATH,
"mct_v1_cnf.json"), Config,
'meta.description', "FAKE MODEL"),
(os.path.join(RESOURCES_PATH,
"mct_v1_meta_cat_cnf.json"), ConfigMetaCAT,
"general.category_name", 'TEST CATEGORY'),
(os.path.join(RESOURCES_PATH,
"mct_v1_rel_cat_cnf.json"), ConfigRelCAT,
"general.model_name", 'bert-unknown'),
(os.path.join(RESOURCES_PATH,
"mct_v1_deid_cnf.json"), ConfigTransformersNER,
"general.name", 'NOT-DEID'),
]

@classmethod
def setUpClass(cls):
return super().setUpClass()

def _get_attr_nested(self, obj: SerialisableBaseModel, path: str) -> Any:
"""Get an attribute from a nested object using a dot-separated path."""
parts = path.split('.')
for part in parts:
obj = getattr(obj, part)
return obj

def assert_can_convert(
self, path, cls: Type[SerialisableBaseModel],
exp_path: str, exp_value: Any):
cnf = convert_config.get_config_from_old_per_cls(path, cls)
self.assertIsInstance(cnf, cls, f"Failed for {cls.__name__}")
self.assertEqual(self._get_attr_nested(cnf, exp_path), exp_value,
f"Failed for {cls.__name__} at {exp_path}")

def test_can_convert(self):
for path, cls, exp_path, exp_value in self.PATHS_AND_CLASSES:
with self.subTest(f"Testing {cls.__name__} at {path}"):
self.assert_can_convert(path, cls, exp_path, exp_value)