Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,5 @@ cython_debug/
.DS_Store
output/
.vscode/

CLAUDE.md
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,28 @@ Options:
--help Show this message and exit.
```

### `create-embeddings`
```text
Usage: embeddings create-embeddings [OPTIONS]

Create embeddings for TIMDEX records.

Options:
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
[required]
--model-path PATH Path where the model will be downloaded to and
loaded from, e.g. '/path/to/model'. [required]
-d, --dataset-location PATH TIMDEX dataset location, e.g.
's3://timdex/dataset', to read records from.
[required]
--run-id TEXT TIMDEX ETL run id. [required]
--run-record-offset INTEGER TIMDEX ETL run record offset to start from,
default = 0. [required]
--record-limit INTEGER Limit number of records after --run-record-
offset, default = None (unlimited). [required]
--strategy TEXT Pre-embedding record transformation strategy to
use. Repeatable. [required]
--output-jsonl TEXT Optionally write embeddings to local JSONLines
file (primarily for testing).
--help Show this message and exit.
```
138 changes: 136 additions & 2 deletions embeddings/cli.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with the approach in the commit message!

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import TYPE_CHECKING

import click
import jsonlines
from timdex_dataset_api import TIMDEXDataset

from embeddings.config import configure_logger, configure_sentry
from embeddings.models.registry import get_model_class
Expand Down Expand Up @@ -150,8 +152,140 @@ def test_model_load(ctx: click.Context) -> None:
@main.command()
@click.pass_context
@model_required
def create_embedding(ctx: click.Context) -> None:
"""Create a single embedding for a single input text."""
@click.option(
"-d",
"--dataset-location",
required=True,
type=click.Path(),
help="TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from.",
)
@click.option(
"--run-id",
required=True,
type=str,
help="TIMDEX ETL run id.",
)
@click.option(
"--run-record-offset",
required=True,
type=int,
default=0,
help="TIMDEX ETL run record offset to start from, default = 0.",
)
@click.option(
"--record-limit",
required=True,
type=int,
default=None,
help="Limit number of records after --run-record-offset, default = None (unlimited).",
)
@click.option(
"--strategy",
type=str, # WIP: establish an enum of supported strategies
required=True,
multiple=True,
help="Pre-embedding record transformation strategy to use. Repeatable.",
)
@click.option(
"--output-jsonl",
required=False,
type=str,
default=None,
help="Optionally write embeddings to local JSONLines file (primarily for testing).",
)
def create_embeddings(
ctx: click.Context,
dataset_location: str,
run_id: str,
run_record_offset: int,
record_limit: int,
strategy: list[str],
output_jsonl: str,
) -> None:
"""Create embeddings for TIMDEX records."""
model: BaseEmbeddingModel = ctx.obj["model"]

# init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)

# query TIMDEX dataset for an iterator of records
timdex_records = timdex_dataset.read_dicts_iter(
columns=[
"timdex_record_id",
"run_id",
"run_record_offset",
"transformed_record",
],
run_id=run_id,
where=f"""run_record_offset >= {run_record_offset}""",
limit=record_limit,
action="index",
)

# create an iterator of InputTexts applying all requested strategies to all records
# WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that
# create texts based on the requested strategies (e.g. "full record"), which are
# captured in --strategy CLI args
# WIP NOTE: the following simulates that...
# DEBUG ------------------------------------------------------------------------------
import json # noqa: PLC0415

from embeddings.embedding import EmbeddingInput # noqa: PLC0415

input_records = (
EmbeddingInput(
timdex_record_id=timdex_record["timdex_record_id"],
run_id=timdex_record["run_id"],
run_record_offset=timdex_record["run_record_offset"],
embedding_strategy=_strategy,
text=json.dumps(timdex_record["transformed_record"].decode()),
)
for timdex_record in timdex_records
for _strategy in strategy
)
# DEBUG ------------------------------------------------------------------------------

# create an iterator of Embeddings via the embedding model
# WIP NOTE: this will use the embedding class .create_embeddings() bulk method
# WIP NOTE: the following simulates that...
# DEBUG ------------------------------------------------------------------------------
from embeddings.embedding import Embedding # noqa: PLC0415

embeddings = (
Embedding(
timdex_record_id=input_record.timdex_record_id,
run_id=input_record.run_id,
run_record_offset=input_record.run_record_offset,
embedding_strategy=input_record.embedding_strategy,
model_uri=model.model_uri,
embedding_vector=[0.1, 0.2, 0.3],
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
)
for input_record in input_records
)
# DEBUG ------------------------------------------------------------------------------

# if requested, write embeddings to a local JSONLines file
if output_jsonl:
with jsonlines.open(
output_jsonl,
mode="w",
dumps=lambda obj: json.dumps(
obj,
default=str,
),
Comment on lines +273 to +276
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was new to me: when using jsonlines.open() to get a writer, you can define a custom dumps= serializer. We needed the option default=str to coerce datetime objects into strings on serialization.

) as writer:
for embedding in embeddings:
writer.write(embedding.to_dict())
Comment on lines +270 to +279

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: this block could be a method to improve readability

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the thinking of encapsulating this somehow, but I'd like to wait on that until the other pieces are more established.

For example, I'm unsure if it makes sense for the embedding classes to perform writing, I'm thinking not. Therefore, it's basically the CLI that does the writing. So there is nowhere for a method per se, but we could have some utility functions? But if we go the utility function route, I'm unsure if a free floating function at the bottom of the file or the hopping around to another file is better than these couple of steps here.

Duly noted, but opting to wait for now.


# else, default writing embeddings back to TIMDEX dataset
else:
# WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
# NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
# Embedding instance will nearly 1:1 map to.
raise NotImplementedError

logger.info("Embeddings creation complete.")


if __name__ == "__main__": # pragma: no cover
Expand Down
5 changes: 2 additions & 3 deletions embeddings/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ def configure_logger(logger: logging.Logger, *, verbose: bool) -> str:
format="%(asctime)s %(levelname)s %(name)s.%(funcName)s() line %(lineno)d: "
"%(message)s"
)
logger.setLevel(logging.DEBUG)
for handler in logging.root.handlers:
handler.addFilter(logging.Filter("embeddings"))
Comment on lines -14 to -15

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

Copy link
Collaborator Author

@ghukill ghukill Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a great question, one that I felt deserved an entire commit 😅: 0024b8f.

Not sharing the commit to be glib and happy to elaborate more. In short, any applications that install our timdex_dataset_api library, it's helpful to get logs from TDA as well. Unfortunately, some of our other conventions for setting up logging make that difficult. I would argue they are over aggressively "only this app shall log". But I'd also argue went too hard the other direction with, "every library shall log, unless directed otherwise via WARNING_ONLY_LOGGERS".

To me, and noted in the commit, this could be a happy medium:

  • in the application, explicitly configure whic libraries you want --verbose to bump to DEBUG
  • all other libraries keep their default WARNING
  • not implemented, but opens the door for a DEBUG_LOGGERS env var that could toggle other libraries to DEBUG logging

TL/DR: moves to an opt-in pattern for debug logging, while putting TDA on the same footing as the application it's part of

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, since this was in the first commit, I didn't associate it with those changes!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at all - kind of sloppy on my part how this happened. Snuck the removal in an earlier PR, then this update is basically building on that.

logging.getLogger("embeddings").setLevel(logging.DEBUG)
logging.getLogger("timdex_dataset_api").setLevel(logging.DEBUG)
else:
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s.%(funcName)s(): %(message)s"
Expand Down
58 changes: 58 additions & 0 deletions embeddings/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import datetime
import json
from dataclasses import asdict, dataclass, field


@dataclass
class EmbeddingInput:
"""Encapsulates the inputs for an embedding.

When creating an embedding, we need to note what TIMDEX record the embedding is
associated with and what strategy was used to prepare the embedding input text from
the record itself.
Comment on lines +7 to +12
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better!


Args:
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
embedding_strategy: strategy used to create text for embedding
text: text to embed, created from the TIMDEX record via the embedding_strategy
Comment on lines +16 to +17

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docstrings and names could be a little clearer, is there a more descriptive name than text?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I agree that the class names and docstrings may need some touches, I do feel like text is a succinct and accurate property for this class. The class and docstring should clearly communicate that an instance of this class is:

  1. came from a specific TIMDEX record
  2. we prepared a string of text to create the embedding from via strategy XYZ
  3. the actual string of text we'll send to the embedding model is found at .text

I'm unsure if we benefit from something like .text_to_embed, as I think .text on this class should kind of imply that. It's like a Meal class with a .desert property, where the relationship feels implied and wouldn't expect .desert_to_eat.

"""

timdex_record_id: str
run_id: str
run_record_offset: int
embedding_strategy: str
text: str


@dataclass
class Embedding:
"""Encapsulates a single embedding.

Args:
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
model_uri: model URI used to create the embedding
embedding_strategy: strategy used to create text for embedding
embedding_vector: vector representation of embedding
embedding_token_weights: decoded token:weight pairs from sparse vector
- only applicable to models that produce this output
"""

timdex_record_id: str
run_id: str
run_record_offset: int
model_uri: str
embedding_strategy: str
embedding_vector: list[float]
embedding_token_weights: dict

timestamp: datetime.datetime = field(
default_factory=lambda: datetime.datetime.now(datetime.UTC)
)

def to_dict(self) -> dict:
"""Marshal to dictionary."""
return asdict(self)

def to_json(self) -> str:
"""Serialize to JSON."""
return json.dumps(self.to_dict(), default=str)
22 changes: 22 additions & 0 deletions embeddings/models/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Base class for embedding models."""

from abc import ABC, abstractmethod
from collections.abc import Iterator
from pathlib import Path

from embeddings.embedding import Embedding, EmbeddingInput


class BaseEmbeddingModel(ABC):
"""Abstract base class for embedding models.
Expand Down Expand Up @@ -46,3 +49,22 @@ def download(self) -> Path:
@abstractmethod
def load(self) -> None:
"""Load model from self.model_path."""

@abstractmethod
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
"""Create an Embedding for an EmbeddingInput.

Args:
input_record: EmbeddingInput instance
"""

def create_embeddings(
self, input_records: Iterator[EmbeddingInput]
) -> Iterator[Embedding]:
"""Yield Embeddings for an iterator of InputRecords.

Args:
input_records: iterator of InputRecords
"""
for input_text in input_records:
yield self.create_embedding(input_text)
4 changes: 4 additions & 0 deletions embeddings/models/os_neural_sparse_doc_v3_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from huggingface_hub import snapshot_download
from transformers import AutoModelForMaskedLM, AutoTokenizer

from embeddings.embedding import Embedding, EmbeddingInput
from embeddings.models.base import BaseEmbeddingModel

if TYPE_CHECKING:
Expand Down Expand Up @@ -161,3 +162,6 @@ def load(self) -> None:
self._id_to_token[token_id] = token

logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s")

def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
raise NotImplementedError
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ requires-python = ">=3.12"
dependencies = [
"click>=8.2.1",
"huggingface-hub>=0.26.0",
"jsonlines>=4.0.0",
"sentry-sdk>=2.34.1",
"timdex-dataset-api",
"torch>=2.9.0",
Expand Down Expand Up @@ -39,6 +40,11 @@ exclude = [
"output/"
]

[[tool.mypy.overrides]]
module = ["timdex_dataset_api.*"]
follow_untyped_imports = true


[tool.pytest.ini_options]
log_level = "INFO"

Expand Down Expand Up @@ -88,6 +94,7 @@ fixture-parentheses = false
"tests/**/*" = [
"ANN",
"ARG001",
"PLR2004",
"S101",
]

Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest
from click.testing import CliRunner

from embeddings.embedding import Embedding, EmbeddingInput
from embeddings.models import registry
from embeddings.models.base import BaseEmbeddingModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,13 +45,31 @@ def download(self) -> Path:
def load(self) -> None:
logger.info("Model loaded successfully, 1.5s")

def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
return Embedding(
timdex_record_id=input_record.timdex_record_id,
run_id=input_record.run_id,
run_record_offset=input_record.run_record_offset,
embedding_strategy=input_record.embedding_strategy,
model_uri=self.model_uri,
embedding_vector=[0.1, 0.2, 0.3],
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
)


@pytest.fixture
def mock_model(tmp_path):
"""Fixture providing a MockEmbeddingModel instance."""
return MockEmbeddingModel(tmp_path / "model")


@pytest.fixture
def register_mock_model(monkeypatch):
"""Register MockEmbeddingModel in the model registry."""
monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel)
monkeypatch.setenv("TE_MODEL_PATH", "/fake/path")


@pytest.fixture
def neural_sparse_doc_v3_gte_fake_model_directory(tmp_path):
"""Create a fake downloaded model directory with required files."""
Expand Down
Loading