Skip to content

Commit 6e1a1a5

Browse files
authored
Merge pull request #12 from MITLibraries/USE-112-113-model-downloads
USE 112 & 113 - Base model class and stubs for downloads
2 parents 6ef490f + 8cac1a0 commit 6e1a1a5

File tree

10 files changed

+711
-52
lines changed

10 files changed

+711
-52
lines changed

embeddings/cli.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,91 @@
11
import logging
2+
import time
23
from datetime import timedelta
3-
from time import perf_counter
4+
from pathlib import Path
45

56
import click
67

78
from embeddings.config import configure_logger, configure_sentry
9+
from embeddings.models.registry import get_model_class
810

911
logger = logging.getLogger(__name__)
1012

1113

12-
@click.command()
14+
@click.group("embeddings")
1315
@click.option(
14-
"-v", "--verbose", is_flag=True, help="Pass to log at debug level instead of info"
16+
"-v",
17+
"--verbose",
18+
is_flag=True,
19+
help="Pass to log at debug level instead of info",
1520
)
16-
def main(*, verbose: bool) -> None:
17-
start_time = perf_counter()
21+
@click.pass_context
22+
def main(
23+
ctx: click.Context,
24+
*,
25+
verbose: bool,
26+
) -> None:
27+
ctx.ensure_object(dict)
28+
ctx.obj["start_time"] = time.perf_counter()
29+
1830
root_logger = logging.getLogger()
1931
logger.info(configure_logger(root_logger, verbose=verbose))
2032
logger.info(configure_sentry())
2133
logger.info("Running process")
2234

23-
# Do things here!
35+
def _log_command_elapsed_time() -> None:
36+
elapsed_time = time.perf_counter() - ctx.obj["start_time"]
37+
logger.info(
38+
"Total time to complete process: %s", str(timedelta(seconds=elapsed_time))
39+
)
40+
41+
ctx.call_on_close(_log_command_elapsed_time)
42+
43+
44+
@main.command()
45+
def ping() -> None:
46+
"""Emit 'pong' to debug logs and stdout."""
47+
logger.debug("pong")
48+
click.echo("pong")
49+
50+
51+
@main.command()
52+
@click.option(
53+
"--model-uri",
54+
required=True,
55+
help="HuggingFace model URI (e.g., 'org/model-name')",
56+
)
57+
@click.option(
58+
"--output",
59+
required=True,
60+
type=click.Path(path_type=Path),
61+
help="Output path for zipped model (e.g., '/path/to/model.zip')",
62+
)
63+
def download_model(model_uri: str, output: Path) -> None:
64+
"""Download a model from HuggingFace and save as zip file."""
65+
# load embedding model class
66+
model_class = get_model_class(model_uri)
67+
model = model_class(model_uri)
68+
69+
# download model assets
70+
logger.info(f"Downloading model: {model_uri}")
71+
result_path = model.download(output)
72+
73+
message = f"Model downloaded and saved to: {result_path}"
74+
logger.info(message)
75+
click.echo(result_path)
76+
77+
78+
@main.command()
79+
@click.option(
80+
"--model-uri",
81+
required=True,
82+
help="HuggingFace model URI (e.g., 'org/model-name')",
83+
)
84+
def create_embeddings(_model_uri: str) -> None:
85+
# TODO: docstring # noqa: FIX002
86+
raise NotImplementedError
87+
2488

25-
elapsed_time = perf_counter() - start_time
26-
logger.info(
27-
"Total time to complete process: %s", str(timedelta(seconds=elapsed_time))
28-
)
89+
if __name__ == "__main__": # pragma: no cover
90+
logger = logging.getLogger("embeddings.main")
91+
main()

embeddings/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

