diff --git a/README.md b/README.md index 1295ad6..7ff9fc9 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,9 @@ Options: default = 0. [required] --record-limit INTEGER Limit number of records after --run-record- offset, default = None (unlimited). [required] - --strategy TEXT Pre-embedding record transformation strategy to - use. Repeatable. [required] + --strategy [full_record] Pre-embedding record transformation strategy. + Repeatable to apply multiple strategies. + [required] --output-jsonl TEXT Optionally write embeddings to local JSONLines file (primarily for testing). --help Show this message and exit. diff --git a/embeddings/cli.py b/embeddings/cli.py index 3c459ee..11f203b 100644 --- a/embeddings/cli.py +++ b/embeddings/cli.py @@ -1,4 +1,5 @@ import functools +import json import logging import time from collections.abc import Callable @@ -12,6 +13,8 @@ from embeddings.config import configure_logger, configure_sentry from embeddings.models.registry import get_model_class +from embeddings.strategies.processor import create_embedding_inputs +from embeddings.strategies.registry import STRATEGY_REGISTRY logger = logging.getLogger(__name__) @@ -181,10 +184,13 @@ def test_model_load(ctx: click.Context) -> None: ) @click.option( "--strategy", - type=str, # WIP: establish an enum of supported strategies + type=click.Choice(list(STRATEGY_REGISTRY.keys())), required=True, multiple=True, - help="Pre-embedding record transformation strategy to use. Repeatable.", + help=( + "Pre-embedding record transformation strategy. " + "Repeatable to apply multiple strategies." + ), ) @click.option( "--output-jsonl", @@ -222,48 +228,11 @@ def create_embeddings( action="index", ) - # create an iterator of InputTexts applying all requested strategies to all records - # WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that - # create texts based on the requested strategies (e.g. "full record"), which are - # captured in --strategy CLI args - # WIP NOTE: the following simulates that... - # DEBUG ------------------------------------------------------------------------------ - import json # noqa: PLC0415 - - from embeddings.embedding import EmbeddingInput # noqa: PLC0415 - - input_records = ( - EmbeddingInput( - timdex_record_id=timdex_record["timdex_record_id"], - run_id=timdex_record["run_id"], - run_record_offset=timdex_record["run_record_offset"], - embedding_strategy=_strategy, - text=json.dumps(timdex_record["transformed_record"].decode()), - ) - for timdex_record in timdex_records - for _strategy in strategy - ) - # DEBUG ------------------------------------------------------------------------------ - - # create an iterator of Embeddings via the embedding model - # WIP NOTE: this will use the embedding class .create_embeddings() bulk method - # WIP NOTE: the following simulates that... - # DEBUG ------------------------------------------------------------------------------ - from embeddings.embedding import Embedding # noqa: PLC0415 - - embeddings = ( - Embedding( - timdex_record_id=input_record.timdex_record_id, - run_id=input_record.run_id, - run_record_offset=input_record.run_record_offset, - embedding_strategy=input_record.embedding_strategy, - model_uri=model.model_uri, - embedding_vector=[0.1, 0.2, 0.3], - embedding_token_weights={"coffee": 0.9, "seattle": 0.5}, - ) - for input_record in input_records - ) - # DEBUG ------------------------------------------------------------------------------ + # create an iterator of EmbeddingInputs applying all requested strategies + input_records = create_embedding_inputs(timdex_records, list(strategy)) + + # create embeddings via the embedding model + embeddings = model.create_embeddings(input_records) # if requested, write embeddings to a local JSONLines file if output_jsonl: diff --git a/embeddings/strategies/__init__.py b/embeddings/strategies/__init__.py new file mode 100644 index 0000000..19116b1 --- /dev/null +++ b/embeddings/strategies/__init__.py @@ -0,0 +1 @@ +"""Strategies for transforming TIMDEX records into EmbeddingInputs.""" diff --git a/embeddings/strategies/base.py b/embeddings/strategies/base.py new file mode 100644 index 0000000..0e27bd5 --- /dev/null +++ b/embeddings/strategies/base.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + + +class BaseStrategy(ABC): + """Base class for embedding input strategies. + + All child classes must set class level attribute STRATEGY_NAME. + """ + + STRATEGY_NAME: str # type hint to document the requirement + + def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 + super().__init_subclass__(**kwargs) + + # require class level STRATEGY_NAME to be set + if not hasattr(cls, "STRATEGY_NAME"): + raise TypeError(f"{cls.__name__} must define 'STRATEGY_NAME' class attribute") + if not isinstance(cls.STRATEGY_NAME, str): + raise TypeError( + f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string" + ) + + @abstractmethod + def extract_text(self, timdex_record: dict) -> str: + """Extract text to be embedded from transformed_record. + + Args: + timdex_record: TIMDEX JSON record ("transformed_record" in TIMDEX dataset) + """ diff --git a/embeddings/strategies/full_record.py b/embeddings/strategies/full_record.py new file mode 100644 index 0000000..5dae9d6 --- /dev/null +++ b/embeddings/strategies/full_record.py @@ -0,0 +1,13 @@ +import json + +from embeddings.strategies.base import BaseStrategy + + +class FullRecordStrategy(BaseStrategy): + """Serialize entire TIMDEX record JSON as embedding input.""" + + STRATEGY_NAME = "full_record" + + def extract_text(self, timdex_record: dict) -> str: + """Serialize the entire transformed_record as JSON.""" + return json.dumps(timdex_record) diff --git a/embeddings/strategies/processor.py b/embeddings/strategies/processor.py new file mode 100644 index 0000000..5316729 --- /dev/null +++ b/embeddings/strategies/processor.py @@ -0,0 +1,49 @@ +import json +from collections.abc import Iterator + +from embeddings.embedding import EmbeddingInput +from embeddings.strategies.registry import get_strategy_class + + +def create_embedding_inputs( + timdex_dataset_records: Iterator[dict], + strategies: list[str], +) -> Iterator[EmbeddingInput]: + """Yield EmbeddingInput instances for all records x all strategies. + + Creates a cartesian product: each record is processed by each strategy, + yielding one EmbeddingInput per combination. + + Args: + timdex_dataset_records: Iterator of TIMDEX records. + Expected keys: timdex_record_id, run_id, run_record_offset, + transformed_record (bytes) + strategies: List of strategy names to apply + + Yields: + EmbeddingInput instances ready for embedding model + + Example: + 100 records x 3 strategies = 300 EmbeddingInput instances + """ + # instantiate strategy transformers + transformers = [get_strategy_class(strategy)() for strategy in strategies] + + # loop through records and apply all strategies, yielding an EmbeddingInput for each + for timdex_dataset_record in timdex_dataset_records: + + # decode and parse the TIMDEX JSON record once for all requested strategies + timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode()) + + for transformer in transformers: + # prepare text for embedding from transformer strategy + text = transformer.extract_text(timdex_record) + + # emit an EmbeddingInput instance + yield EmbeddingInput( + timdex_record_id=timdex_dataset_record["timdex_record_id"], + run_id=timdex_dataset_record["run_id"], + run_record_offset=timdex_dataset_record["run_record_offset"], + embedding_strategy=transformer.STRATEGY_NAME, + text=text, + ) diff --git a/embeddings/strategies/registry.py b/embeddings/strategies/registry.py new file mode 100644 index 0000000..3dc12bb --- /dev/null +++ b/embeddings/strategies/registry.py @@ -0,0 +1,29 @@ +import logging + +from embeddings.strategies.base import BaseStrategy +from embeddings.strategies.full_record import FullRecordStrategy + +logger = logging.getLogger(__name__) + +STRATEGY_CLASSES = [ + FullRecordStrategy, +] + +STRATEGY_REGISTRY: dict[str, type[BaseStrategy]] = { + strategy.STRATEGY_NAME: strategy for strategy in STRATEGY_CLASSES +} + + +def get_strategy_class(strategy_name: str) -> type[BaseStrategy]: + """Get strategy class by name. + + Args: + strategy_name: Name of the strategy to retrieve + """ + if strategy_name not in STRATEGY_REGISTRY: + available = ", ".join(sorted(STRATEGY_REGISTRY.keys())) + msg = f"Unknown strategy: {strategy_name}. Available: {available}" + logger.error(msg) + raise ValueError(msg) + + return STRATEGY_REGISTRY[strategy_name] diff --git a/tests/test_strategies.py b/tests/test_strategies.py new file mode 100644 index 0000000..ede712a --- /dev/null +++ b/tests/test_strategies.py @@ -0,0 +1,75 @@ +import json + +import pytest + +from embeddings.strategies.base import BaseStrategy +from embeddings.strategies.full_record import FullRecordStrategy +from embeddings.strategies.processor import create_embedding_inputs +from embeddings.strategies.registry import get_strategy_class + + +def test_full_record_strategy_extracts_text(): + timdex_record = {"timdex_record_id": "test-123", "title": ["Test Title"]} + strategy = FullRecordStrategy() + + text = strategy.extract_text(timdex_record) + + assert text == json.dumps(timdex_record) + assert strategy.STRATEGY_NAME == "full_record" + + +def test_create_embedding_inputs_yields_cartesian_product(): + # two records + timdex_records = iter( + [ + { + "timdex_record_id": "id-1", + "run_id": "run-1", + "run_record_offset": 0, + "transformed_record": b'{"title": ["Record 1"]}', + }, + { + "timdex_record_id": "id-2", + "run_id": "run-1", + "run_record_offset": 1, + "transformed_record": b'{"title": ["Record 2"]}', + }, + ] + ) + + # single strategy (for now) + strategies = ["full_record"] + + embedding_inputs = list(create_embedding_inputs(timdex_records, strategies)) + + assert len(embedding_inputs) == 2 + assert embedding_inputs[0].timdex_record_id == "id-1" + assert embedding_inputs[0].embedding_strategy == "full_record" + assert embedding_inputs[1].timdex_record_id == "id-2" + assert embedding_inputs[1].embedding_strategy == "full_record" + + +def test_get_strategy_class_returns_correct_class(): + strategy_class = get_strategy_class("full_record") + assert strategy_class is FullRecordStrategy + + +def test_get_strategy_class_raises_for_unknown_strategy(): + with pytest.raises(ValueError, match="Unknown strategy"): + get_strategy_class("nonexistent_strategy") + + +def test_subclass_without_strategy_name_raises_type_error(): + with pytest.raises(TypeError, match="must define 'STRATEGY_NAME' class attribute"): + + class InvalidStrategy(BaseStrategy): + pass + + +def test_subclass_with_non_string_strategy_name_raises_type_error(): + with pytest.raises( + TypeError, match="must override 'STRATEGY_NAME' with a valid string" + ): + + class InvalidStrategy(BaseStrategy): + STRATEGY_NAME = 123