Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,9 @@ class EmbeddingLinking(Linking):
use_ner_link_candidates: bool = True
"""Link candidates are provided by some NER steps. This will flag if
you want to trust them or not."""
learning_rate: float = 1e-4
"""Learning rate for training the embedding linker. Only used if
the embedding linker is trainable."""
weight_decay: float = 0.01
"""Weight decay for training the embedding linker. Only used if
the embedding linker is trainable."""
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class Linker(AbstractEntityProvidingComponent):
comp_name = "embedding_linker"
name = "embedding_linker"
_MODEL_FOLDER_NAME = "embedding_model"
_STATE_FILE_NAME = "state.json"

Expand Down Expand Up @@ -124,13 +124,21 @@ def create_embeddings(
)
self.max_length = max_length
self.cnf_l.max_token_length = max_length
self.context_model.max_length = max_length

self.context_model.embed_cuis(embedding_model_name)
self.context_model.embed_names(embedding_model_name)
# Route model swaps through linker-level hook so trainable variants can
# refresh optimizer/scaler when underlying params change.
self.load_transformers(embedding_model_name)
self.context_model.embed_cuis()
self.context_model.embed_names()

self._names_context_matrix = None
self._cui_context_matrix = None

def load_transformers(self, embedding_model_name: str) -> None:
"""Pass through to the underlying transformer model for context embedding."""
self.context_model.load_transformers(embedding_model_name)

def get_type(self) -> CoreComponentType:
return CoreComponentType.linking

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from medcat.tokenizing.tokenizers import BaseTokenizer
from medcat.tokenizing.tokens import MutableDocument, MutableEntity
from medcat.vocab import Vocab
from medcat_embedding_linker.embedding_linker import Linker
from medcat_embedding_linker.embedding_linker import Linker as StaticEmbeddingLinker
from medcat.storage.serialisables import AbstractManualSerialisable
import logging
import torch
Expand All @@ -17,13 +17,14 @@
logger = logging.getLogger(__name__)


class TrainableEmbeddingLinker(Linker, AbstractManualSerialisable):
class Linker(StaticEmbeddingLinker, AbstractManualSerialisable):
"""Trainable variant of the embedding linker.
This class inherits inference and embedding behavior from Linker and provides
method hooks for online/offline training.
"""

comp_name = "trainable_embedding_linker"
name = "trainable_embedding_linker"

_MODEL_FOLDER_NAME = "trainable_embedding_model"
_MODEL_STATE_FILE_NAME = "model_state.pt"

Expand All @@ -47,11 +48,39 @@ def __init__(self, cdb: CDB, config: Config) -> None:
self.negative_sampling_candidate_pool_size = (
self.cnf_l.negative_sampling_candidate_pool_size
)
self.scaler = torch.amp.GradScaler() # for FP16 training stability
self.reset_optimizer_and_scaler()

def reset_optimizer_and_scaler(
self,
learning_rate: Optional[float] = None,
weight_decay: Optional[float] = None,
) -> None:
"""Recreate training state bound to the current context model params.

Optionally update the learning rate and weight decay in the config.
If not provided, the current config values are used.

Args:
learning_rate: New learning rate. Updates config if provided.
weight_decay: New weight decay. Updates config if provided.
"""
if learning_rate is not None:
self.cnf_l.learning_rate = learning_rate
if weight_decay is not None:
self.cnf_l.weight_decay = weight_decay
# Keep scaler and optimizer aligned with the currently loaded model.
self.scaler = torch.amp.GradScaler()
self.optimizer = torch.optim.AdamW(
self.context_model.model.parameters(), lr=1e-4, weight_decay=0.01
self.context_model.model.parameters(),
lr=self.cnf_l.learning_rate,
weight_decay=self.cnf_l.weight_decay,
)

def load_transformers(self, embedding_model_name: str) -> None:
"""Switch embedding model and refresh optimizer/scaler to new params."""
self.context_model.load_transformers(embedding_model_name)
self.reset_optimizer_and_scaler()

def _generate_negative_samples(
self,
candidate_indices: Tensor,
Expand Down Expand Up @@ -366,7 +395,7 @@ def create_new_component(
cdb: CDB,
vocab: Vocab,
model_load_path: Optional[str],
) -> "TrainableEmbeddingLinker":
) -> "Linker":
return cls(cdb, cdb.config)

def serialise_to(self, folder_path: str) -> None:
Expand All @@ -382,7 +411,7 @@ def serialise_to(self, folder_path: str) -> None:
@classmethod
def deserialise_from(
cls, folder_path: str, **init_kwargs
) -> "TrainableEmbeddingLinker":
) -> "Linker":
cdb = init_kwargs["cdb"]
linker = cls(cdb, cdb.config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,12 @@ def embed(
outputs = self.model(**batch_dict)
return outputs.half()

def embed_cuis(
self, embedding_model_name: Optional[Union[str, Path]] = None
) -> None:
def embed_cuis(self) -> None:
"""Create embeddings for each CUI's longest name and store in CDB.

If ``embedding_model_name`` is provided, switch/load that model first.
Otherwise, reuse the currently loaded model (training-friendly default).
Switch the model first via ``load_transformers`` if needed.
"""
target_model = embedding_model_name or self.cnf_l.embedding_model_name
self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding
self.load_transformers(target_model)

cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size)
Expand All @@ -344,17 +339,12 @@ def embed_cuis(
self.cdb.addl_info["cui_embeddings"] = all_embeddings_matrix
logger.debug("Embedding cui names done, total: %d", len(cui_names))

def embed_names(
self, embedding_model_name: Optional[Union[str, Path]] = None
) -> None:
def embed_names(self) -> None:
"""Create embeddings for all names and store in CDB.

If ``embedding_model_name`` is provided, switch/load that model first.
Otherwise, reuse the currently loaded model (training-friendly default).
Switch the model first via ``load_transformers`` if needed.
"""
target_model = embedding_model_name or self.cnf_l.embedding_model_name
self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding
self.load_transformers(target_model)

names = self._name_keys
total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size)
Expand Down
Loading