embeddings/models/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Base class for embedding models."""
2+
3+
from abc import ABC, abstractmethod
4+
from pathlib import Path
5+
6+
7+
class BaseEmbeddingModel(ABC):
8+
"""Abstract base class for embedding models.
9+
10+
Args:
11+
model_uri: HuggingFace model identifier (e.g., 'org/model-name').
12+
"""
13+
14+
def __init__(self, model_uri: str) -> None:
15+
self.model_uri = model_uri
16+
17+
@abstractmethod
18+
def download(self, output_path: Path) -> Path:
19+
"""Download and prepare model, saving to output_path.
20+
21+
Args:
22+
output_path: Path where the model zip should be saved.
23+
"""
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""OpenSearch Neural Sparse Doc v3 GTE model."""
2+
3+
import logging
4+
from pathlib import Path
5+
6+
from embeddings.models.base import BaseEmbeddingModel
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
12+
"""OpenSearch Neural Sparse Encoding Doc v3 GTE model.
13+
14+
HuggingFace URI: opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
15+
"""
16+
17+
def download(self, output_path: Path) -> Path:
18+
"""Download and prepare model, saving to output_path.
19+
20+
Args:
21+
output_path: Path where the model zip should be saved.
22+
"""
23+
logger.info(f"Downloading model: { self.model_uri}, saving to: {output_path}.")
24+
raise NotImplementedError

embeddings/models/registry.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Registry mapping model URIs to model classes."""
2+
3+
import logging
4+
5+
from embeddings.models.base import BaseEmbeddingModel
6+
from embeddings.models.os_neural_sparse_doc_v3_gte import OSNeuralSparseDocV3GTE
7+
8+
logger = logging.getLogger(__name__)
9+
10+
MODEL_REGISTRY: dict[str, type[BaseEmbeddingModel]] = {
11+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte": (
12+
OSNeuralSparseDocV3GTE
13+
),
14+
}
15+
16+
17+
def get_model_class(model_uri: str) -> type[BaseEmbeddingModel]:
18+
"""Get model class for given URI.
19+
20+
Args:
21+
model_uri: HuggingFace model identifier.
22+
23+
Returns:
24+
Model class for the given URI.
25+
"""
26+
if model_uri not in MODEL_REGISTRY:
27+
available = ", ".join(sorted(MODEL_REGISTRY.keys()))
28+
msg = f"Unknown model URI: {model_uri}. Available models: {available}"
29+
logger.error(msg)
30+
raise ValueError(msg)
31+
return MODEL_REGISTRY[model_uri]

pyproject.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ requires-python = ">=3.12"
99

1010
dependencies = [
1111
"click>=8.2.1",
12+
"huggingface-hub>=0.26.0",
1213
"sentry-sdk>=2.34.1",
14+
"timdex-dataset-api",
1315
]
1416

1517
[dependency-groups]
@@ -55,11 +57,14 @@ ignore = [
5557
"D101",
5658
"D102",
5759
"D103",
58-
"D104",
60+
"D104",
61+
"G004",
5962
"PLR0912",
6063
"PLR0913",
6164
"PLR0915",
6265
"S321",
66+
"TD002",
67+
"TD003",
6368
]
6469

6570
# allow autofix behavior for specified rules
@@ -84,9 +89,12 @@ max-doc-length = 90
8489
[tool.ruff.lint.pydocstyle]
8590
convention = "google"
8691

92+
[tool.uv.sources]
93+
timdex-dataset-api = { git = "https://github.com/MITLibraries/timdex-dataset-api" }
94+
8795
[project.scripts]
8896
embeddings = "embeddings.cli:main"
8997

9098
[build-system]
9199
requires = ["setuptools>=61"]
92-
build-backend = "setuptools.build_meta"
100+
build-backend = "setuptools.build_meta"

tests/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import zipfile
2+
from pathlib import Path
3+
14
import pytest
25
from click.testing import CliRunner
36

7+
from embeddings.models.base import BaseEmbeddingModel
8+
49

510
@pytest.fixture(autouse=True)
611
def _test_env(monkeypatch):
@@ -11,3 +16,22 @@ def _test_env(monkeypatch):
1116
@pytest.fixture
1217
def runner():
1318
return CliRunner()
19+
20+
21+
class MockEmbeddingModel(BaseEmbeddingModel):
22+
"""Simple test model that doesn't hit external APIs."""
23+
24+
def download(self, output_path: Path) -> Path:
25+
"""Create a fake model zip file for testing."""
26+
output_path.parent.mkdir(parents=True, exist_ok=True)
27+
with zipfile.ZipFile(output_path, "w") as zf:
28+
zf.writestr("config.json", '{"model": "mock", "vocab_size": 30000}')
29+
zf.writestr("pytorch_model.bin", b"fake model weights")
30+
zf.writestr("tokenizer.json", '{"version": "1.0"}')
31+
return output_path
32+
33+
34+
@pytest.fixture
35+
def mock_model():
36+
"""Fixture providing a MockEmbeddingModel instance."""
37+
return MockEmbeddingModel("test/mock-model")

