-
Notifications
You must be signed in to change notification settings - Fork 0
Stub CLI command methods to create embeddings #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ce181df
0024b8f
d433f17
de351a1
d4e6e54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,3 +155,5 @@ cython_debug/ | |
| .DS_Store | ||
| output/ | ||
| .vscode/ | ||
|
|
||
| CLAUDE.md | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| from typing import TYPE_CHECKING | ||
|
|
||
| import click | ||
| import jsonlines | ||
| from timdex_dataset_api import TIMDEXDataset | ||
|
|
||
| from embeddings.config import configure_logger, configure_sentry | ||
| from embeddings.models.registry import get_model_class | ||
|
|
@@ -150,8 +152,140 @@ def test_model_load(ctx: click.Context) -> None: | |
| @main.command() | ||
| @click.pass_context | ||
| @model_required | ||
| def create_embedding(ctx: click.Context) -> None: | ||
| """Create a single embedding for a single input text.""" | ||
| @click.option( | ||
| "-d", | ||
| "--dataset-location", | ||
| required=True, | ||
| type=click.Path(), | ||
| help="TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from.", | ||
| ) | ||
| @click.option( | ||
| "--run-id", | ||
| required=True, | ||
| type=str, | ||
| help="TIMDEX ETL run id.", | ||
| ) | ||
| @click.option( | ||
| "--run-record-offset", | ||
| required=True, | ||
| type=int, | ||
| default=0, | ||
| help="TIMDEX ETL run record offset to start from, default = 0.", | ||
| ) | ||
| @click.option( | ||
| "--record-limit", | ||
| required=True, | ||
| type=int, | ||
| default=None, | ||
| help="Limit number of records after --run-record-offset, default = None (unlimited).", | ||
| ) | ||
| @click.option( | ||
| "--strategy", | ||
| type=str, # WIP: establish an enum of supported strategies | ||
ghukill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| required=True, | ||
| multiple=True, | ||
| help="Pre-embedding record transformation strategy to use. Repeatable.", | ||
| ) | ||
| @click.option( | ||
| "--output-jsonl", | ||
| required=False, | ||
| type=str, | ||
| default=None, | ||
| help="Optionally write embeddings to local JSONLines file (primarily for testing).", | ||
| ) | ||
| def create_embeddings( | ||
| ctx: click.Context, | ||
| dataset_location: str, | ||
| run_id: str, | ||
| run_record_offset: int, | ||
| record_limit: int, | ||
| strategy: list[str], | ||
| output_jsonl: str, | ||
| ) -> None: | ||
| """Create embeddings for TIMDEX records.""" | ||
| model: BaseEmbeddingModel = ctx.obj["model"] | ||
|
|
||
| # init TIMDEXDataset | ||
| timdex_dataset = TIMDEXDataset(dataset_location) | ||
|
|
||
| # query TIMDEX dataset for an iterator of records | ||
| timdex_records = timdex_dataset.read_dicts_iter( | ||
| columns=[ | ||
| "timdex_record_id", | ||
| "run_id", | ||
| "run_record_offset", | ||
| "transformed_record", | ||
| ], | ||
| run_id=run_id, | ||
| where=f"""run_record_offset >= {run_record_offset}""", | ||
| limit=record_limit, | ||
| action="index", | ||
| ) | ||
|
|
||
| # create an iterator of InputTexts applying all requested strategies to all records | ||
| # WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that | ||
| # create texts based on the requested strategies (e.g. "full record"), which are | ||
| # captured in --strategy CLI args | ||
| # WIP NOTE: the following simulates that... | ||
| # DEBUG ------------------------------------------------------------------------------ | ||
| import json # noqa: PLC0415 | ||
|
|
||
| from embeddings.embedding import EmbeddingInput # noqa: PLC0415 | ||
|
|
||
| input_records = ( | ||
| EmbeddingInput( | ||
| timdex_record_id=timdex_record["timdex_record_id"], | ||
| run_id=timdex_record["run_id"], | ||
| run_record_offset=timdex_record["run_record_offset"], | ||
| embedding_strategy=_strategy, | ||
| text=json.dumps(timdex_record["transformed_record"].decode()), | ||
| ) | ||
| for timdex_record in timdex_records | ||
| for _strategy in strategy | ||
| ) | ||
| # DEBUG ------------------------------------------------------------------------------ | ||
|
|
||
| # create an iterator of Embeddings via the embedding model | ||
| # WIP NOTE: this will use the embedding class .create_embeddings() bulk method | ||
| # WIP NOTE: the following simulates that... | ||
| # DEBUG ------------------------------------------------------------------------------ | ||
| from embeddings.embedding import Embedding # noqa: PLC0415 | ||
|
|
||
| embeddings = ( | ||
| Embedding( | ||
| timdex_record_id=input_record.timdex_record_id, | ||
| run_id=input_record.run_id, | ||
| run_record_offset=input_record.run_record_offset, | ||
| embedding_strategy=input_record.embedding_strategy, | ||
| model_uri=model.model_uri, | ||
| embedding_vector=[0.1, 0.2, 0.3], | ||
| embedding_token_weights={"coffee": 0.9, "seattle": 0.5}, | ||
| ) | ||
| for input_record in input_records | ||
| ) | ||
| # DEBUG ------------------------------------------------------------------------------ | ||
|
|
||
| # if requested, write embeddings to a local JSONLines file | ||
| if output_jsonl: | ||
| with jsonlines.open( | ||
| output_jsonl, | ||
| mode="w", | ||
| dumps=lambda obj: json.dumps( | ||
| obj, | ||
| default=str, | ||
| ), | ||
|
Comment on lines
+273
to
+276
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was new to me: when using |
||
| ) as writer: | ||
| for embedding in embeddings: | ||
| writer.write(embedding.to_dict()) | ||
|
Comment on lines
+270
to
+279
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: this block could be a method to improve readability There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the thinking of encapsulating this somehow, but I'd like to wait on that until the other pieces are more established. For example, I'm unsure if it makes sense for the embedding classes to perform writing, I'm thinking not. Therefore, it's basically the CLI that does the writing. So there is nowhere for a method per se, but we could have some utility functions? But if we go the utility function route, I'm unsure if a free floating function at the bottom of the file or the hopping around to another file is better than these couple of steps here. Duly noted, but opting to wait for now. |
||
|
|
||
| # else, default writing embeddings back to TIMDEX dataset | ||
| else: | ||
| # WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...) | ||
| # NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the | ||
| # Embedding instance will nearly 1:1 map to. | ||
| raise NotImplementedError | ||
|
|
||
| logger.info("Embeddings creation complete.") | ||
|
|
||
ghukill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if __name__ == "__main__": # pragma: no cover | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,9 +10,8 @@ def configure_logger(logger: logging.Logger, *, verbose: bool) -> str: | |
| format="%(asctime)s %(levelname)s %(name)s.%(funcName)s() line %(lineno)d: " | ||
| "%(message)s" | ||
| ) | ||
| logger.setLevel(logging.DEBUG) | ||
| for handler in logging.root.handlers: | ||
| handler.addFilter(logging.Filter("embeddings")) | ||
|
Comment on lines
-14
to
-15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a great question, one that I felt deserved an entire commit 😅: 0024b8f. Not sharing the commit to be glib and happy to elaborate more. In short, any applications that install our To me, and noted in the commit, this could be a happy medium:
TL/DR: moves to an opt-in pattern for debug logging, while putting TDA on the same footing as the application it's part of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad, since this was in the first commit, I didn't associate it with those changes! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not at all - kind of sloppy on my part how this happened. Snuck the removal in an earlier PR, then this update is basically building on that. |
||
| logging.getLogger("embeddings").setLevel(logging.DEBUG) | ||
| logging.getLogger("timdex_dataset_api").setLevel(logging.DEBUG) | ||
| else: | ||
| logging.basicConfig( | ||
| format="%(asctime)s %(levelname)s %(name)s.%(funcName)s(): %(message)s" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import datetime | ||
| import json | ||
| from dataclasses import asdict, dataclass, field | ||
|
|
||
|
|
||
| @dataclass | ||
| class EmbeddingInput: | ||
| """Encapsulates the inputs for an embedding. | ||
|
|
||
| When creating an embedding, we need to note what TIMDEX record the embedding is | ||
| associated with and what strategy was used to prepare the embedding input text from | ||
| the record itself. | ||
|
Comment on lines
+7
to
+12
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Much better! |
||
|
|
||
| Args: | ||
| (timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record | ||
| embedding_strategy: strategy used to create text for embedding | ||
| text: text to embed, created from the TIMDEX record via the embedding_strategy | ||
|
Comment on lines
+16
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These docstrings and names could be a little clearer, is there a more descriptive name than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While I agree that the class names and docstrings may need some touches, I do feel like
I'm unsure if we benefit from something like |
||
| """ | ||
|
|
||
| timdex_record_id: str | ||
| run_id: str | ||
| run_record_offset: int | ||
| embedding_strategy: str | ||
| text: str | ||
|
|
||
|
|
||
| @dataclass | ||
| class Embedding: | ||
| """Encapsulates a single embedding. | ||
|
|
||
| Args: | ||
| (timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record | ||
| model_uri: model URI used to create the embedding | ||
| embedding_strategy: strategy used to create text for embedding | ||
| embedding_vector: vector representation of embedding | ||
| embedding_token_weights: decoded token:weight pairs from sparse vector | ||
| - only applicable to models that produce this output | ||
| """ | ||
|
|
||
| timdex_record_id: str | ||
| run_id: str | ||
| run_record_offset: int | ||
| model_uri: str | ||
| embedding_strategy: str | ||
| embedding_vector: list[float] | ||
| embedding_token_weights: dict | ||
|
|
||
| timestamp: datetime.datetime = field( | ||
| default_factory=lambda: datetime.datetime.now(datetime.UTC) | ||
| ) | ||
|
|
||
| def to_dict(self) -> dict: | ||
| """Marshal to dictionary.""" | ||
| return asdict(self) | ||
|
|
||
| def to_json(self) -> str: | ||
| """Serialize to JSON.""" | ||
| return json.dumps(self.to_dict(), default=str) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with the approach in the commit message!