-
Notifications
You must be signed in to change notification settings - Fork 0
USE 131 - Framework for embedding input strategies #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Strategies for transforming TIMDEX records into EmbeddingInputs.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| from abc import ABC, abstractmethod | ||
|
|
||
|
|
||
| class BaseStrategy(ABC): | ||
| """Base class for embedding input strategies. | ||
|
|
||
| All child classes must set class level attribute STRATEGY_NAME. | ||
| """ | ||
|
|
||
| STRATEGY_NAME: str # type hint to document the requirement | ||
ghukill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 | ||
| super().__init_subclass__(**kwargs) | ||
|
|
||
| # require class level STRATEGY_NAME to be set | ||
| if not hasattr(cls, "STRATEGY_NAME"): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this check necessary since the attribute exists in the base class even if it's not defined?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think it is. Without this, you could leave out the class level Another pattern could be something like this: class BaseStrategy(ABC):
@absractmethod
@classmethod
def strategy_name(self) -> str:
# return strategy name...Which would allow getting the strategy name for an uninstantiated class, which is important. We do something similar in other projects I think. But this pattern, with a little logic in the base class, enforces that child classes define it. |
||
| raise TypeError(f"{cls.__name__} must define 'STRATEGY_NAME' class attribute") | ||
| if not isinstance(cls.STRATEGY_NAME, str): | ||
| raise TypeError( | ||
| f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string" | ||
| ) | ||
|
|
||
| @abstractmethod | ||
| def extract_text(self, timdex_record: dict) -> str: | ||
| """Extract text to be embedded from transformed_record. | ||
|
|
||
| Args: | ||
| timdex_record: TIMDEX JSON record ("transformed_record" in TIMDEX dataset) | ||
| """ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import json | ||
|
|
||
| from embeddings.strategies.base import BaseStrategy | ||
|
|
||
|
|
||
| class FullRecordStrategy(BaseStrategy): | ||
| """Serialize entire TIMDEX record JSON as embedding input.""" | ||
|
|
||
| STRATEGY_NAME = "full_record" | ||
|
|
||
| def extract_text(self, timdex_record: dict) -> str: | ||
| """Serialize the entire transformed_record as JSON.""" | ||
| return json.dumps(timdex_record) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import json | ||
| from collections.abc import Iterator | ||
|
|
||
| from embeddings.embedding import EmbeddingInput | ||
| from embeddings.strategies.registry import get_strategy_class | ||
|
|
||
|
|
||
| def create_embedding_inputs( | ||
| timdex_dataset_records: Iterator[dict], | ||
| strategies: list[str], | ||
| ) -> Iterator[EmbeddingInput]: | ||
| """Yield EmbeddingInput instances for all records x all strategies. | ||
|
|
||
| Creates a cartesian product: each record is processed by each strategy, | ||
| yielding one EmbeddingInput per combination. | ||
|
|
||
| Args: | ||
| timdex_dataset_records: Iterator of TIMDEX records. | ||
| Expected keys: timdex_record_id, run_id, run_record_offset, | ||
| transformed_record (bytes) | ||
| strategies: List of strategy names to apply | ||
|
|
||
| Yields: | ||
| EmbeddingInput instances ready for embedding model | ||
|
|
||
| Example: | ||
| 100 records x 3 strategies = 300 EmbeddingInput instances | ||
| """ | ||
| # instantiate strategy transformers | ||
| transformers = [get_strategy_class(strategy)() for strategy in strategies] | ||
|
|
||
| # loop through records and apply all strategies, yielding an EmbeddingInput for each | ||
| for timdex_dataset_record in timdex_dataset_records: | ||
|
|
||
| # decode and parse the TIMDEX JSON record once for all requested strategies | ||
| timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode()) | ||
|
|
||
| for transformer in transformers: | ||
| # prepare text for embedding from transformer strategy | ||
| text = transformer.extract_text(timdex_record) | ||
|
|
||
| # emit an EmbeddingInput instance | ||
| yield EmbeddingInput( | ||
| timdex_record_id=timdex_dataset_record["timdex_record_id"], | ||
| run_id=timdex_dataset_record["run_id"], | ||
| run_record_offset=timdex_dataset_record["run_record_offset"], | ||
| embedding_strategy=transformer.STRATEGY_NAME, | ||
| text=text, | ||
| ) |
ghukill marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| import logging | ||
|
|
||
| from embeddings.strategies.base import BaseStrategy | ||
| from embeddings.strategies.full_record import FullRecordStrategy | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| STRATEGY_CLASSES = [ | ||
| FullRecordStrategy, | ||
| ] | ||
|
|
||
| STRATEGY_REGISTRY: dict[str, type[BaseStrategy]] = { | ||
| strategy.STRATEGY_NAME: strategy for strategy in STRATEGY_CLASSES | ||
| } | ||
|
|
||
|
|
||
| def get_strategy_class(strategy_name: str) -> type[BaseStrategy]: | ||
| """Get strategy class by name. | ||
|
|
||
| Args: | ||
| strategy_name: Name of the strategy to retrieve | ||
| """ | ||
| if strategy_name not in STRATEGY_REGISTRY: | ||
| available = ", ".join(sorted(STRATEGY_REGISTRY.keys())) | ||
| msg = f"Unknown strategy: {strategy_name}. Available: {available}" | ||
| logger.error(msg) | ||
| raise ValueError(msg) | ||
|
|
||
| return STRATEGY_REGISTRY[strategy_name] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| import json | ||
|
|
||
| import pytest | ||
|
|
||
| from embeddings.strategies.base import BaseStrategy | ||
| from embeddings.strategies.full_record import FullRecordStrategy | ||
| from embeddings.strategies.processor import create_embedding_inputs | ||
| from embeddings.strategies.registry import get_strategy_class | ||
|
|
||
|
|
||
| def test_full_record_strategy_extracts_text(): | ||
| timdex_record = {"timdex_record_id": "test-123", "title": ["Test Title"]} | ||
| strategy = FullRecordStrategy() | ||
|
|
||
| text = strategy.extract_text(timdex_record) | ||
|
|
||
| assert text == json.dumps(timdex_record) | ||
| assert strategy.STRATEGY_NAME == "full_record" | ||
|
|
||
|
|
||
| def test_create_embedding_inputs_yields_cartesian_product(): | ||
| # two records | ||
| timdex_records = iter( | ||
| [ | ||
| { | ||
| "timdex_record_id": "id-1", | ||
| "run_id": "run-1", | ||
| "run_record_offset": 0, | ||
| "transformed_record": b'{"title": ["Record 1"]}', | ||
| }, | ||
| { | ||
| "timdex_record_id": "id-2", | ||
| "run_id": "run-1", | ||
| "run_record_offset": 1, | ||
| "transformed_record": b'{"title": ["Record 2"]}', | ||
| }, | ||
| ] | ||
| ) | ||
|
|
||
| # single strategy (for now) | ||
| strategies = ["full_record"] | ||
|
|
||
| embedding_inputs = list(create_embedding_inputs(timdex_records, strategies)) | ||
|
|
||
| assert len(embedding_inputs) == 2 | ||
| assert embedding_inputs[0].timdex_record_id == "id-1" | ||
| assert embedding_inputs[0].embedding_strategy == "full_record" | ||
| assert embedding_inputs[1].timdex_record_id == "id-2" | ||
| assert embedding_inputs[1].embedding_strategy == "full_record" | ||
|
|
||
|
|
||
| def test_get_strategy_class_returns_correct_class(): | ||
| strategy_class = get_strategy_class("full_record") | ||
| assert strategy_class is FullRecordStrategy | ||
|
|
||
|
|
||
| def test_get_strategy_class_raises_for_unknown_strategy(): | ||
| with pytest.raises(ValueError, match="Unknown strategy"): | ||
| get_strategy_class("nonexistent_strategy") | ||
|
|
||
|
|
||
| def test_subclass_without_strategy_name_raises_type_error(): | ||
| with pytest.raises(TypeError, match="must define 'STRATEGY_NAME' class attribute"): | ||
|
|
||
| class InvalidStrategy(BaseStrategy): | ||
| pass | ||
|
|
||
|
|
||
| def test_subclass_with_non_string_strategy_name_raises_type_error(): | ||
| with pytest.raises( | ||
| TypeError, match="must override 'STRATEGY_NAME' with a valid string" | ||
| ): | ||
|
|
||
| class InvalidStrategy(BaseStrategy): | ||
| STRATEGY_NAME = 123 |
Uh oh!
There was an error while loading. Please reload this page.