Skip to content

Commit 9bc73e5

Browse files
authored
Merge 847f671 into 6e1a1a5
2 parents 6e1a1a5 + 847f671 commit 9bc73e5

File tree

10 files changed

+272
-43
lines changed

10 files changed

+272
-43
lines changed

README.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,43 @@ WORKSPACE=### Set to `dev` for local development, this will be set to `stage` an
2222

2323
### Optional
2424

25-
_None yet at this time._
25+
```shell
26+
TE_MODEL_URI=# HuggingFace model URI
27+
TE_MODEL_DOWNLOAD_PATH=# Download location for model
28+
HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts
29+
```
30+
31+
## CLI Commands
32+
33+
For local development, all CLI commands should be invoked with the following format to pickup environment variables from `.env`:
34+
35+
```shell
36+
uv run --env-file .env embeddings <COMMAND> <ARGS>
37+
```
2638

39+
### `ping`
40+
```text
41+
Usage: embeddings ping [OPTIONS]
2742
43+
Emit 'pong' to debug logs and stdout.
44+
```
45+
46+
### `download-model`
47+
```text
48+
Usage: embeddings download-model [OPTIONS]
49+
50+
Download a model from HuggingFace and save as zip file.
51+
52+
Options:
53+
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') [required]
54+
--output PATH Output path for zipped model (e.g., '/path/to/model.zip')
55+
[required]
56+
--help Show this message and exit.
57+
```
58+
59+
### `create-embeddings`
60+
```text
61+
TODO...
62+
```
2863

2964

embeddings/cli.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import functools
12
import logging
23
import time
4+
from collections.abc import Callable
35
from datetime import timedelta
46
from pathlib import Path
57

@@ -11,6 +13,22 @@
1113
logger = logging.getLogger(__name__)
1214

1315

16+
def model_required(f: Callable) -> Callable:
17+
"""Decorator for commands that require a specific model."""
18+
19+
@click.option(
20+
"--model-uri",
21+
envvar="TE_MODEL_URI",
22+
required=True,
23+
help="HuggingFace model URI (e.g., 'org/model-name')",
24+
)
25+
@functools.wraps(f)
26+
def wrapper(*args: list, **kwargs: dict) -> Callable:
27+
return f(*args, **kwargs)
28+
29+
return wrapper
30+
31+
1432
@click.group("embeddings")
1533
@click.option(
1634
"-v",
@@ -49,22 +67,19 @@ def ping() -> None:
4967

5068

5169
@main.command()
52-
@click.option(
53-
"--model-uri",
54-
required=True,
55-
help="HuggingFace model URI (e.g., 'org/model-name')",
56-
)
70+
@model_required
5771
@click.option(
5872
"--output",
5973
required=True,
74+
envvar="TE_MODEL_DOWNLOAD_PATH",
6075
type=click.Path(path_type=Path),
6176
help="Output path for zipped model (e.g., '/path/to/model.zip')",
6277
)
6378
def download_model(model_uri: str, output: Path) -> None:
6479
"""Download a model from HuggingFace and save as zip file."""
6580
# load embedding model class
6681
model_class = get_model_class(model_uri)
67-
model = model_class(model_uri)
82+
model = model_class()
6883

6984
# download model assets
7085
logger.info(f"Downloading model: {model_uri}")
@@ -76,11 +91,7 @@ def download_model(model_uri: str, output: Path) -> None:
7691

7792

7893
@main.command()
79-
@click.option(
80-
"--model-uri",
81-
required=True,
82-
help="HuggingFace model URI (e.g., 'org/model-name')",
83-
)
94+
@model_required
8495
def create_embeddings(_model_uri: str) -> None:
8596
# TODO: docstring # noqa: FIX002
8697
raise NotImplementedError

embeddings/models/base.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,25 @@
77
class BaseEmbeddingModel(ABC):
88
"""Abstract base class for embedding models.
99
10-
Args:
11-
model_uri: HuggingFace model identifier (e.g., 'org/model-name').
10+
All child classes must set class level attribute MODEL_URI.
1211
"""
1312

14-
def __init__(self, model_uri: str) -> None:
15-
self.model_uri = model_uri
13+
MODEL_URI: str # Type hint to document the requirement
14+
15+
def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
16+
super().__init_subclass__(**kwargs)
17+
18+
# require class level MODEL_URI to be set
19+
if not hasattr(cls, "MODEL_URI"):
20+
msg = f"{cls.__name__} must define 'MODEL_URI' class attribute"
21+
raise TypeError(msg)
22+
if not isinstance(cls.MODEL_URI, str):
23+
msg = f"{cls.__name__} must override 'MODEL_URI' with a valid string"
24+
raise TypeError(msg)
25+
26+
@property
27+
def model_uri(self) -> str:
28+
return self.MODEL_URI
1629

1730
@abstractmethod
1831
def download(self, output_path: Path) -> Path:

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
1414
HuggingFace URI: opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
1515
"""
1616

17+
MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
18+
1719
def download(self, output_path: Path) -> Path:
1820
"""Download and prepare model, saving to output_path.
1921

embeddings/models/registry.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,23 @@
77

88
logger = logging.getLogger(__name__)
99

10+
MODEL_CLASSES = [OSNeuralSparseDocV3GTE]
11+
1012
MODEL_REGISTRY: dict[str, type[BaseEmbeddingModel]] = {
11-
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte": (
12-
OSNeuralSparseDocV3GTE
13-
),
13+
model.MODEL_URI: model for model in MODEL_CLASSES
1414
}
1515

