From cb03062c510bdd23279fcbdb831d549f8437a851 Mon Sep 17 00:00:00 2001 From: Graham Hukill Date: Mon, 3 Nov 2025 14:51:15 -0500 Subject: [PATCH 1/4] Introduce strategies framework for preparing EmbeddingInputs Why these changes are being introduced: A core requirement of this application is the ability to take a TIMDEX JSON record and "transform" all or parts of it into a single string for which an embedding can be created. We are calling these "embedding strategies" in the context of this app. While our first strategy will likely be a very simple, full record approach, we want to support multiple strategies in the application, and even multiple strategies for a single record in a single invocation. How this addresses that need: * A new 'strategies' module is created * A base 'BaseStrategy' class, with a required 'extract_text()' method for implementations * Our first strategy represented in class 'FullRecordStrategy', which JSON dumps the entire TIMDEX JSON record. * A registry of strategies, similar to our models, that allow CLI level validation. Side effects of this change: * None really, but further solidifies that this application is contains the opinionation about how text is prepared for the embedding process. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-131 * https://mitlibraries.atlassian.net/browse/USE-132 --- README.md | 5 ++- embeddings/cli.py | 57 +++++++--------------------- embeddings/strategies/__init__.py | 1 + embeddings/strategies/base.py | 57 ++++++++++++++++++++++++++++ embeddings/strategies/full_record.py | 13 +++++++ embeddings/strategies/processor.py | 42 ++++++++++++++++++++ embeddings/strategies/registry.py | 29 ++++++++++++++ 7 files changed, 158 insertions(+), 46 deletions(-) create mode 100644 embeddings/strategies/__init__.py create mode 100644 embeddings/strategies/base.py create mode 100644 embeddings/strategies/full_record.py create mode 100644 embeddings/strategies/processor.py create mode 100644 embeddings/strategies/registry.py diff --git a/README.md b/README.md index 1295ad6..7ff9fc9 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,9 @@ Options: default = 0. [required] --record-limit INTEGER Limit number of records after --run-record- offset, default = None (unlimited). [required] - --strategy TEXT Pre-embedding record transformation strategy to - use. Repeatable. [required] + --strategy [full_record] Pre-embedding record transformation strategy. + Repeatable to apply multiple strategies. + [required] --output-jsonl TEXT Optionally write embeddings to local JSONLines file (primarily for testing). --help Show this message and exit. diff --git a/embeddings/cli.py b/embeddings/cli.py index 3c459ee..11f203b 100644 --- a/embeddings/cli.py +++ b/embeddings/cli.py @@ -1,4 +1,5 @@ import functools +import json import logging import time from collections.abc import Callable @@ -12,6 +13,8 @@ from embeddings.config import configure_logger, configure_sentry from embeddings.models.registry import get_model_class +from embeddings.strategies.processor import create_embedding_inputs +from embeddings.strategies.registry import STRATEGY_REGISTRY logger = logging.getLogger(__name__) @@ -181,10 +184,13 @@ def test_model_load(ctx: click.Context) -> None: ) @click.option( "--strategy", - type=str, # WIP: establish an enum of supported strategies + type=click.Choice(list(STRATEGY_REGISTRY.keys())), required=True, multiple=True, - help="Pre-embedding record transformation strategy to use. Repeatable.", + help=( + "Pre-embedding record transformation strategy. " + "Repeatable to apply multiple strategies." + ), ) @click.option( "--output-jsonl", @@ -222,48 +228,11 @@ def create_embeddings( action="index", ) - # create an iterator of InputTexts applying all requested strategies to all records - # WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that - # create texts based on the requested strategies (e.g. "full record"), which are - # captured in --strategy CLI args - # WIP NOTE: the following simulates that... - # DEBUG ------------------------------------------------------------------------------ - import json # noqa: PLC0415 - - from embeddings.embedding import EmbeddingInput # noqa: PLC0415 - - input_records = ( - EmbeddingInput( - timdex_record_id=timdex_record["timdex_record_id"], - run_id=timdex_record["run_id"], - run_record_offset=timdex_record["run_record_offset"], - embedding_strategy=_strategy, - text=json.dumps(timdex_record["transformed_record"].decode()), - ) - for timdex_record in timdex_records - for _strategy in strategy - ) - # DEBUG ------------------------------------------------------------------------------ - - # create an iterator of Embeddings via the embedding model - # WIP NOTE: this will use the embedding class .create_embeddings() bulk method - # WIP NOTE: the following simulates that... - # DEBUG ------------------------------------------------------------------------------ - from embeddings.embedding import Embedding # noqa: PLC0415 - - embeddings = ( - Embedding( - timdex_record_id=input_record.timdex_record_id, - run_id=input_record.run_id, - run_record_offset=input_record.run_record_offset, - embedding_strategy=input_record.embedding_strategy, - model_uri=model.model_uri, - embedding_vector=[0.1, 0.2, 0.3], - embedding_token_weights={"coffee": 0.9, "seattle": 0.5}, - ) - for input_record in input_records - ) - # DEBUG ------------------------------------------------------------------------------ + # create an iterator of EmbeddingInputs applying all requested strategies + input_records = create_embedding_inputs(timdex_records, list(strategy)) + + # create embeddings via the embedding model + embeddings = model.create_embeddings(input_records) # if requested, write embeddings to a local JSONLines file if output_jsonl: diff --git a/embeddings/strategies/__init__.py b/embeddings/strategies/__init__.py new file mode 100644 index 0000000..19116b1 --- /dev/null +++ b/embeddings/strategies/__init__.py @@ -0,0 +1 @@ +"""Strategies for transforming TIMDEX records into EmbeddingInputs.""" diff --git a/embeddings/strategies/base.py b/embeddings/strategies/base.py new file mode 100644 index 0000000..c463cf5 --- /dev/null +++ b/embeddings/strategies/base.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod + +from embeddings.embedding import EmbeddingInput + + +class BaseStrategy(ABC): + """Base class for embedding input strategies. + + All child classes must set class level attribute STRATEGY_NAME. + """ + + STRATEGY_NAME: str # type hint to document the requirement + + def __init__( + self, + timdex_record_id: str, + run_id: str, + run_record_offset: int, + transformed_record: dict, + ) -> None: + """Initialize strategy with TIMDEX record metadata. + + Args: + timdex_record_id: TIMDEX record ID + run_id: TIMDEX ETL run ID + run_record_offset: record offset within the run + transformed_record: parsed TIMDEX record JSON + """ + self.timdex_record_id = timdex_record_id + self.run_id = run_id + self.run_record_offset = run_record_offset + self.transformed_record = transformed_record + + def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 + super().__init_subclass__(**kwargs) + + # require class level STRATEGY_NAME to be set + if not hasattr(cls, "STRATEGY_NAME"): + msg = f"{cls.__name__} must define 'STRATEGY_NAME' class attribute" + raise TypeError(msg) + if not isinstance(cls.STRATEGY_NAME, str): + msg = f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string" + raise TypeError(msg) + + @abstractmethod + def extract_text(self) -> str: + """Extract text to be embedded from transformed_record.""" + + def to_embedding_input(self) -> EmbeddingInput: + """Create EmbeddingInput instance with strategy-specific extracted text.""" + return EmbeddingInput( + timdex_record_id=self.timdex_record_id, + run_id=self.run_id, + run_record_offset=self.run_record_offset, + embedding_strategy=self.STRATEGY_NAME, + text=self.extract_text(), + ) diff --git a/embeddings/strategies/full_record.py b/embeddings/strategies/full_record.py new file mode 100644 index 0000000..81a7b7d --- /dev/null +++ b/embeddings/strategies/full_record.py @@ -0,0 +1,13 @@ +import json + +from embeddings.strategies.base import BaseStrategy + + +class FullRecordStrategy(BaseStrategy): + """Serialize entire TIMDEX record JSON as embedding input.""" + + STRATEGY_NAME = "full_record" + + def extract_text(self) -> str: + """Serialize the entire transformed_record as JSON.""" + return json.dumps(self.transformed_record) diff --git a/embeddings/strategies/processor.py b/embeddings/strategies/processor.py new file mode 100644 index 0000000..eb1b143 --- /dev/null +++ b/embeddings/strategies/processor.py @@ -0,0 +1,42 @@ +import json +from collections.abc import Iterator + +from embeddings.embedding import EmbeddingInput +from embeddings.strategies.registry import get_strategy_class + + +def create_embedding_inputs( + timdex_records: Iterator[dict], + strategies: list[str], +) -> Iterator[EmbeddingInput]: + """Yield EmbeddingInput instances for all records x all strategies. + + Creates a cartesian product: each record is processed by each strategy, + yielding one EmbeddingInput per combination. + + Args: + timdex_records: Iterator of TIMDEX records. + Expected keys: timdex_record_id, run_id, run_record_offset, + transformed_record (bytes) + strategies: List of strategy names to apply + + Yields: + EmbeddingInput instances ready for embedding model + + Example: + 100 records x 3 strategies = 300 EmbeddingInput instances + """ + for timdex_record in timdex_records: + # decode and parse the TIMDEX JSON record + transformed_record = json.loads(timdex_record["transformed_record"].decode()) + + # apply all strategies to the record and yield + for strategy_name in strategies: + strategy_class = get_strategy_class(strategy_name) + strategy_instance = strategy_class( + timdex_record_id=timdex_record["timdex_record_id"], + run_id=timdex_record["run_id"], + run_record_offset=timdex_record["run_record_offset"], + transformed_record=transformed_record, + ) + yield strategy_instance.to_embedding_input() diff --git a/embeddings/strategies/registry.py b/embeddings/strategies/registry.py new file mode 100644 index 0000000..3dc12bb --- /dev/null +++ b/embeddings/strategies/registry.py @@ -0,0 +1,29 @@ +import logging + +from embeddings.strategies.base import BaseStrategy +from embeddings.strategies.full_record import FullRecordStrategy + +logger = logging.getLogger(__name__) + +STRATEGY_CLASSES = [ + FullRecordStrategy, +] + +STRATEGY_REGISTRY: dict[str, type[BaseStrategy]] = { + strategy.STRATEGY_NAME: strategy for strategy in STRATEGY_CLASSES +} + + +def get_strategy_class(strategy_name: str) -> type[BaseStrategy]: + """Get strategy class by name. + + Args: + strategy_name: Name of the strategy to retrieve + """ + if strategy_name not in STRATEGY_REGISTRY: + available = ", ".join(sorted(STRATEGY_REGISTRY.keys())) + msg = f"Unknown strategy: {strategy_name}. Available: {available}" + logger.error(msg) + raise ValueError(msg) + + return STRATEGY_REGISTRY[strategy_name] From 70d839443428dccba0ab1de59a229eb5abc53e01 Mon Sep 17 00:00:00 2001 From: Graham Hukill Date: Mon, 3 Nov 2025 15:09:29 -0500 Subject: [PATCH 2/4] Strategies unit tests --- tests/test_strategies.py | 83 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/test_strategies.py diff --git a/tests/test_strategies.py b/tests/test_strategies.py new file mode 100644 index 0000000..fc256b1 --- /dev/null +++ b/tests/test_strategies.py @@ -0,0 +1,83 @@ +import json + +import pytest + +from embeddings.strategies.base import BaseStrategy +from embeddings.strategies.full_record import FullRecordStrategy +from embeddings.strategies.processor import create_embedding_inputs +from embeddings.strategies.registry import get_strategy_class + + +def test_full_record_strategy_creates_embedding_input(): + transformed_record = {"timdex_record_id": "test-123", "title": ["Test Title"]} + strategy = FullRecordStrategy( + timdex_record_id="test-123", + run_id="run-456", + run_record_offset=42, + transformed_record=transformed_record, + ) + + embedding_input = strategy.to_embedding_input() + + assert embedding_input.timdex_record_id == "test-123" + assert embedding_input.run_id == "run-456" + assert embedding_input.run_record_offset == 42 + assert embedding_input.embedding_strategy == "full_record" + assert embedding_input.text == json.dumps(transformed_record) + + +def test_create_embedding_inputs_yields_cartesian_product(): + # two records + timdex_records = iter( + [ + { + "timdex_record_id": "id-1", + "run_id": "run-1", + "run_record_offset": 0, + "transformed_record": b'{"title": ["Record 1"]}', + }, + { + "timdex_record_id": "id-2", + "run_id": "run-1", + "run_record_offset": 1, + "transformed_record": b'{"title": ["Record 2"]}', + }, + ] + ) + + # single strategy (for now) + strategies = ["full_record"] + + embedding_inputs = list(create_embedding_inputs(timdex_records, strategies)) + + assert len(embedding_inputs) == 2 + assert embedding_inputs[0].timdex_record_id == "id-1" + assert embedding_inputs[0].embedding_strategy == "full_record" + assert embedding_inputs[1].timdex_record_id == "id-2" + assert embedding_inputs[1].embedding_strategy == "full_record" + + +def test_get_strategy_class_returns_correct_class(): + strategy_class = get_strategy_class("full_record") + assert strategy_class is FullRecordStrategy + + +def test_get_strategy_class_raises_for_unknown_strategy(): + with pytest.raises(ValueError, match="Unknown strategy"): + get_strategy_class("nonexistent_strategy") + + +def test_subclass_without_strategy_name_raises_type_error(): + with pytest.raises(TypeError, match="must define 'STRATEGY_NAME' class attribute"): + + class InvalidStrategy(BaseStrategy): + pass + + +def test_subclass_with_non_string_strategy_name_raises_type_error(): + with pytest.raises( + TypeError, match="must override 'STRATEGY_NAME' with a valid string" + ): + + class InvalidStrategy(BaseStrategy): + STRATEGY_NAME = 123 From 145cd8134ee008ea4ab663d9d943a0515f35863c Mon Sep 17 00:00:00 2001 From: Graham Hukill Date: Tue, 4 Nov 2025 13:38:27 -0500 Subject: [PATCH 3/4] Streamline exception messages --- embeddings/strategies/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/embeddings/strategies/base.py b/embeddings/strategies/base.py index c463cf5..0ccce72 100644 --- a/embeddings/strategies/base.py +++ b/embeddings/strategies/base.py @@ -36,11 +36,11 @@ def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 # require class level STRATEGY_NAME to be set if not hasattr(cls, "STRATEGY_NAME"): - msg = f"{cls.__name__} must define 'STRATEGY_NAME' class attribute" - raise TypeError(msg) + raise TypeError(f"{cls.__name__} must define 'STRATEGY_NAME' class attribute") if not isinstance(cls.STRATEGY_NAME, str): - msg = f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string" - raise TypeError(msg) + raise TypeError( + f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string" + ) @abstractmethod def extract_text(self) -> str: From 421ba710522d726ec8fa71aa5d37ec189f73d40b Mon Sep 17 00:00:00 2001 From: Graham Hukill Date: Tue, 4 Nov 2025 15:31:25 -0500 Subject: [PATCH 4/4] Init transformer strategy once for all records Why these changes are being introduced: Formerly, a transformer strategy class was instantiated in a per-record fashion, where things like the timdex_record_id and other record-level values were passed. This ultimately felt awkward, when we could just as easily instantiate it once in a more generic fashion, then build EmbeddingInput instances with the *result* of the strategy extracting text from the TIMDEX JSON record. How this addresses that need: All record-level details are removed as arguments for initializing a transformer strategy. Instead, the helper function create_embedding_inputs() is responsible for passing the TIMDEX JSON record to the transformer strategies, and then building an EmbeddingInput object before yielding. This keeps the init of those strategies much simpler, and preventing properties in the class they don't really need. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-131 * https://mitlibraries.atlassian.net/browse/USE-132 --- embeddings/strategies/base.py | 40 +++++----------------------- embeddings/strategies/full_record.py | 4 +-- embeddings/strategies/processor.py | 37 ++++++++++++++----------- tests/test_strategies.py | 20 +++++--------- 4 files changed, 36 insertions(+), 65 deletions(-) diff --git a/embeddings/strategies/base.py b/embeddings/strategies/base.py index 0ccce72..0e27bd5 100644 --- a/embeddings/strategies/base.py +++ b/embeddings/strategies/base.py @@ -1,7 +1,5 @@ from abc import ABC, abstractmethod -from embeddings.embedding import EmbeddingInput - class BaseStrategy(ABC): """Base class for embedding input strategies. @@ -11,26 +9,6 @@ class BaseStrategy(ABC): STRATEGY_NAME: str # type hint to document the requirement - def __init__( - self, - timdex_record_id: str, - run_id: str, - run_record_offset: int, - transformed_record: dict, - ) -> None: - """Initialize strategy with TIMDEX record metadata. - - Args: - timdex_record_id: TIMDEX record ID - run_id: TIMDEX ETL run ID - run_record_offset: record offset within the run - transformed_record: parsed TIMDEX record JSON - """ - self.timdex_record_id = timdex_record_id - self.run_id = run_id - self.run_record_offset = run_record_offset - self.transformed_record = transformed_record - def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 super().__init_subclass__(**kwargs) @@ -43,15 +21,9 @@ def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 ) @abstractmethod - def extract_text(self) -> str: - """Extract text to be embedded from transformed_record.""" - - def to_embedding_input(self) -> EmbeddingInput: - """Create EmbeddingInput instance with strategy-specific extracted text.""" - return EmbeddingInput( - timdex_record_id=self.timdex_record_id, - run_id=self.run_id, - run_record_offset=self.run_record_offset, - embedding_strategy=self.STRATEGY_NAME, - text=self.extract_text(), - ) + def extract_text(self, timdex_record: dict) -> str: + """Extract text to be embedded from transformed_record. + + Args: + timdex_record: TIMDEX JSON record ("transformed_record" in TIMDEX dataset) + """ diff --git a/embeddings/strategies/full_record.py b/embeddings/strategies/full_record.py index 81a7b7d..5dae9d6 100644 --- a/embeddings/strategies/full_record.py +++ b/embeddings/strategies/full_record.py @@ -8,6 +8,6 @@ class FullRecordStrategy(BaseStrategy): STRATEGY_NAME = "full_record" - def extract_text(self) -> str: + def extract_text(self, timdex_record: dict) -> str: """Serialize the entire transformed_record as JSON.""" - return json.dumps(self.transformed_record) + return json.dumps(timdex_record) diff --git a/embeddings/strategies/processor.py b/embeddings/strategies/processor.py index eb1b143..5316729 100644 --- a/embeddings/strategies/processor.py +++ b/embeddings/strategies/processor.py @@ -6,7 +6,7 @@ def create_embedding_inputs( - timdex_records: Iterator[dict], + timdex_dataset_records: Iterator[dict], strategies: list[str], ) -> Iterator[EmbeddingInput]: """Yield EmbeddingInput instances for all records x all strategies. @@ -15,7 +15,7 @@ def create_embedding_inputs( yielding one EmbeddingInput per combination. Args: - timdex_records: Iterator of TIMDEX records. + timdex_dataset_records: Iterator of TIMDEX records. Expected keys: timdex_record_id, run_id, run_record_offset, transformed_record (bytes) strategies: List of strategy names to apply @@ -26,17 +26,24 @@ def create_embedding_inputs( Example: 100 records x 3 strategies = 300 EmbeddingInput instances """ - for timdex_record in timdex_records: - # decode and parse the TIMDEX JSON record - transformed_record = json.loads(timdex_record["transformed_record"].decode()) - - # apply all strategies to the record and yield - for strategy_name in strategies: - strategy_class = get_strategy_class(strategy_name) - strategy_instance = strategy_class( - timdex_record_id=timdex_record["timdex_record_id"], - run_id=timdex_record["run_id"], - run_record_offset=timdex_record["run_record_offset"], - transformed_record=transformed_record, + # instantiate strategy transformers + transformers = [get_strategy_class(strategy)() for strategy in strategies] + + # loop through records and apply all strategies, yielding an EmbeddingInput for each + for timdex_dataset_record in timdex_dataset_records: + + # decode and parse the TIMDEX JSON record once for all requested strategies + timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode()) + + for transformer in transformers: + # prepare text for embedding from transformer strategy + text = transformer.extract_text(timdex_record) + + # emit an EmbeddingInput instance + yield EmbeddingInput( + timdex_record_id=timdex_dataset_record["timdex_record_id"], + run_id=timdex_dataset_record["run_id"], + run_record_offset=timdex_dataset_record["run_record_offset"], + embedding_strategy=transformer.STRATEGY_NAME, + text=text, ) - yield strategy_instance.to_embedding_input() diff --git a/tests/test_strategies.py b/tests/test_strategies.py index fc256b1..ede712a 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -8,22 +8,14 @@ from embeddings.strategies.registry import get_strategy_class -def test_full_record_strategy_creates_embedding_input(): - transformed_record = {"timdex_record_id": "test-123", "title": ["Test Title"]} - strategy = FullRecordStrategy( - timdex_record_id="test-123", - run_id="run-456", - run_record_offset=42, - transformed_record=transformed_record, - ) +def test_full_record_strategy_extracts_text(): + timdex_record = {"timdex_record_id": "test-123", "title": ["Test Title"]} + strategy = FullRecordStrategy() - embedding_input = strategy.to_embedding_input() + text = strategy.extract_text(timdex_record) - assert embedding_input.timdex_record_id == "test-123" - assert embedding_input.run_id == "run-456" - assert embedding_input.run_record_offset == 42 - assert embedding_input.embedding_strategy == "full_record" - assert embedding_input.text == json.dumps(transformed_record) + assert text == json.dumps(timdex_record) + assert strategy.STRATEGY_NAME == "full_record" def test_create_embedding_inputs_yields_cartesian_product():