From 02264cb6a16288704dac2a0193fadab27efc33d1 Mon Sep 17 00:00:00 2001 From: Graham Hukill Date: Tue, 28 Oct 2025 15:07:18 -0400 Subject: [PATCH 1/2] Remove copy/paste Dockerfile comments --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index cd3fc39..f5de3ff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ COPY pyproject.toml uv.lock* ./ # Copy CLI application COPY embeddings ./embeddings -# Install package into system python, includes "marimo-launcher" script +# Install package into system python RUN uv pip install --system . # Download the model and include in the Docker image From 39ef93e29b26e2bb2457fa919e02ff401446ab45 Mon Sep 17 00:00:00 2001 From: Graham Hukill Date: Tue, 28 Oct 2025 15:07:34 -0400 Subject: [PATCH 2/2] Lean into model_required CLI decorator Why these changes are being introduced: Many of the CLI commands will require an embedding class and model to work. A decorator was created originally that injected a --model-uri CLI argument, but it also provides a place to load the class itself and become more of a middleware. How this addresses that need: Updates the model_required decorator to also load the embedding model class. This DRY's up the CLI commands that use it and centralizes that logic and conventions for the CLI argument, env vars, and whatnot. Lastly, it is now required to include a 'model_path' when instantiating a model class instance, and this location is used for both download and load. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-112 --- Dockerfile | 8 +- README.md | 38 ++++-- embeddings/cli.py | 119 ++++++++++++------ embeddings/models/base.py | 24 ++-- .../models/os_neural_sparse_doc_v3_gte.py | 56 ++++----- tests/conftest.py | 18 +-- tests/test_cli.py | 115 ++++++++++++++++- tests/test_models.py | 11 +- tests/test_os_neural_sparse_doc_v3_gte.py | 50 ++++---- 9 files changed, 310 insertions(+), 129 deletions(-) diff --git a/Dockerfile b/Dockerfile index f5de3ff..5c34b88 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,12 +19,12 @@ COPY embeddings ./embeddings RUN uv pip install --system . # Download the model and include in the Docker image -# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_DOWNLOAD_PATH" are set here to support -# the downloading of the model into this image build, but persist in the container and -# effectively also set this as the default model. +# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_PATH" are set here to support +# the downloading of the model during image build, but also persist in the container and +# effectively set the default model. ENV HF_HUB_DISABLE_PROGRESS_BARS=true ENV TE_MODEL_URI=opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte -ENV TE_MODEL_DOWNLOAD_PATH=/model +ENV TE_MODEL_PATH=/model RUN python -m embeddings.cli --verbose download-model ENTRYPOINT ["python", "-m", "embeddings.cli"] diff --git a/README.md b/README.md index 054aa7a..0521d50 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ WORKSPACE=### Set to `dev` for local development, this will be set to `stage` an ```shell TE_MODEL_URI=# HuggingFace model URI -TE_MODEL_DOWNLOAD_PATH=# Download location for model +TE_MODEL_PATH=# Path where the model will be downloaded to and loaded from HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts ``` @@ -34,7 +34,7 @@ This CLI application is designed to create embeddings for input texts. To do th To this end, there is a base embedding class `BaseEmbeddingModel` that is designed to be extended and customized for a particular embedding model. -Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_DOWNLOAD_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI. +Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI. This allows invoking the CLI without specifying a model URI or local location, allowing this model to serve as the default, e.g.: @@ -61,18 +61,38 @@ Usage: embeddings ping [OPTIONS] ```text Usage: embeddings download-model [OPTIONS] - Download a model from HuggingFace and save as zip file. + Download a model from HuggingFace and save locally. Options: - --model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') [required] - --output PATH Output path for zipped model (e.g., '/path/to/model.zip') - [required] - --help Show this message and exit. + --model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') + [required] + --model-path PATH Path where the model will be downloaded to and loaded + from, e.g. '/path/to/model'. [required] + --help Show this message and exit. ``` -### `create-embeddings` +### `test-model-load` ```text -TODO... +Usage: embeddings test-model-load [OPTIONS] + + Test loading of embedding class and local model based on env vars. + + In a deployed context, the following env vars are expected: - + TE_MODEL_URI - TE_MODEL_PATH + + With these set, the embedding class should be registered successfully and + initialized, and the model loaded from a local copy. + + This CLI command is NOT used during normal workflows. This is used primary + during development and after model downloading/loading changes to ensure the + model loads correctly. + +Options: + --model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') + [required] + --model-path PATH Path where the model will be downloaded to and loaded + from, e.g. '/path/to/model'. [required] + --help Show this message and exit. ``` diff --git a/embeddings/cli.py b/embeddings/cli.py index 9f5b84d..f68212b 100644 --- a/embeddings/cli.py +++ b/embeddings/cli.py @@ -1,10 +1,10 @@ import functools import logging -import os import time from collections.abc import Callable from datetime import timedelta from pathlib import Path +from typing import TYPE_CHECKING import click @@ -13,21 +13,8 @@ logger = logging.getLogger(__name__) - -def model_required(f: Callable) -> Callable: - """Decorator for commands that require a specific model.""" - - @click.option( - "--model-uri", - envvar="TE_MODEL_URI", - required=True, - help="HuggingFace model URI (e.g., 'org/model-name')", - ) - @functools.wraps(f) - def wrapper(*args: list, **kwargs: dict) -> Callable: - return f(*args, **kwargs) - - return wrapper +if TYPE_CHECKING: + from embeddings.models.base import BaseEmbeddingModel @click.group("embeddings") @@ -60,6 +47,60 @@ def _log_command_elapsed_time() -> None: ctx.call_on_close(_log_command_elapsed_time) +def model_required(f: Callable) -> Callable: + """Middleware decorator for commands that require an embedding model. + + This decorator adds two CLI options: + - "--model-uri": defaults to environment variable "TE_MODEL_URI" + - "--model-path": defaults to environment variable "TE_MODEL_PATH" + + The decorator intercepts these parameters, uses the model URI to identify and + instantiate the appropriate embedding model class with the provided model path, + and stores the model instance in the Click context at ctx.obj["model"]. + + Both model_uri and model_path parameters are consumed by the decorator and not + passed to the decorated command function. + """ + + @click.option( + "--model-uri", + envvar="TE_MODEL_URI", + required=True, + help="HuggingFace model URI (e.g., 'org/model-name')", + ) + @click.option( + "--model-path", + required=True, + envvar="TE_MODEL_PATH", + type=click.Path(path_type=Path), + help=( + "Path where the model will be downloaded to and loaded from, " + "e.g. '/path/to/model'." + ), + ) + @functools.wraps(f) + def wrapper(*args: tuple, **kwargs: dict[str, str | Path]) -> Callable: + # pop "model_uri" and "model_path" from CLI args + model_uri: str = str(kwargs.pop("model_uri")) + model_path: str | Path = str(kwargs.pop("model_path")) + + # initialize embedding class + model_class = get_model_class(str(model_uri)) + model: BaseEmbeddingModel = model_class(model_path) + logger.info( + f"Embedding class '{model.__class__.__name__}' " + f"initialized from model URI '{model_uri}'." + ) + + # save embedding class instance to Context + ctx: click.Context = args[0] # type: ignore[assignment] + ctx.obj["model"] = model + + return f(*args, **kwargs) + + return wrapper + + @main.command() def ping() -> None: """Emit 'pong' to debug logs and stdout.""" @@ -68,23 +109,16 @@ def ping() -> None: @main.command() +@click.pass_context @model_required -@click.option( - "--output", - required=True, - envvar="TE_MODEL_DOWNLOAD_PATH", - type=click.Path(path_type=Path), - help="Output path for zipped model (e.g., '/path/to/model.zip')", -) -def download_model(model_uri: str, output: Path) -> None: - """Download a model from HuggingFace and save as zip file.""" - # load embedding model class - model_class = get_model_class(model_uri) - model = model_class() +def download_model( + ctx: click.Context, +) -> None: + """Download a model from HuggingFace and save locally.""" + model: BaseEmbeddingModel = ctx.obj["model"] - # download model assets - logger.info(f"Downloading model: {model_uri}") - result_path = model.download(output) + logger.info(f"Downloading model: {model.model_uri}") + result_path = model.download() message = f"Model downloaded and saved to: {result_path}" logger.info(message) @@ -92,29 +126,32 @@ def download_model(model_uri: str, output: Path) -> None: @main.command() -def test_model_load() -> None: +@click.pass_context +@model_required +def test_model_load(ctx: click.Context) -> None: """Test loading of embedding class and local model based on env vars. In a deployed context, the following env vars are expected: - TE_MODEL_URI - - TE_MODEL_DOWNLOAD_PATH + - TE_MODEL_PATH With these set, the embedding class should be registered successfully and initialized, and the model loaded from a local copy. - """ - # load embedding model class - model_class = get_model_class(os.environ["TE_MODEL_URI"]) - model = model_class() - model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"]) + This CLI command is NOT used during normal workflows. This is used primary + during development and after model downloading/loading changes to ensure the model + loads correctly. + """ + model: BaseEmbeddingModel = ctx.obj["model"] + model.load() click.echo("OK") @main.command() +@click.pass_context @model_required -def create_embeddings(_model_uri: str) -> None: - # TODO: docstring # noqa: FIX002 - raise NotImplementedError +def create_embedding(ctx: click.Context) -> None: + """Create a single embedding for a single input text.""" if __name__ == "__main__": # pragma: no cover diff --git a/embeddings/models/base.py b/embeddings/models/base.py index e64a5c2..34d55fe 100644 --- a/embeddings/models/base.py +++ b/embeddings/models/base.py @@ -12,6 +12,14 @@ class BaseEmbeddingModel(ABC): MODEL_URI: str # Type hint to document the requirement + def __init__(self, model_path: str | Path) -> None: + """Initialize the embedding model with a model path. + + Args: + model_path: Path where the model will be downloaded to and loaded from. + """ + self.model_path = Path(model_path) + def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 super().__init_subclass__(**kwargs) @@ -28,17 +36,13 @@ def model_uri(self) -> str: return self.MODEL_URI @abstractmethod - def download(self, output_path: str | Path) -> Path: - """Download and prepare model, saving to output_path. + def download(self) -> Path: + """Download and prepare model, saving to self.model_path. - Args: - output_path: Path where the model zip should be saved. + Returns: + Path where the model was saved. """ @abstractmethod - def load(self, model_path: str | Path) -> None: - """Load model from local, downloaded instance. - - Args: - model_path: Path of local model directory. - """ + def load(self) -> None: + """Load model from self.model_path.""" diff --git a/embeddings/models/os_neural_sparse_doc_v3_gte.py b/embeddings/models/os_neural_sparse_doc_v3_gte.py index b4f6dee..8e72f23 100644 --- a/embeddings/models/os_neural_sparse_doc_v3_gte.py +++ b/embeddings/models/os_neural_sparse_doc_v3_gte.py @@ -30,24 +30,27 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel): MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte" - def __init__(self) -> None: - """Initialize the model.""" - super().__init__() + def __init__(self, model_path: str | Path) -> None: + """Initialize the model. + + Args: + model_path: Path where the model will be downloaded to and loaded from. + """ + super().__init__(model_path) self._model: PreTrainedModel | None = None self._tokenizer: DistilBertTokenizerFast | None = None self._special_token_ids: list | None = None self._id_to_token: list | None = None - def download(self, output_path: str | Path) -> Path: - """Download and prepare model, saving to output_path. + def download(self) -> Path: + """Download and prepare model, saving to self.model_path. - Args: - output_path: Path where the model should be saved. + Returns: + Path where the model was saved. """ start_time = time.perf_counter() - output_path = Path(output_path) - logger.info(f"Downloading model: {self.model_uri}, saving to: {output_path}.") + logger.info(f"Downloading model: {self.model_uri}, saving to: {self.model_path}.") with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -60,19 +63,21 @@ def download(self, output_path: str | Path) -> Path: self._patch_local_model_with_alibaba_new_impl(temp_path) # compress model directory as a zip file - if output_path.suffix.lower() == ".zip": + if self.model_path.suffix.lower() == ".zip": logger.debug("Creating zip file of model contents.") - shutil.make_archive(str(output_path.with_suffix("")), "zip", temp_path) + shutil.make_archive( + str(self.model_path.with_suffix("")), "zip", temp_path + ) # copy to output directory without zipping else: - logger.debug(f"Copying model contents to {output_path}") - if output_path.exists(): - shutil.rmtree(output_path) - shutil.copytree(temp_path, output_path) + logger.debug(f"Copying model contents to {self.model_path}") + if self.model_path.exists(): + shutil.rmtree(self.model_path) + shutil.copytree(temp_path, self.model_path) logger.info(f"Model downloaded successfully, {time.perf_counter() - start_time}s") - return output_path + return self.model_path def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> None: """Patch downloaded model with required assets from Alibaba-NLP/new-impl. @@ -124,28 +129,23 @@ def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> Non logger.debug("Dependency model Alibaba-NLP/new-impl downloaded and used.") - def load(self, model_path: str | Path) -> None: - """Load the model from the specified path. - - Args: - model_path: Path to the model directory. - """ + def load(self) -> None: + """Load the model from self.model_path.""" start_time = time.perf_counter() - logger.info(f"Loading model from: {model_path}") - model_path = Path(model_path) + logger.info(f"Loading model from: {self.model_path}") # ensure model exists locally - if not model_path.exists(): - raise FileNotFoundError(f"Model not found at path: {model_path}") + if not self.model_path.exists(): + raise FileNotFoundError(f"Model not found at path: {self.model_path}") # load local model and tokenizer self._model = AutoModelForMaskedLM.from_pretrained( - model_path, + self.model_path, trust_remote_code=True, local_files_only=True, ) self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call] - model_path, + self.model_path, local_files_only=True, ) diff --git a/tests/conftest.py b/tests/conftest.py index 7eb9262..3a5fbf3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,23 +27,27 @@ class MockEmbeddingModel(BaseEmbeddingModel): MODEL_URI = "test/mock-model" - def download(self, output_path: Path) -> Path: + def __init__(self, model_path: str | Path) -> None: + """Initialize the mock model.""" + super().__init__(model_path) + + def download(self) -> Path: """Create a fake model zip file for testing.""" - output_path.parent.mkdir(parents=True, exist_ok=True) - with zipfile.ZipFile(output_path, "w") as zf: + self.model_path.parent.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(self.model_path, "w") as zf: zf.writestr("config.json", '{"model": "mock", "vocab_size": 30000}') zf.writestr("pytorch_model.bin", b"fake model weights") zf.writestr("tokenizer.json", '{"version": "1.0"}') - return output_path + return self.model_path - def load(self, model_path: str | Path) -> None: # noqa: ARG002 + def load(self) -> None: logger.info("Model loaded successfully, 1.5s") @pytest.fixture -def mock_model(): +def mock_model(tmp_path): """Fixture providing a MockEmbeddingModel instance.""" - return MockEmbeddingModel() + return MockEmbeddingModel(tmp_path / "model") @pytest.fixture diff --git a/tests/test_cli.py b/tests/test_cli.py index e51eb35..63a466c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,6 @@ from embeddings.cli import main +from embeddings.models import registry +from tests.conftest import MockEmbeddingModel def test_cli_default_logging(caplog, runner): @@ -18,7 +20,118 @@ def test_cli_debug_logging(caplog, runner): def test_download_model_unknown_uri(caplog, runner): result = runner.invoke( - main, ["download-model", "--model-uri", "unknown/model", "--output", "out.zip"] + main, + ["download-model", "--model-uri", "unknown/model", "--model-path", "out.zip"], ) assert result.exit_code != 0 assert "Unknown model URI" in caplog.text + + +def test_model_required_decorator_with_cli_option(caplog, monkeypatch, runner, tmp_path): + """Test decorator successfully initializes model from --model-uri option.""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + + output_path = tmp_path / "model.zip" + result = runner.invoke( + main, + [ + "download-model", + "--model-uri", + "test/mock-model", + "--model-path", + str(output_path), + ], + ) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + assert output_path.exists() + + +def test_model_required_decorator_with_env_var(caplog, monkeypatch, runner, tmp_path): + """Test decorator successfully initializes model from TE_MODEL_URI env var.""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + monkeypatch.setenv("TE_MODEL_URI", "test/mock-model") + + output_path = tmp_path / "model.zip" + result = runner.invoke(main, ["download-model", "--model-path", str(output_path)]) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + assert output_path.exists() + + +def test_model_required_decorator_missing_parameter(runner): + """Test decorator fails when --model-uri is not provided and env var is not set.""" + result = runner.invoke(main, ["download-model", "--model-path", "out.zip"]) + + assert result.exit_code != 0 + assert "Missing option '--model-uri'" in result.output + + +def test_model_required_decorator_stores_model_in_context( + caplog, monkeypatch, runner, tmp_path +): + """Test decorator stores model instance in ctx.obj['model'].""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + + output_path = tmp_path / "model.zip" + result = runner.invoke( + main, + [ + "download-model", + "--model-uri", + "test/mock-model", + "--model-path", + str(output_path), + ], + ) + + assert result.exit_code == 0 + # verify the model was used successfully (download method was called) + assert "Downloading model: test/mock-model" in caplog.text + assert output_path.exists() + + +def test_model_required_decorator_log_message(caplog, monkeypatch, runner, tmp_path): + """Test decorator logs correct initialization message.""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + + output_path = tmp_path / "model.zip" + result = runner.invoke( + main, + [ + "download-model", + "--model-uri", + "test/mock-model", + "--model-path", + str(output_path), + ], + ) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + + +def test_model_required_decorator_works_across_commands(caplog, monkeypatch, runner): + """Test decorator works for multiple commands (test_model_load).""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + monkeypatch.setenv("TE_MODEL_PATH", "/fake/path") + + result = runner.invoke(main, ["test-model-load", "--model-uri", "test/mock-model"]) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + assert "OK" in result.output diff --git a/tests/test_models.py b/tests/test_models.py index caeba00..9fc8f24 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,24 +4,27 @@ from embeddings.models.base import BaseEmbeddingModel from embeddings.models.registry import MODEL_REGISTRY, get_model_class +from tests.conftest import MockEmbeddingModel def test_mock_model_instantiation(mock_model): assert mock_model.model_uri == "test/mock-model" -def test_mock_model_download_creates_zip(mock_model, tmp_path): +def test_mock_model_download_creates_zip(tmp_path): output_path = tmp_path / "test_model.zip" - result = mock_model.download(output_path) + mock_model = MockEmbeddingModel(output_path) + result = mock_model.download() assert result == output_path assert output_path.exists() assert zipfile.is_zipfile(output_path) -def test_mock_model_download_contains_expected_files(mock_model, tmp_path): +def test_mock_model_download_contains_expected_files(tmp_path): output_path = tmp_path / "test_model.zip" - mock_model.download(output_path) + mock_model = MockEmbeddingModel(output_path) + mock_model.download() with zipfile.ZipFile(output_path, "r") as zf: file_list = zf.namelist() diff --git a/tests/test_os_neural_sparse_doc_v3_gte.py b/tests/test_os_neural_sparse_doc_v3_gte.py index 06268fd..3e093eb 100644 --- a/tests/test_os_neural_sparse_doc_v3_gte.py +++ b/tests/test_os_neural_sparse_doc_v3_gte.py @@ -10,18 +10,18 @@ from embeddings.models.os_neural_sparse_doc_v3_gte import OSNeuralSparseDocV3GTE -def test_init(): +def test_init(tmp_path): """Test model initialization.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") assert model._model is None assert model._tokenizer is None assert model._special_token_ids is None assert model._id_to_token is None -def test_model_uri(): +def test_model_uri(tmp_path): """Test model_uri property returns correct HuggingFace URI.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") assert ( model.model_uri == "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte" @@ -36,10 +36,10 @@ def test_download_to_directory( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test download to directory (not zip).""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model_output" + model = OSNeuralSparseDocV3GTE(output_path) - result = model.download(output_path) + result = model.download() assert result == output_path assert output_path.exists() @@ -52,10 +52,10 @@ def test_download_to_zip_file( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test download creates zip when path ends in .zip.""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model.zip" + model = OSNeuralSparseDocV3GTE(output_path) - result = model.download(output_path) + result = model.download() assert result == output_path assert output_path.exists() @@ -66,8 +66,8 @@ def test_download_calls_patch_method( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path, monkeypatch ): """Test that download calls the Alibaba patching method.""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model_output" + model = OSNeuralSparseDocV3GTE(output_path) patch_called = False @@ -77,7 +77,7 @@ def mock_patch(temp_path): monkeypatch.setattr(model, "_patch_local_model_with_alibaba_new_impl", mock_patch) - model.download(output_path) + model.download() assert patch_called @@ -86,10 +86,10 @@ def test_download_returns_path( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test download returns the output path.""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model_output" + model = OSNeuralSparseDocV3GTE(output_path) - result = model.download(output_path) + result = model.download() assert result == output_path assert isinstance(result, Path) @@ -99,7 +99,7 @@ def test_patch_downloads_alibaba_model( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test patch method downloads Alibaba-NLP/new-impl.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") model_temp_path = tmp_path / "temp_model" model_temp_path.mkdir() (model_temp_path / "config.json").write_text('{"model_type": "test"}') @@ -112,7 +112,7 @@ def test_patch_downloads_alibaba_model( def test_patch_copies_files(neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path): """Test patch copies modeling.py and configuration.py.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") model_temp_path = tmp_path / "temp_model" model_temp_path.mkdir() (model_temp_path / "config.json").write_text('{"model_type": "test"}') @@ -130,7 +130,7 @@ def test_patch_updates_config_json( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test patch updates auto_map in config.json.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") model_temp_path = tmp_path / "temp_model" model_temp_path.mkdir() initial_config = {"model_type": "test", "vocab_size": 30000} @@ -151,9 +151,9 @@ def test_load_success( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test successful load from local path.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._model is not None assert model._tokenizer is not None @@ -161,11 +161,11 @@ def test_load_success( def test_load_file_not_found(): """Test load raises FileNotFoundError for missing path.""" - model = OSNeuralSparseDocV3GTE() nonexistent_path = Path("/nonexistent/path") + model = OSNeuralSparseDocV3GTE(nonexistent_path) with pytest.raises(FileNotFoundError, match="Model not found at path"): - model.load(nonexistent_path) + model.load() def test_load_initializes_model_and_tokenizer( @@ -173,12 +173,12 @@ def test_load_initializes_model_and_tokenizer( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test load initializes _model and _tokenizer attributes.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) assert model._model is None assert model._tokenizer is None - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._model is not None assert model._tokenizer is not None @@ -189,9 +189,9 @@ def test_load_sets_up_special_token_ids( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test load sets up _special_token_ids list.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._special_token_ids is not None assert isinstance(model._special_token_ids, list) @@ -206,9 +206,9 @@ def test_load_sets_up_id_to_token_mapping( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test load creates _id_to_token mapping correctly.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._id_to_token is not None assert isinstance(model._id_to_token, list)