diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index 1d67f0faa..cd72e7652 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -5,16 +5,25 @@ from medcat.tokenizing.tokenizers import BaseTokenizer from typing import Optional, Iterator, Set from medcat.vocab import Vocab -from torch import Tensor -from transformers import AutoTokenizer, AutoModel from medcat.utils.postprocessing import create_main_ann from tqdm import tqdm from collections import defaultdict -import torch.nn.functional as F -import torch import logging import math +from medcat.utils.import_utils import ensure_optional_extras_installed +import medcat + +# NOTE: the below needs to be before torch/transformers imports +_EXTRA_NAME = "embed-linker" +ensure_optional_extras_installed(medcat.__name__, _EXTRA_NAME) + +# avoid linting issues due to above check +from torch import Tensor # noqa: E402 +from transformers import AutoTokenizer, AutoModel # noqa: E402 +import torch.nn.functional as F # noqa: E402 +import torch # noqa: E402 + logger = logging.getLogger(__name__) diff --git a/medcat-v2/pyproject.toml b/medcat-v2/pyproject.toml index a675ebcbe..b7bc6095d 100644 --- a/medcat-v2/pyproject.toml +++ b/medcat-v2/pyproject.toml @@ -113,6 +113,10 @@ rel_cat = [ "scikit-learn>=1.1.3,<2.0", "torch>=2.4.0,<3.0", ] +embed_linker = [ + "transformers>=4.41.0,<5.0", # avoid major bump + "torch>=2.4.0,<3.0", +] test = [] # TODO - list [project.urls]