Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions embeddings/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
import json
import logging
import time
from collections.abc import Callable
from collections.abc import Callable, Iterator
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING

import click
import jsonlines
import smart_open
from timdex_dataset_api import TIMDEXDataset
from timdex_dataset_api import DatasetEmbedding, TIMDEXDataset, TIMDEXEmbeddings

from embeddings.config import configure_logger, configure_sentry
from embeddings.models.base import Embedding
from embeddings.models.registry import get_model_class
from embeddings.strategies.processor import create_embedding_inputs
from embeddings.strategies.registry import STRATEGY_REGISTRY
Expand Down Expand Up @@ -222,6 +223,7 @@ def create_embeddings(
"""Create embeddings for TIMDEX records."""
model: BaseEmbeddingModel = ctx.obj["model"]
model.load()
timdex_dataset: TIMDEXDataset | None = None

# read input records from TIMDEX dataset (default) or a JSONLines file
if input_jsonl:
Expand All @@ -230,7 +232,6 @@ def create_embeddings(
jsonlines.Reader(file_obj) as reader,
):
timdex_records = iter(list(reader))

else:
if not dataset_location or not run_id:
raise click.UsageError(
Expand Down Expand Up @@ -273,14 +274,26 @@ def create_embeddings(
for embedding in embeddings:
writer.write(embedding.to_dict())
else:
# WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
# NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
# Embedding instance will nearly 1:1 map to.
raise NotImplementedError
if not timdex_dataset:
# if input_jsonl, init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset)
timdex_embeddings.write(_dataset_embedding_iter(embeddings))

logger.info("Embeddings creation complete.")


if __name__ == "__main__": # pragma: no cover
logger = logging.getLogger("embeddings.main")
main()
def _dataset_embedding_iter(
embeddings: Iterator[Embedding],
) -> Iterator[DatasetEmbedding]:
"""Yield DatasetEmbedding objects from model embeddings."""
for embedding in embeddings:
yield DatasetEmbedding(
timdex_record_id=embedding.timdex_record_id,
run_id=embedding.run_id,
run_record_offset=embedding.run_record_offset,
embedding_model=embedding.model_uri,
embedding_strategy=embedding.embedding_strategy,
embedding_vector=embedding.embedding_vector,
embedding_object=json.dumps(embedding.embedding_token_weights).encode(),
)
46 changes: 46 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from pathlib import Path
from unittest.mock import patch

from timdex_dataset_api import TIMDEXDataset
from timdex_dataset_api.embeddings import TIMDEXEmbeddings

from embeddings.cli import main


Expand Down Expand Up @@ -133,6 +139,46 @@ def test_model_required_decorator_works_across_commands(
assert "OK" in result.output


@patch("timdex_dataset_api.TIMDEXDataset.read_dicts_iter")
def test_create_embeddings_writes_to_timdex_dataset(
mock_timdex_dataset_read_dicts_iter, register_mock_model, runner, tmp_path
):
mock_timdex_dataset_read_dicts_iter.return_value = iter(
[
{
"timdex_record_id": "record:1",
"run_id": "run-1",
"run_record_offset": 0,
"transformed_record": '{"title":"Record 1","description":"This is a record about coffee in the mountains."}', # noqa: E501
}
]
)

# init TIMDEX Dataset and Embeddings
timdex_dataset = TIMDEXDataset(location=str(tmp_path / "dataset"))
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset)

result = runner.invoke(
main,
[
"create-embeddings",
"--model-uri",
"test/mock-model",
"--dataset-location",
str(tmp_path / "dataset"),
"--run-id",
"run-1",
"--strategy",
"full_record",
],
)

# TODO @jonavellecuerdo: Update to use TIMDEXEmbeddings # noqa: FIX002
# read method when ready
assert result.exit_code == 0
assert Path(timdex_embeddings.data_embeddings_root).exists()


def test_create_embeddings_requires_strategy(register_mock_model, runner):
result = runner.invoke(
main,
Expand Down
48 changes: 24 additions & 24 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.