diff --git a/embeddings/cli.py b/embeddings/cli.py index 26ae866..324bd2c 100644 --- a/embeddings/cli.py +++ b/embeddings/cli.py @@ -2,7 +2,7 @@ import json import logging import time -from collections.abc import Callable +from collections.abc import Callable, Iterator from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -10,9 +10,10 @@ import click import jsonlines import smart_open -from timdex_dataset_api import TIMDEXDataset +from timdex_dataset_api import DatasetEmbedding, TIMDEXDataset, TIMDEXEmbeddings from embeddings.config import configure_logger, configure_sentry +from embeddings.models.base import Embedding from embeddings.models.registry import get_model_class from embeddings.strategies.processor import create_embedding_inputs from embeddings.strategies.registry import STRATEGY_REGISTRY @@ -222,6 +223,7 @@ def create_embeddings( """Create embeddings for TIMDEX records.""" model: BaseEmbeddingModel = ctx.obj["model"] model.load() + timdex_dataset: TIMDEXDataset | None = None # read input records from TIMDEX dataset (default) or a JSONLines file if input_jsonl: @@ -230,7 +232,6 @@ def create_embeddings( jsonlines.Reader(file_obj) as reader, ): timdex_records = iter(list(reader)) - else: if not dataset_location or not run_id: raise click.UsageError( @@ -273,14 +274,26 @@ def create_embeddings( for embedding in embeddings: writer.write(embedding.to_dict()) else: - # WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...) - # NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the - # Embedding instance will nearly 1:1 map to. - raise NotImplementedError + if not timdex_dataset: + # if input_jsonl, init TIMDEXDataset + timdex_dataset = TIMDEXDataset(dataset_location) + timdex_embeddings = TIMDEXEmbeddings(timdex_dataset) + timdex_embeddings.write(_dataset_embedding_iter(embeddings)) logger.info("Embeddings creation complete.") -if __name__ == "__main__": # pragma: no cover - logger = logging.getLogger("embeddings.main") - main() +def _dataset_embedding_iter( + embeddings: Iterator[Embedding], +) -> Iterator[DatasetEmbedding]: + """Yield DatasetEmbedding objects from model embeddings.""" + for embedding in embeddings: + yield DatasetEmbedding( + timdex_record_id=embedding.timdex_record_id, + run_id=embedding.run_id, + run_record_offset=embedding.run_record_offset, + embedding_model=embedding.model_uri, + embedding_strategy=embedding.embedding_strategy, + embedding_vector=embedding.embedding_vector, + embedding_object=json.dumps(embedding.embedding_token_weights).encode(), + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 9f78a93..74a5efc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,9 @@ +from pathlib import Path +from unittest.mock import patch + +from timdex_dataset_api import TIMDEXDataset +from timdex_dataset_api.embeddings import TIMDEXEmbeddings + from embeddings.cli import main @@ -133,6 +139,46 @@ def test_model_required_decorator_works_across_commands( assert "OK" in result.output +@patch("timdex_dataset_api.TIMDEXDataset.read_dicts_iter") +def test_create_embeddings_writes_to_timdex_dataset( + mock_timdex_dataset_read_dicts_iter, register_mock_model, runner, tmp_path +): + mock_timdex_dataset_read_dicts_iter.return_value = iter( + [ + { + "timdex_record_id": "record:1", + "run_id": "run-1", + "run_record_offset": 0, + "transformed_record": '{"title":"Record 1","description":"This is a record about coffee in the mountains."}', # noqa: E501 + } + ] + ) + + # init TIMDEX Dataset and Embeddings + timdex_dataset = TIMDEXDataset(location=str(tmp_path / "dataset")) + timdex_embeddings = TIMDEXEmbeddings(timdex_dataset) + + result = runner.invoke( + main, + [ + "create-embeddings", + "--model-uri", + "test/mock-model", + "--dataset-location", + str(tmp_path / "dataset"), + "--run-id", + "run-1", + "--strategy", + "full_record", + ], + ) + + # TODO @jonavellecuerdo: Update to use TIMDEXEmbeddings # noqa: FIX002 + # read method when ready + assert result.exit_code == 0 + assert Path(timdex_embeddings.data_embeddings_root).exists() + + def test_create_embeddings_requires_strategy(register_mock_model, runner): result = runner.invoke( main, diff --git a/uv.lock b/uv.lock index f63b7fa..299a794 100644 --- a/uv.lock +++ b/uv.lock @@ -338,28 +338,28 @@ sdist = { url = "https://files.pythonhosted.org/packages/a2/55/8f8cab2afd404cf57 [[package]] name = "duckdb" -version = "1.4.2.dev27" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0e/30/b1c564550861899c8db506e9782c1b2dc4f50722e51d4473e33ea4ddc60b/duckdb-1.4.2.dev27.tar.gz", hash = "sha256:055a31d715facbc8416ef01cbaad8e7c007a48f73733e1504593ece9870749e3", size = 18471564, upload-time = "2025-10-17T08:13:25.58Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/33/4ed6b03b3d1ec295e179f64b466d521791b1d365bdf3845c27fdddadf073/duckdb-1.4.2.dev27-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0b900f801b05c42b7c0b38d0f8679c5a0cafdef9a4f5f5bd0c194a24fc94312f", size = 28949386, upload-time = "2025-10-17T08:12:22.111Z" }, - { url = "https://files.pythonhosted.org/packages/a2/16/f66b6e7c50752ba47847b3433564f346d9990cb1154e7c92cb8bfc08e3b5/duckdb-1.4.2.dev27-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7b5c1a0809e0123f4b8cc18b749d573f6e96f33472f5486283df2afd802a3f1", size = 16092169, upload-time = "2025-10-17T08:12:24.456Z" }, - { url = "https://files.pythonhosted.org/packages/d1/82/9e3d00a37a741f7d2f1d2595c726b4c4e618114e4db7d66d75d8b655d3a0/duckdb-1.4.2.dev27-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d5b51c838a2a6d3eec537e3852d432d83d1c213fc14431760734c0b5a8c86d7b", size = 13704382, upload-time = "2025-10-17T08:12:26.676Z" }, - { url = "https://files.pythonhosted.org/packages/cd/6c/8e719ad2526fb5c66bb5d79c2a83bacc46ea32ba68370f21f6a388df1c3d/duckdb-1.4.2.dev27-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9e3190bb3e696d8380ccc68ef77ca982244f5debc2cad64aaceb83f14a57438e", size = 18440409, upload-time = "2025-10-17T08:12:29.177Z" }, - { url = "https://files.pythonhosted.org/packages/d4/84/828f9aee8797eeccd01376906b84a4aadac0017421db2777ed674891889d/duckdb-1.4.2.dev27-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:753dbc58f1be72fbd111882fcb40a3ca3a11183b1fae70ee3d9ee7314312448e", size = 20443950, upload-time = "2025-10-17T08:12:31.951Z" }, - { url = "https://files.pythonhosted.org/packages/c5/a0/895892991bf9eab6a284a2e0668db3ad2e91b5a1a2716e0549ce0ac3c8c7/duckdb-1.4.2.dev27-cp312-cp312-win_amd64.whl", hash = "sha256:352afa65795588a540d414c1fd4f3aa125c320c524f26f6d9017788a7efd6245", size = 12307457, upload-time = "2025-10-17T08:12:34.483Z" }, - { url = "https://files.pythonhosted.org/packages/a9/ff/3a05601174f87e0fd83306496d6aed845c31f4fa6b72c0d940e8ebabe012/duckdb-1.4.2.dev27-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:09f0c95022615bb6593240e9aded18c825b71a7708b4d5feac8010b6835a0218", size = 28949906, upload-time = "2025-10-17T08:12:36.953Z" }, - { url = "https://files.pythonhosted.org/packages/ef/0c/d526395b56ed9fa52a395ef55b6cd4905faa3d22324b607a50232da28cd1/duckdb-1.4.2.dev27-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1dbe3ed4e6a11bf93424f42d73f1428c12d9fdaf724996dfee8982299a5a9e4", size = 16092427, upload-time = "2025-10-17T08:12:39.686Z" }, - { url = "https://files.pythonhosted.org/packages/05/02/1db9ec43e57ce718efeb19d5950d347b667703c17392e7b1137b0ee38249/duckdb-1.4.2.dev27-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:727094fc2ee39ac2227d7cd16c6cf5eacd747c0303d9a4b8c9b8e8a7f31bcbfd", size = 13704418, upload-time = "2025-10-17T08:12:41.826Z" }, - { url = "https://files.pythonhosted.org/packages/0c/fa/3152ae03823f9524bd61dcb5dc071353d6c14b5b38d3d4838e4eca0807a9/duckdb-1.4.2.dev27-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:359fd08fa84c7ad303334778b3d20c1ef391bf0a7523753f5b4fc582eb63e309", size = 18444180, upload-time = "2025-10-17T08:12:43.998Z" }, - { url = "https://files.pythonhosted.org/packages/bc/64/f5f0bf92717c43c6711e4ac960180ca6e4c15597c040e578c7ec577c3039/duckdb-1.4.2.dev27-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50458f77161f28928dc805578b13e874b883e1e67e686c7173ae629f01b951e9", size = 20443164, upload-time = "2025-10-17T08:12:46.336Z" }, - { url = "https://files.pythonhosted.org/packages/b9/e3/c4dc0297499c8254b08f3b6b8c19d2855f11c58b5017297a560cb2fa0b17/duckdb-1.4.2.dev27-cp313-cp313-win_amd64.whl", hash = "sha256:2c2f363942232019fae10efd9b92c9dde590a955600898c6588a2abbbf322f5d", size = 12307907, upload-time = "2025-10-17T08:12:48.959Z" }, - { url = "https://files.pythonhosted.org/packages/00/70/807f2d8ba7fc625dc6c9bd851b88251aef3649578f7f6872b9dbb716f282/duckdb-1.4.2.dev27-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:7947440cb61279b7f6acdc624f0c72d700b3591f178dd3c064d4fe40f257fc85", size = 28959364, upload-time = "2025-10-17T08:12:51.222Z" }, - { url = "https://files.pythonhosted.org/packages/c0/96/310d552b385b65fc9bac275a8ef56817527adb0494077ea5c8ffb83dd5f5/duckdb-1.4.2.dev27-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e130343c5d7fc87dff9cd002f7ec7791e9e76020bd23a47d0ac2ba6775b9fb91", size = 16094254, upload-time = "2025-10-17T08:12:55.275Z" }, - { url = "https://files.pythonhosted.org/packages/d5/44/a7728ba77bf517cf5ba248a0b85868556e5c5919d3b9a3d21ea3508e4a3b/duckdb-1.4.2.dev27-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e03a3b89dcf69e9f41ce0f512a307bccb1818aa2e8ebd6e98f36cc1360c20a9f", size = 13708467, upload-time = "2025-10-17T08:12:59.346Z" }, - { url = "https://files.pythonhosted.org/packages/93/1a/3ad702047a839d5ba4f5bb27a0f0cf0d900d6143a9521d0f0c1b5c814c7d/duckdb-1.4.2.dev27-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:409a2dea728953a4d7500fcbe4f96d5e021cc996cb73c2aa45d788fbcfc2ecac", size = 18449820, upload-time = "2025-10-17T08:13:01.888Z" }, - { url = "https://files.pythonhosted.org/packages/77/ac/e6679096c944ba2751f2ca8bc4832da5c6a47d59ab72cabcec6fefb015c6/duckdb-1.4.2.dev27-cp314-cp314-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:33bde9de3b5c1a70d6ba212821afaad5d49de2644c85ae021bac9c8d36d358c2", size = 20444475, upload-time = "2025-10-17T08:13:04.464Z" }, - { url = "https://files.pythonhosted.org/packages/a8/41/d09cd75c8229a4450e6d0bc79a1d6c4f7ef9108545f39e78c03067e5b4e9/duckdb-1.4.2.dev27-cp314-cp314-win_amd64.whl", hash = "sha256:934464e710ee057c9e43ed2eee60e8fad207be9fe387cf21dd42e05fde9c61f9", size = 12807107, upload-time = "2025-10-17T08:13:06.776Z" }, +version = "1.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/81/99/ac6c105118751cc3ccae980b12e44847273f3402e647ec3197aff2251e23/duckdb-1.4.2.tar.gz", hash = "sha256:df81acee3b15ecb2c72eb8f8579fb5922f6f56c71f5c8892ea3bc6fab39aa2c4", size = 18469786, upload-time = "2025-11-12T13:18:04.203Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/29/2f68c57e7c4242fedbf4b3fdc24fce2ffcf60640c936621d8a645593a161/duckdb-1.4.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9356fe17af2711e0a5ace4b20a0373e03163545fd7516e0c3c40428f44597052", size = 29015814, upload-time = "2025-11-12T13:16:59.329Z" }, + { url = "https://files.pythonhosted.org/packages/34/b7/030cc278a4ae788800a833b2901b9a7da7a6993121053c4155c359328531/duckdb-1.4.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:946a8374c0252db3fa41165ab9952b48adc8de06561a6b5fd62025ac700e492f", size = 15403892, upload-time = "2025-11-12T13:17:02.141Z" }, + { url = "https://files.pythonhosted.org/packages/f7/a2/67f4798a7a29bd0813f8a1e94a83e857e57f5d1ba14cf3edc5551aad0095/duckdb-1.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:389fa9abe4ca37d091332a2f8c3ebd713f18e87dc4cb5e8efd3e5aa8ddf8885f", size = 13733622, upload-time = "2025-11-12T13:17:04.502Z" }, + { url = "https://files.pythonhosted.org/packages/6e/ac/d0d0e3feae9663334b2336f15785d280b54a56c3ffa10334e20a51a87ecd/duckdb-1.4.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be8c0c40f2264b91500b89c688f743e1c7764966e988f680b1f19416b00052e", size = 18470220, upload-time = "2025-11-12T13:17:07.049Z" }, + { url = "https://files.pythonhosted.org/packages/a5/52/7570a50430cbffc8bd702443ac28a446b0fa4f77747a3821d4b37a852b15/duckdb-1.4.2-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6a21732dd52a76f1e61484c06d65800b18f57fe29e8102a7466c201a2221604", size = 20481138, upload-time = "2025-11-12T13:17:09.459Z" }, + { url = "https://files.pythonhosted.org/packages/95/5e/be05f46a290ea27630c112ff9e01fd01f585e599967fc52fe2edc7bc2039/duckdb-1.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:769440f4507c20542ae2e5b87f6c6c6d3f148c0aa8f912528f6c97e9aedf6a21", size = 12330737, upload-time = "2025-11-12T13:17:12.02Z" }, + { url = "https://files.pythonhosted.org/packages/70/c4/5054dbe79cf570b0c97db0c2eba7eb541cc561037360479059a3b57e4a32/duckdb-1.4.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:de646227fc2c53101ac84e86e444e7561aa077387aca8b37052f3803ee690a17", size = 29015784, upload-time = "2025-11-12T13:17:14.409Z" }, + { url = "https://files.pythonhosted.org/packages/2c/b8/97f4f07d9459f5d262751cccfb2f4256debb8fe5ca92370cebe21aab1ee2/duckdb-1.4.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f1fac31babda2045d4cdefe6d0fd2ebdd8d4c2a333fbcc11607cfeaec202d18d", size = 15403788, upload-time = "2025-11-12T13:17:16.864Z" }, + { url = "https://files.pythonhosted.org/packages/a4/ea/112f33ace03682bafd4aaf0a3336da689b9834663e7032b3d678fd2902c9/duckdb-1.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:43ac632f40ab1aede9b4ce3c09ea043f26f3db97b83c07c632c84ebd7f7c0f4a", size = 13733603, upload-time = "2025-11-12T13:17:20.884Z" }, + { url = "https://files.pythonhosted.org/packages/34/83/8d6f845a9a946e8b47b6253b9edb084c45670763e815feed6cfefc957e89/duckdb-1.4.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:77db030b48321bf785767b7b1800bf657dd2584f6df0a77e05201ecd22017da2", size = 18473725, upload-time = "2025-11-12T13:17:23.074Z" }, + { url = "https://files.pythonhosted.org/packages/82/29/153d1b4fc14c68e6766d7712d35a7ab6272a801c52160126ac7df681f758/duckdb-1.4.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a456adbc3459c9dcd99052fad20bd5f8ef642be5b04d09590376b2eb3eb84f5c", size = 20481971, upload-time = "2025-11-12T13:17:26.703Z" }, + { url = "https://files.pythonhosted.org/packages/58/b7/8d3a58b5ebfb9e79ed4030a0f2fbd7e404c52602e977b1e7ab51651816c7/duckdb-1.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:2f7c61617d2b1da3da5d7e215be616ad45aa3221c4b9e2c4d1c28ed09bc3c1c4", size = 12330535, upload-time = "2025-11-12T13:17:29.175Z" }, + { url = "https://files.pythonhosted.org/packages/25/46/0f316e4d0d6bada350b9da06691a2537c329c8948c78e8b5e0c4874bc5e2/duckdb-1.4.2-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:422be8c6bdc98366c97f464b204b81b892bf962abeae6b0184104b8233da4f19", size = 29028616, upload-time = "2025-11-12T13:17:31.599Z" }, + { url = "https://files.pythonhosted.org/packages/82/ab/e04a8f97865251b544aee9501088d4f0cb8e8b37339bd465c0d33857d411/duckdb-1.4.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:459b1855bd06a226a2838da4f14c8863fd87a62e63d414a7f7f682a7c616511a", size = 15410382, upload-time = "2025-11-12T13:17:34.14Z" }, + { url = "https://files.pythonhosted.org/packages/47/ec/b8229517c2f9fe88a38bb1a172a2da4d0ff34996d319d74554fda80b6358/duckdb-1.4.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:20c45b4ead1ea4d23a1be1cd4f1dfc635e58b55f0dd11e38781369be6c549903", size = 13737588, upload-time = "2025-11-12T13:17:36.515Z" }, + { url = "https://files.pythonhosted.org/packages/f2/9a/63d26da9011890a5b893e0c21845c0c0b43c634bf263af3bbca64be0db76/duckdb-1.4.2-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e552451054534970dc999e69ca5ae5c606458548c43fb66d772117760485096", size = 18477886, upload-time = "2025-11-12T13:17:39.136Z" }, + { url = "https://files.pythonhosted.org/packages/23/35/b1fae4c5245697837f6f63e407fa81e7ccc7948f6ef2b124cd38736f4d1d/duckdb-1.4.2-cp314-cp314-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:128c97dab574a438d7c8d020670b21c68792267d88e65a7773667b556541fa9b", size = 20483292, upload-time = "2025-11-12T13:17:41.501Z" }, + { url = "https://files.pythonhosted.org/packages/25/5e/6f5ebaabc12c6db62f471f86b5c9c8debd57f11aa1b2acbbcc4c68683238/duckdb-1.4.2-cp314-cp314-win_amd64.whl", hash = "sha256:dfcc56a83420c0dec0b83e97a6b33addac1b7554b8828894f9d203955591218c", size = 12830520, upload-time = "2025-11-12T13:17:43.93Z" }, ] [[package]] @@ -1871,8 +1871,8 @@ wheels = [ [[package]] name = "timdex-dataset-api" -version = "3.5.0" -source = { git = "https://github.com/MITLibraries/timdex-dataset-api#ef34e4ad2702a0aa6fbe16ccaf3966928040d0ce" } +version = "3.6.1" +source = { git = "https://github.com/MITLibraries/timdex-dataset-api#eac061079df38272cee59a1a869a60f3094eb7b1" } dependencies = [ { name = "attrs" }, { name = "boto3" },