Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ Options:
default = 0. [required]
--record-limit INTEGER Limit number of records after --run-record-
offset, default = None (unlimited). [required]
--strategy TEXT Pre-embedding record transformation strategy to
use. Repeatable. [required]
--strategy [full_record] Pre-embedding record transformation strategy.
Repeatable to apply multiple strategies.
[required]
--output-jsonl TEXT Optionally write embeddings to local JSONLines
file (primarily for testing).
--help Show this message and exit.
Expand Down
57 changes: 13 additions & 44 deletions embeddings/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import json
import logging
import time
from collections.abc import Callable
Expand All @@ -12,6 +13,8 @@

from embeddings.config import configure_logger, configure_sentry
from embeddings.models.registry import get_model_class
from embeddings.strategies.processor import create_embedding_inputs
from embeddings.strategies.registry import STRATEGY_REGISTRY

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -181,10 +184,13 @@ def test_model_load(ctx: click.Context) -> None:
)
@click.option(
"--strategy",
type=str, # WIP: establish an enum of supported strategies
type=click.Choice(list(STRATEGY_REGISTRY.keys())),
required=True,
multiple=True,
help="Pre-embedding record transformation strategy to use. Repeatable.",
help=(
"Pre-embedding record transformation strategy. "
"Repeatable to apply multiple strategies."
),
)
@click.option(
"--output-jsonl",
Expand Down Expand Up @@ -222,48 +228,11 @@ def create_embeddings(
action="index",
)

# create an iterator of InputTexts applying all requested strategies to all records
# WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that
# create texts based on the requested strategies (e.g. "full record"), which are
# captured in --strategy CLI args
# WIP NOTE: the following simulates that...
# DEBUG ------------------------------------------------------------------------------
import json # noqa: PLC0415

from embeddings.embedding import EmbeddingInput # noqa: PLC0415

input_records = (
EmbeddingInput(
timdex_record_id=timdex_record["timdex_record_id"],
run_id=timdex_record["run_id"],
run_record_offset=timdex_record["run_record_offset"],
embedding_strategy=_strategy,
text=json.dumps(timdex_record["transformed_record"].decode()),
)
for timdex_record in timdex_records
for _strategy in strategy
)
# DEBUG ------------------------------------------------------------------------------

# create an iterator of Embeddings via the embedding model
# WIP NOTE: this will use the embedding class .create_embeddings() bulk method
# WIP NOTE: the following simulates that...
# DEBUG ------------------------------------------------------------------------------
from embeddings.embedding import Embedding # noqa: PLC0415

embeddings = (
Embedding(
timdex_record_id=input_record.timdex_record_id,
run_id=input_record.run_id,
run_record_offset=input_record.run_record_offset,
embedding_strategy=input_record.embedding_strategy,
model_uri=model.model_uri,
embedding_vector=[0.1, 0.2, 0.3],
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
)
for input_record in input_records
)
# DEBUG ------------------------------------------------------------------------------
# create an iterator of EmbeddingInputs applying all requested strategies
input_records = create_embedding_inputs(timdex_records, list(strategy))

# create embeddings via the embedding model
embeddings = model.create_embeddings(input_records)

# if requested, write embeddings to a local JSONLines file
if output_jsonl:
Expand Down
1 change: 1 addition & 0 deletions embeddings/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Strategies for transforming TIMDEX records into EmbeddingInputs."""
29 changes: 29 additions & 0 deletions embeddings/strategies/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod


class BaseStrategy(ABC):
"""Base class for embedding input strategies.

All child classes must set class level attribute STRATEGY_NAME.
"""

STRATEGY_NAME: str # type hint to document the requirement

def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
super().__init_subclass__(**kwargs)

# require class level STRATEGY_NAME to be set
if not hasattr(cls, "STRATEGY_NAME"):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary since the attribute exists in the base class even if it's not defined?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it is. Without this, you could leave out the class level STRATEGY_NAME attribute on a real strategy class.

Another pattern could be something like this:

class BaseStrategy(ABC):
  @absractmethod
  @classmethod
  def strategy_name(self) -> str:
    # return strategy name...

Which would allow getting the strategy name for an uninstantiated class, which is important. We do something similar in other projects I think.

But this pattern, with a little logic in the base class, enforces that child classes define it.

raise TypeError(f"{cls.__name__} must define 'STRATEGY_NAME' class attribute")
if not isinstance(cls.STRATEGY_NAME, str):
raise TypeError(
f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string"
)

@abstractmethod
def extract_text(self, timdex_record: dict) -> str:
"""Extract text to be embedded from transformed_record.

Args:
timdex_record: TIMDEX JSON record ("transformed_record" in TIMDEX dataset)
"""
13 changes: 13 additions & 0 deletions embeddings/strategies/full_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import json

from embeddings.strategies.base import BaseStrategy


class FullRecordStrategy(BaseStrategy):
"""Serialize entire TIMDEX record JSON as embedding input."""

STRATEGY_NAME = "full_record"

def extract_text(self, timdex_record: dict) -> str:
"""Serialize the entire transformed_record as JSON."""
return json.dumps(timdex_record)
49 changes: 49 additions & 0 deletions embeddings/strategies/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
from collections.abc import Iterator

from embeddings.embedding import EmbeddingInput
from embeddings.strategies.registry import get_strategy_class


def create_embedding_inputs(
timdex_dataset_records: Iterator[dict],
strategies: list[str],
) -> Iterator[EmbeddingInput]:
"""Yield EmbeddingInput instances for all records x all strategies.

Creates a cartesian product: each record is processed by each strategy,
yielding one EmbeddingInput per combination.

Args:
timdex_dataset_records: Iterator of TIMDEX records.
Expected keys: timdex_record_id, run_id, run_record_offset,
transformed_record (bytes)
strategies: List of strategy names to apply

Yields:
EmbeddingInput instances ready for embedding model

Example:
100 records x 3 strategies = 300 EmbeddingInput instances
"""
# instantiate strategy transformers
transformers = [get_strategy_class(strategy)() for strategy in strategies]

# loop through records and apply all strategies, yielding an EmbeddingInput for each
for timdex_dataset_record in timdex_dataset_records:

# decode and parse the TIMDEX JSON record once for all requested strategies
timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode())

for transformer in transformers:
# prepare text for embedding from transformer strategy
text = transformer.extract_text(timdex_record)

# emit an EmbeddingInput instance
yield EmbeddingInput(
timdex_record_id=timdex_dataset_record["timdex_record_id"],
run_id=timdex_dataset_record["run_id"],
run_record_offset=timdex_dataset_record["run_record_offset"],
embedding_strategy=transformer.STRATEGY_NAME,
text=text,
)
29 changes: 29 additions & 0 deletions embeddings/strategies/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging

from embeddings.strategies.base import BaseStrategy
from embeddings.strategies.full_record import FullRecordStrategy

logger = logging.getLogger(__name__)

STRATEGY_CLASSES = [
FullRecordStrategy,
]

STRATEGY_REGISTRY: dict[str, type[BaseStrategy]] = {
strategy.STRATEGY_NAME: strategy for strategy in STRATEGY_CLASSES
}


def get_strategy_class(strategy_name: str) -> type[BaseStrategy]:
"""Get strategy class by name.

Args:
strategy_name: Name of the strategy to retrieve
"""
if strategy_name not in STRATEGY_REGISTRY:
available = ", ".join(sorted(STRATEGY_REGISTRY.keys()))
msg = f"Unknown strategy: {strategy_name}. Available: {available}"
logger.error(msg)
raise ValueError(msg)

return STRATEGY_REGISTRY[strategy_name]
75 changes: 75 additions & 0 deletions tests/test_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json

import pytest

from embeddings.strategies.base import BaseStrategy
from embeddings.strategies.full_record import FullRecordStrategy
from embeddings.strategies.processor import create_embedding_inputs
from embeddings.strategies.registry import get_strategy_class


def test_full_record_strategy_extracts_text():
timdex_record = {"timdex_record_id": "test-123", "title": ["Test Title"]}
strategy = FullRecordStrategy()

text = strategy.extract_text(timdex_record)

assert text == json.dumps(timdex_record)
assert strategy.STRATEGY_NAME == "full_record"


def test_create_embedding_inputs_yields_cartesian_product():
# two records
timdex_records = iter(
[
{
"timdex_record_id": "id-1",
"run_id": "run-1",
"run_record_offset": 0,
"transformed_record": b'{"title": ["Record 1"]}',
},
{
"timdex_record_id": "id-2",
"run_id": "run-1",
"run_record_offset": 1,
"transformed_record": b'{"title": ["Record 2"]}',
},
]
)

# single strategy (for now)
strategies = ["full_record"]

embedding_inputs = list(create_embedding_inputs(timdex_records, strategies))

assert len(embedding_inputs) == 2
assert embedding_inputs[0].timdex_record_id == "id-1"
assert embedding_inputs[0].embedding_strategy == "full_record"
assert embedding_inputs[1].timdex_record_id == "id-2"
assert embedding_inputs[1].embedding_strategy == "full_record"


def test_get_strategy_class_returns_correct_class():
strategy_class = get_strategy_class("full_record")
assert strategy_class is FullRecordStrategy


def test_get_strategy_class_raises_for_unknown_strategy():
with pytest.raises(ValueError, match="Unknown strategy"):
get_strategy_class("nonexistent_strategy")


def test_subclass_without_strategy_name_raises_type_error():
with pytest.raises(TypeError, match="must define 'STRATEGY_NAME' class attribute"):

class InvalidStrategy(BaseStrategy):
pass


def test_subclass_with_non_string_strategy_name_raises_type_error():
with pytest.raises(
TypeError, match="must override 'STRATEGY_NAME' with a valid string"
):

class InvalidStrategy(BaseStrategy):
STRATEGY_NAME = 123