Skip to content

Commit 8cac1a0

Browse files
committed
Normalize string formatting and CLI logging
1 parent 262217a commit 8cac1a0

File tree

4 files changed

+20
-23
lines changed

4 files changed

+20
-23
lines changed

embeddings/cli.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,17 @@ def ping() -> None:
6262
)
6363
def download_model(model_uri: str, output: Path) -> None:
6464
"""Download a model from HuggingFace and save as zip file."""
65-
try:
66-
model_class = get_model_class(model_uri)
67-
except ValueError as e:
68-
logger.exception("Unknown model URI: %s", model_uri)
69-
raise click.ClickException(str(e)) from e
70-
71-
logger.info("Downloading model: %s", model_uri)
65+
# load embedding model class
66+
model_class = get_model_class(model_uri)
7267
model = model_class(model_uri)
7368

74-
try:
75-
result_path = model.download(output)
76-
logger.info("Model downloaded successfully to: %s", result_path)
77-
click.echo(f"Model saved to: {result_path}")
78-
except NotImplementedError as e:
79-
logger.exception("Download not yet implemented for model: %s", model_uri)
80-
raise click.ClickException(str(e)) from e
81-
except Exception as e:
82-
logger.exception("Failed to download model: %s", model_uri)
83-
msg = f"Download failed: {e}"
84-
raise click.ClickException(msg) from e
69+
# download model assets
70+
logger.info(f"Downloading model: {model_uri}")
71+
result_path = model.download(output)
72+
73+
message = f"Model downloaded and saved to: {result_path}"
74+
logger.info(message)
75+
click.echo(result_path)
8576

8677

8778
@main.command()
@@ -90,9 +81,8 @@ def download_model(model_uri: str, output: Path) -> None:
9081
required=True,
9182
help="HuggingFace model URI (e.g., 'org/model-name')",
9283
)
93-
def create_embeddings(model_uri: str) -> None:
94-
"""Create embeddings."""
95-
logger.info("create-embeddings command called with model: %s", model_uri)
84+
def create_embeddings(_model_uri: str) -> None:
85+
# TODO: docstring # noqa: FIX002
9686
raise NotImplementedError
9787

9888

embeddings/models/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
"""Registry mapping model URIs to model classes."""
22

3+
import logging
4+
35
from embeddings.models.base import BaseEmbeddingModel
46
from embeddings.models.os_neural_sparse_doc_v3_gte import OSNeuralSparseDocV3GTE
57

8+
logger = logging.getLogger(__name__)
9+
610
MODEL_REGISTRY: dict[str, type[BaseEmbeddingModel]] = {
711
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte": (
812
OSNeuralSparseDocV3GTE
@@ -22,5 +26,6 @@ def get_model_class(model_uri: str) -> type[BaseEmbeddingModel]:
2226
if model_uri not in MODEL_REGISTRY:
2327
available = ", ".join(sorted(MODEL_REGISTRY.keys()))
2428
msg = f"Unknown model URI: {model_uri}. Available models: {available}"
29+
logger.error(msg)
2530
raise ValueError(msg)
2631
return MODEL_REGISTRY[model_uri]

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ ignore = [
6363
"PLR0913",
6464
"PLR0915",
6565
"S321",
66+
"TD002",
67+
"TD003",
6668
]
6769

6870
# allow autofix behavior for specified rules

tests/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def test_cli_debug_logging(caplog, runner):
1616
assert "pong" in result.output
1717

1818

19-
def test_download_model_unknown_uri(runner):
19+
def test_download_model_unknown_uri(caplog, runner):
2020
result = runner.invoke(
2121
main, ["download-model", "--model-uri", "unknown/model", "--output", "out.zip"]
2222
)
2323
assert result.exit_code != 0
24-
assert "Unknown model URI" in result.output
24+
assert "Unknown model URI" in caplog.text
2525

2626

2727
def test_download_model_not_implemented(caplog, runner):

0 commit comments

Comments
 (0)