1616

1717
def get_model_class(model_uri: str) -> type[BaseEmbeddingModel]:
18-
"""Get model class for given URI.
18+
"""Get an embedding model class via the HuggingFace model URI.
1919
2020
Args:
2121
model_uri: HuggingFace model identifier.
22-
23-
Returns:
24-
Model class for the given URI.
2522
"""
2623
if model_uri not in MODEL_REGISTRY:
2724
available = ", ".join(sorted(MODEL_REGISTRY.keys()))
2825
msg = f"Unknown model URI: {model_uri}. Available models: {available}"
2926
logger.error(msg)
3027
raise ValueError(msg)
28+
3129
return MODEL_REGISTRY[model_uri]

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
dev = [
1919
"black>=25.1.0",
2020
"coveralls>=4.0.1",
21+
"ipython>=9.6.0",
2122
"mypy>=1.17.1",
2223
"pip-audit>=2.9.0",
2324
"pre-commit>=4.3.0",
@@ -58,13 +59,15 @@ ignore = [
5859
"D102",
5960
"D103",
6061
"D104",
62+
"EM102",
6163
"G004",
6264
"PLR0912",
6365
"PLR0913",
6466
"PLR0915",
6567
"S321",
6668
"TD002",
6769
"TD003",
70+
"TRY003",
6871
]
6972

7073
# allow autofix behavior for specified rules

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def runner():
2121
class MockEmbeddingModel(BaseEmbeddingModel):
2222
"""Simple test model that doesn't hit external APIs."""
2323

24+
MODEL_URI = "test/mock-model"
25+
2426
def download(self, output_path: Path) -> Path:
2527
"""Create a fake model zip file for testing."""
2628
output_path.parent.mkdir(parents=True, exist_ok=True)
@@ -34,4 +36,4 @@ def download(self, output_path: Path) -> Path:
3436
@pytest.fixture
3537
def mock_model():
3638
"""Fixture providing a MockEmbeddingModel instance."""
37-
return MockEmbeddingModel("test/mock-model")
39+
return MockEmbeddingModel()

tests/test_cli.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,3 @@ def test_download_model_unknown_uri(caplog, runner):
2222
)
2323
assert result.exit_code != 0
2424
assert "Unknown model URI" in caplog.text
25-
26-
27-
def test_download_model_not_implemented(caplog, runner):
28-
caplog.set_level("INFO")
29-
result = runner.invoke(
30-
main,
31-
[
32-
"download-model",
33-
"--model-uri",
34-
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte",
35-
"--output",
36-
"out.zip",
37-
],
38-
)
39-
assert (
40-
"Downloading model: opensearch-project/"
41-
"opensearch-neural-sparse-encoding-doc-v3-gte, saving to: out.zip."
42-
) in caplog.text
43-
assert result.exit_code != 0

tests/test_models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from embeddings.models.base import BaseEmbeddingModel
56
from embeddings.models.registry import MODEL_REGISTRY, get_model_class
67

78

@@ -46,3 +47,17 @@ def test_get_model_class_returns_correct_class():
4647
def test_get_model_class_raises_for_unknown_uri():
4748
with pytest.raises(ValueError, match="Unknown model URI"):
4849
get_model_class("unknown/model-uri")
50+
51+
52+
def test_subclass_without_model_uri_raises_type_error():
53+
with pytest.raises(TypeError, match="must define 'MODEL_URI' class attribute"):
54+
55+
class InvalidModel(BaseEmbeddingModel):
56+
pass
57+
58+
59+
def test_subclass_with_non_string_model_uri_raises_type_error():
60+
with pytest.raises(TypeError, match="must override 'MODEL_URI' with a valid string"):
61+
62+
class InvalidModel(BaseEmbeddingModel):
63+
MODEL_URI = 123

0 commit comments

Comments
 (0)