Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "ec4a8509",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -92,7 +92,7 @@
],
"source": [
"# Install medcat\n",
"! pip install \"medcat[spacy,meta-cat] @ git+https://github.com/CogStack/cogstack-nlp@medcat/v0.11.2#subdirectory=medcat-v2\" # NOTE: VERSION-STRING"
"! pip install \"medcat[spacy,rel-cat] @ git+https://github.com/CogStack/cogstack-nlp@medcat/v0.11.2#subdirectory=medcat-v2\" # NOTE: VERSION-STRING"
]
},
{
Expand Down
23 changes: 8 additions & 15 deletions medcat-v2/medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from medcat.components.types import AbstractCoreComponent, HashableComponet
from medcat.components.addons.addons import AddonComponent
from medcat.utils.legacy.identifier import is_legacy_model_pack
from medcat.utils.defaults import AVOID_LEGACY_CONVERSION_ENVIRON
from medcat.utils.defaults import avoid_legacy_conversion
from medcat.utils.defaults import doing_legacy_conversion_message
from medcat.utils.defaults import LegacyConversionDisabledError
from medcat.utils.usage_monitoring import UsageMonitor


Expand Down Expand Up @@ -602,22 +604,13 @@ def load_model_pack(cls, model_pack_path: str) -> 'CAT':
logger.info("Attempting to load model from file: %s",
model_pack_path)
is_legacy = is_legacy_model_pack(model_pack_path)
should_avoid = os.environ.get(
AVOID_LEGACY_CONVERSION_ENVIRON, "False").lower() == "true"
if is_legacy and not should_avoid:
avoid_legacy = avoid_legacy_conversion()
if is_legacy and not avoid_legacy:
from medcat.utils.legacy.conversion_all import Converter
logger.warning(
"Doing legacy conversion on model pack '%s'. "
"This will make the model load take significantly longer. "
"If you wish to avoid this, set the environment variable '%s' "
"to 'true'", model_pack_path, AVOID_LEGACY_CONVERSION_ENVIRON)
doing_legacy_conversion_message(logger, 'CAT', model_pack_path)
return Converter(model_pack_path, None).convert()
elif is_legacy and should_avoid:
raise ValueError(
f"The model pack '{model_pack_path}' is a legacy model pack. "
"Please set the environment variable "
f"'{AVOID_LEGACY_CONVERSION_ENVIRON}' "
"to 'true' to allow automatic conversion.")
elif is_legacy and avoid_legacy:
raise LegacyConversionDisabledError("CAT")
# NOTE: ignoring addons since they will be loaded later / separately
cat = deserialise(model_pack_path, model_load_path=model_pack_path,
ignore_folders_prefix={
Expand Down
11 changes: 11 additions & 0 deletions medcat-v2/medcat/cdb/cdb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable, Any, Collection, Union, Literal
import os

from medcat.storage.serialisables import AbstractSerialisable
from medcat.cdb.concepts import CUIInfo, NameInfo, TypeInfo
Expand All @@ -9,6 +10,9 @@
from medcat.storage.zip_utils import (
should_serialise_as_zip, serialise_as_zip, deserialise_from_zip)
from medcat.utils.defaults import default_weighted_average, StatusTypes as ST
from medcat.utils.defaults import avoid_legacy_conversion
from medcat.utils.defaults import doing_legacy_conversion_message
from medcat.utils.defaults import LegacyConversionDisabledError
from medcat.utils.hasher import Hasher
from medcat.preprocessors.cleaners import NameDescriptor
from medcat.config import Config
Expand Down Expand Up @@ -510,6 +514,13 @@ def save(self, save_path: str,
def load(cls, path: str) -> 'CDB':
if should_serialise_as_zip(path, 'auto'):
cdb = deserialise_from_zip(path)
elif os.path.isfile(path) and path.endswith('.dat'):
if not avoid_legacy_conversion():
from medcat.utils.legacy.convert_cdb import get_cdb_from_old
doing_legacy_conversion_message(logger, 'CDB', path)
cdb = get_cdb_from_old(path)
else:
raise LegacyConversionDisabledError("CDB")
else:
cdb = deserialise(path)
if not isinstance(cdb, CDB):
Expand Down
29 changes: 29 additions & 0 deletions medcat-v2/medcat/utils/defaults.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Optional
from multiprocessing import cpu_count
from functools import lru_cache
import logging


DEFAULT_SPACY_MODEL = 'en_core_web_md'
Expand All @@ -9,6 +11,33 @@
AVOID_LEGACY_CONVERSION_ENVIRON = "MEDCAT_AVOID_LECACY_CONVERSION"


def avoid_legacy_conversion() -> bool:
return os.environ.get(
AVOID_LEGACY_CONVERSION_ENVIRON, "False").lower() == "true"


class LegacyConversionDisabledError(Exception):
"""Raised when legacy conversion is disabled."""

def __init__(self, component_name: str):
super().__init__(
f"Legacy conversion is disabled (while loading {component_name}). "
f"Set the environment variable {AVOID_LEGACY_CONVERSION_ENVIRON} "
"to `False` to allow conversion.")


def doing_legacy_conversion_message(
logger: logging.Logger, component_name: str, file_path: str = '',
level: int = logging.WARNING
) -> None:
logger.log(
level,
"Doing legacy conversion on %s (at '%s'). "
"Set the environment variable %s "
"to `True` to avoid this.",
component_name, file_path, AVOID_LEGACY_CONVERSION_ENVIRON)


@lru_cache(maxsize=100)
def default_weighted_average(step: int, factor: float = 0.0004) -> float:
return max(0.1, 1 - (step ** 2 * factor))
Expand Down
16 changes: 16 additions & 0 deletions medcat-v2/medcat/vocab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Any, cast, Union, Literal
from typing_extensions import TypedDict
import os
import logging

# import dill
import numpy as np
Expand All @@ -9,6 +11,12 @@
deserialise, AvailableSerialisers, serialise)
from medcat.storage.zip_utils import (
should_serialise_as_zip, serialise_as_zip, deserialise_from_zip)
from medcat.utils.defaults import avoid_legacy_conversion
from medcat.utils.defaults import doing_legacy_conversion_message
from medcat.utils.defaults import LegacyConversionDisabledError


logger = logging.getLogger(__name__)


WordDescriptor = TypedDict('WordDescriptor',
Expand Down Expand Up @@ -323,6 +331,14 @@ def save(self, save_path: str,
def load(cls, path: str) -> 'Vocab':
if should_serialise_as_zip(path, 'auto'):
vocab = deserialise_from_zip(path)
elif os.path.isfile(path) and path.endswith('.dat'):
if not avoid_legacy_conversion():
from medcat.utils.legacy.convert_vocab import (
get_vocab_from_old)
doing_legacy_conversion_message(logger, 'Vocab', path)
vocab = get_vocab_from_old(path)
else:
raise LegacyConversionDisabledError("Vocab")
else:
vocab = deserialise(path)
if not isinstance(vocab, Vocab):
Expand Down
6 changes: 6 additions & 0 deletions medcat-v2/tests/cdb/test_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import tempfile

from .. import UNPACKED_EXAMPLE_MODEL_PACK_PATH, RESOURCES_PATH
from .. import UNPACKED_V1_MODEL_PACK_PATH


ZIPPED_CDB_PATH = os.path.join(RESOURCES_PATH, "mct2_cdb.zip")


class CDBTests(TestCase):
CDB_PATH = os.path.join(UNPACKED_EXAMPLE_MODEL_PACK_PATH, "cdb")
LEGACY_CDB_PATH = os.path.join(UNPACKED_V1_MODEL_PACK_PATH, "cdb.dat")
CUI_TO_REMOVE = "C03"
NAMES_TO_REMOVE = ['high~temperature']
TO_FILTER = ['C01', 'C02']
Expand All @@ -40,6 +42,10 @@ def test_can_load_from_zip(self):
# make sure it's actually a file not a folder
self.assertTrue(os.path.isfile(ZIPPED_CDB_PATH))

def test_can_convert_legacy_upon_load(self):
loaded = cdb.CDB.load(self.LEGACY_CDB_PATH)
self.assertIsInstance(loaded, cdb.CDB)

def test_can_save_to_zip(self):
with tempfile.TemporaryDirectory() as temp_dir:
file_name = os.path.join(temp_dir, "cdb.zip")
Expand Down
3 changes: 2 additions & 1 deletion medcat-v2/tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from medcat.utils.cdb_state import captured_state_cdb
from medcat.components.addons.meta_cat import MetaCATAddon
from medcat.utils.defaults import AVOID_LEGACY_CONVERSION_ENVIRON
from medcat.utils.defaults import LegacyConversionDisabledError

import unittest
import tempfile
Expand Down Expand Up @@ -648,7 +649,7 @@ def test_can_load_legacy_model_unpacked(self):
def test_cannot_load_legacy_with_environ_set(self):
with unittest.mock.patch.dict(os.environ, {
AVOID_LEGACY_CONVERSION_ENVIRON: "true"}, clear=True):
with self.assertRaises(ValueError):
with self.assertRaises(LegacyConversionDisabledError):
cat.CAT.load_model_pack(V1_MODEL_PACK_PATH)


Expand Down
6 changes: 6 additions & 0 deletions medcat-v2/tests/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tempfile

from . import UNPACKED_EXAMPLE_MODEL_PACK_PATH, RESOURCES_PATH
from . import UNPACKED_V1_MODEL_PACK_PATH


ZIPPED_VOCAB_PATH = os.path.join(RESOURCES_PATH, "mct2_vocab.zip")
Expand Down Expand Up @@ -169,6 +170,7 @@ def test_neg_sampling_does_not_include_vectorless(

class DefaultVocabTests(unittest.TestCase):
VOCAB_PATH = os.path.join(UNPACKED_EXAMPLE_MODEL_PACK_PATH, 'vocab')
LEGACY_VOCAB_PATH = os.path.join(UNPACKED_V1_MODEL_PACK_PATH, "vocab.dat")
EXP_SHAPE = (7,)

@classmethod
Expand Down Expand Up @@ -199,6 +201,10 @@ def test_can_load_from_zip(self):
vocab = Vocab.load(ZIPPED_VOCAB_PATH)
self.assertIsInstance(vocab, Vocab)

def test_can_convert_legacy_upon_load(self):
loaded = Vocab.load(self.LEGACY_VOCAB_PATH)
self.assertIsInstance(loaded, Vocab)

def test_can_save_to_zip(self):
with tempfile.TemporaryDirectory() as temp_dir:
file_name = os.path.join(temp_dir, 'vocab.zip')
Expand Down