tests/test_cli.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,43 @@
11
from embeddings.cli import main
22

33

4-
def test_cli_no_options(caplog, runner):
5-
result = runner.invoke(main)
4+
def test_cli_default_logging(caplog, runner):
5+
result = runner.invoke(main, ["ping"])
66
assert result.exit_code == 0
77
assert "Logger 'root' configured with level=INFO" in caplog.text
8-
assert "Running process" in caplog.text
9-
assert "Total time to complete process" in caplog.text
108

119

12-
def test_cli_all_options(caplog, runner):
13-
result = runner.invoke(main, ["--verbose"])
10+
def test_cli_debug_logging(caplog, runner):
11+
with caplog.at_level("DEBUG"):
12+
result = runner.invoke(main, ["--verbose", "ping"])
1413
assert result.exit_code == 0
1514
assert "Logger 'root' configured with level=DEBUG" in caplog.text
16-
assert "Running process" in caplog.text
17-
assert "Total time to complete process" in caplog.text
15+
assert "pong" in caplog.text
16+
assert "pong" in result.output
17+
18+
19+
def test_download_model_unknown_uri(caplog, runner):
20+
result = runner.invoke(
21+
main, ["download-model", "--model-uri", "unknown/model", "--output", "out.zip"]
22+
)
23+
assert result.exit_code != 0
24+
assert "Unknown model URI" in caplog.text
25+
26+
27+
def test_download_model_not_implemented(caplog, runner):
28+
caplog.set_level("INFO")
29+
result = runner.invoke(
30+
main,
31+
[
32+
"download-model",
33+
"--model-uri",
34+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte",
35+
"--output",
36+
"out.zip",
37+
],
38+
)
39+
assert (
40+
"Downloading model: opensearch-project/"
41+
"opensearch-neural-sparse-encoding-doc-v3-gte, saving to: out.zip."
42+
) in caplog.text
43+
assert result.exit_code != 0

tests/test_models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import zipfile
2+
3+
import pytest
4+
5+
from embeddings.models.registry import MODEL_REGISTRY, get_model_class
6+
7+
8+
def test_mock_model_instantiation(mock_model):
9+
assert mock_model.model_uri == "test/mock-model"
10+
11+
12+
def test_mock_model_download_creates_zip(mock_model, tmp_path):
13+
output_path = tmp_path / "test_model.zip"
14+
result = mock_model.download(output_path)
15+
16+
assert result == output_path
17+
assert output_path.exists()
18+
assert zipfile.is_zipfile(output_path)
19+
20+
21+
def test_mock_model_download_contains_expected_files(mock_model, tmp_path):
22+
output_path = tmp_path / "test_model.zip"
23+
mock_model.download(output_path)
24+
25+
with zipfile.ZipFile(output_path, "r") as zf:
26+
file_list = zf.namelist()
27+
assert "config.json" in file_list
28+
assert "pytorch_model.bin" in file_list
29+
assert "tokenizer.json" in file_list
30+
31+
32+
def test_registry_contains_opensearch_model():
33+
assert (
34+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
35+
in MODEL_REGISTRY
36+
)
37+
38+
39+
def test_get_model_class_returns_correct_class():
40+
model_class = get_model_class(
41+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
42+
)
43+
assert model_class.__name__ == "OSNeuralSparseDocV3GTE"
44+
45+
46+
def test_get_model_class_raises_for_unknown_uri():
47+
with pytest.raises(ValueError, match="Unknown model URI"):
48+
get_model_class("unknown/model-uri")

0 commit comments

Comments
 (0)