diff --git a/Dockerfile b/Dockerfile index cd3fc39..5c34b88 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,16 +15,16 @@ COPY pyproject.toml uv.lock* ./ # Copy CLI application COPY embeddings ./embeddings -# Install package into system python, includes "marimo-launcher" script +# Install package into system python RUN uv pip install --system . # Download the model and include in the Docker image -# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_DOWNLOAD_PATH" are set here to support -# the downloading of the model into this image build, but persist in the container and -# effectively also set this as the default model. +# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_PATH" are set here to support +# the downloading of the model during image build, but also persist in the container and +# effectively set the default model. ENV HF_HUB_DISABLE_PROGRESS_BARS=true ENV TE_MODEL_URI=opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte -ENV TE_MODEL_DOWNLOAD_PATH=/model +ENV TE_MODEL_PATH=/model RUN python -m embeddings.cli --verbose download-model ENTRYPOINT ["python", "-m", "embeddings.cli"] diff --git a/README.md b/README.md index 054aa7a..0521d50 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ WORKSPACE=### Set to `dev` for local development, this will be set to `stage` an ```shell TE_MODEL_URI=# HuggingFace model URI -TE_MODEL_DOWNLOAD_PATH=# Download location for model +TE_MODEL_PATH=# Path where the model will be downloaded to and loaded from HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts ``` @@ -34,7 +34,7 @@ This CLI application is designed to create embeddings for input texts. To do th To this end, there is a base embedding class `BaseEmbeddingModel` that is designed to be extended and customized for a particular embedding model. -Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_DOWNLOAD_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI. +Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI. This allows invoking the CLI without specifying a model URI or local location, allowing this model to serve as the default, e.g.: @@ -61,18 +61,38 @@ Usage: embeddings ping [OPTIONS] ```text Usage: embeddings download-model [OPTIONS] - Download a model from HuggingFace and save as zip file. + Download a model from HuggingFace and save locally. Options: - --model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') [required] - --output PATH Output path for zipped model (e.g., '/path/to/model.zip') - [required] - --help Show this message and exit. + --model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') + [required] + --model-path PATH Path where the model will be downloaded to and loaded + from, e.g. '/path/to/model'. [required] + --help Show this message and exit. ``` -### `create-embeddings` +### `test-model-load` ```text -TODO... +Usage: embeddings test-model-load [OPTIONS] + + Test loading of embedding class and local model based on env vars. + + In a deployed context, the following env vars are expected: - + TE_MODEL_URI - TE_MODEL_PATH + + With these set, the embedding class should be registered successfully and + initialized, and the model loaded from a local copy. + + This CLI command is NOT used during normal workflows. This is used primary + during development and after model downloading/loading changes to ensure the + model loads correctly. + +Options: + --model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') + [required] + --model-path PATH Path where the model will be downloaded to and loaded + from, e.g. '/path/to/model'. [required] + --help Show this message and exit. ``` diff --git a/embeddings/cli.py b/embeddings/cli.py index 9f5b84d..f68212b 100644 --- a/embeddings/cli.py +++ b/embeddings/cli.py @@ -1,10 +1,10 @@ import functools import logging -import os import time from collections.abc import Callable from datetime import timedelta from pathlib import Path +from typing import TYPE_CHECKING import click @@ -13,21 +13,8 @@ logger = logging.getLogger(__name__) - -def model_required(f: Callable) -> Callable: - """Decorator for commands that require a specific model.""" - - @click.option( - "--model-uri", - envvar="TE_MODEL_URI", - required=True, - help="HuggingFace model URI (e.g., 'org/model-name')", - ) - @functools.wraps(f) - def wrapper(*args: list, **kwargs: dict) -> Callable: - return f(*args, **kwargs) - - return wrapper +if TYPE_CHECKING: + from embeddings.models.base import BaseEmbeddingModel @click.group("embeddings") @@ -60,6 +47,60 @@ def _log_command_elapsed_time() -> None: ctx.call_on_close(_log_command_elapsed_time) +def model_required(f: Callable) -> Callable: + """Middleware decorator for commands that require an embedding model. + + This decorator adds two CLI options: + - "--model-uri": defaults to environment variable "TE_MODEL_URI" + - "--model-path": defaults to environment variable "TE_MODEL_PATH" + + The decorator intercepts these parameters, uses the model URI to identify and + instantiate the appropriate embedding model class with the provided model path, + and stores the model instance in the Click context at ctx.obj["model"]. + + Both model_uri and model_path parameters are consumed by the decorator and not + passed to the decorated command function. + """ + + @click.option( + "--model-uri", + envvar="TE_MODEL_URI", + required=True, + help="HuggingFace model URI (e.g., 'org/model-name')", + ) + @click.option( + "--model-path", + required=True, + envvar="TE_MODEL_PATH", + type=click.Path(path_type=Path), + help=( + "Path where the model will be downloaded to and loaded from, " + "e.g. '/path/to/model'." + ), + ) + @functools.wraps(f) + def wrapper(*args: tuple, **kwargs: dict[str, str | Path]) -> Callable: + # pop "model_uri" and "model_path" from CLI args + model_uri: str = str(kwargs.pop("model_uri")) + model_path: str | Path = str(kwargs.pop("model_path")) + + # initialize embedding class + model_class = get_model_class(str(model_uri)) + model: BaseEmbeddingModel = model_class(model_path) + logger.info( + f"Embedding class '{model.__class__.__name__}' " + f"initialized from model URI '{model_uri}'." + ) + + # save embedding class instance to Context + ctx: click.Context = args[0] # type: ignore[assignment] + ctx.obj["model"] = model + + return f(*args, **kwargs) + + return wrapper + + @main.command() def ping() -> None: """Emit 'pong' to debug logs and stdout.""" @@ -68,23 +109,16 @@ def ping() -> None: @main.command() +@click.pass_context @model_required -@click.option( - "--output", - required=True, - envvar="TE_MODEL_DOWNLOAD_PATH", - type=click.Path(path_type=Path), - help="Output path for zipped model (e.g., '/path/to/model.zip')", -) -def download_model(model_uri: str, output: Path) -> None: - """Download a model from HuggingFace and save as zip file.""" - # load embedding model class - model_class = get_model_class(model_uri) - model = model_class() +def download_model( + ctx: click.Context, +) -> None: + """Download a model from HuggingFace and save locally.""" + model: BaseEmbeddingModel = ctx.obj["model"] - # download model assets - logger.info(f"Downloading model: {model_uri}") - result_path = model.download(output) + logger.info(f"Downloading model: {model.model_uri}") + result_path = model.download() message = f"Model downloaded and saved to: {result_path}" logger.info(message) @@ -92,29 +126,32 @@ def download_model(model_uri: str, output: Path) -> None: @main.command() -def test_model_load() -> None: +@click.pass_context +@model_required +def test_model_load(ctx: click.Context) -> None: """Test loading of embedding class and local model based on env vars. In a deployed context, the following env vars are expected: - TE_MODEL_URI - - TE_MODEL_DOWNLOAD_PATH + - TE_MODEL_PATH With these set, the embedding class should be registered successfully and initialized, and the model loaded from a local copy. - """ - # load embedding model class - model_class = get_model_class(os.environ["TE_MODEL_URI"]) - model = model_class() - model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"]) + This CLI command is NOT used during normal workflows. This is used primary + during development and after model downloading/loading changes to ensure the model + loads correctly. + """ + model: BaseEmbeddingModel = ctx.obj["model"] + model.load() click.echo("OK") @main.command() +@click.pass_context @model_required -def create_embeddings(_model_uri: str) -> None: - # TODO: docstring # noqa: FIX002 - raise NotImplementedError +def create_embedding(ctx: click.Context) -> None: + """Create a single embedding for a single input text.""" if __name__ == "__main__": # pragma: no cover diff --git a/embeddings/models/base.py b/embeddings/models/base.py index e64a5c2..34d55fe 100644 --- a/embeddings/models/base.py +++ b/embeddings/models/base.py @@ -12,6 +12,14 @@ class BaseEmbeddingModel(ABC): MODEL_URI: str # Type hint to document the requirement + def __init__(self, model_path: str | Path) -> None: + """Initialize the embedding model with a model path. + + Args: + model_path: Path where the model will be downloaded to and loaded from. + """ + self.model_path = Path(model_path) + def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 super().__init_subclass__(**kwargs) @@ -28,17 +36,13 @@ def model_uri(self) -> str: return self.MODEL_URI @abstractmethod - def download(self, output_path: str | Path) -> Path: - """Download and prepare model, saving to output_path. + def download(self) -> Path: + """Download and prepare model, saving to self.model_path. - Args: - output_path: Path where the model zip should be saved. + Returns: + Path where the model was saved. """ @abstractmethod - def load(self, model_path: str | Path) -> None: - """Load model from local, downloaded instance. - - Args: - model_path: Path of local model directory. - """ + def load(self) -> None: + """Load model from self.model_path.""" diff --git a/embeddings/models/os_neural_sparse_doc_v3_gte.py b/embeddings/models/os_neural_sparse_doc_v3_gte.py index b4f6dee..8e72f23 100644 --- a/embeddings/models/os_neural_sparse_doc_v3_gte.py +++ b/embeddings/models/os_neural_sparse_doc_v3_gte.py @@ -30,24 +30,27 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel): MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte" - def __init__(self) -> None: - """Initialize the model.""" - super().__init__() + def __init__(self, model_path: str | Path) -> None: + """Initialize the model. + + Args: + model_path: Path where the model will be downloaded to and loaded from. + """ + super().__init__(model_path) self._model: PreTrainedModel | None = None self._tokenizer: DistilBertTokenizerFast | None = None self._special_token_ids: list | None = None self._id_to_token: list | None = None - def download(self, output_path: str | Path) -> Path: - """Download and prepare model, saving to output_path. + def download(self) -> Path: + """Download and prepare model, saving to self.model_path. - Args: - output_path: Path where the model should be saved. + Returns: + Path where the model was saved. """ start_time = time.perf_counter() - output_path = Path(output_path) - logger.info(f"Downloading model: {self.model_uri}, saving to: {output_path}.") + logger.info(f"Downloading model: {self.model_uri}, saving to: {self.model_path}.") with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -60,19 +63,21 @@ def download(self, output_path: str | Path) -> Path: self._patch_local_model_with_alibaba_new_impl(temp_path) # compress model directory as a zip file - if output_path.suffix.lower() == ".zip": + if self.model_path.suffix.lower() == ".zip": logger.debug("Creating zip file of model contents.") - shutil.make_archive(str(output_path.with_suffix("")), "zip", temp_path) + shutil.make_archive( + str(self.model_path.with_suffix("")), "zip", temp_path + ) # copy to output directory without zipping else: - logger.debug(f"Copying model contents to {output_path}") - if output_path.exists(): - shutil.rmtree(output_path) - shutil.copytree(temp_path, output_path) + logger.debug(f"Copying model contents to {self.model_path}") + if self.model_path.exists(): + shutil.rmtree(self.model_path) + shutil.copytree(temp_path, self.model_path) logger.info(f"Model downloaded successfully, {time.perf_counter() - start_time}s") - return output_path + return self.model_path def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> None: """Patch downloaded model with required assets from Alibaba-NLP/new-impl. @@ -124,28 +129,23 @@ def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> Non logger.debug("Dependency model Alibaba-NLP/new-impl downloaded and used.") - def load(self, model_path: str | Path) -> None: - """Load the model from the specified path. - - Args: - model_path: Path to the model directory. - """ + def load(self) -> None: + """Load the model from self.model_path.""" start_time = time.perf_counter() - logger.info(f"Loading model from: {model_path}") - model_path = Path(model_path) + logger.info(f"Loading model from: {self.model_path}") # ensure model exists locally - if not model_path.exists(): - raise FileNotFoundError(f"Model not found at path: {model_path}") + if not self.model_path.exists(): + raise FileNotFoundError(f"Model not found at path: {self.model_path}") # load local model and tokenizer self._model = AutoModelForMaskedLM.from_pretrained( - model_path, + self.model_path, trust_remote_code=True, local_files_only=True, ) self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call] - model_path, + self.model_path, local_files_only=True, ) diff --git a/tests/conftest.py b/tests/conftest.py index 7eb9262..3a5fbf3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,23 +27,27 @@ class MockEmbeddingModel(BaseEmbeddingModel): MODEL_URI = "test/mock-model" - def download(self, output_path: Path) -> Path: + def __init__(self, model_path: str | Path) -> None: + """Initialize the mock model.""" + super().__init__(model_path) + + def download(self) -> Path: """Create a fake model zip file for testing.""" - output_path.parent.mkdir(parents=True, exist_ok=True) - with zipfile.ZipFile(output_path, "w") as zf: + self.model_path.parent.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(self.model_path, "w") as zf: zf.writestr("config.json", '{"model": "mock", "vocab_size": 30000}') zf.writestr("pytorch_model.bin", b"fake model weights") zf.writestr("tokenizer.json", '{"version": "1.0"}') - return output_path + return self.model_path - def load(self, model_path: str | Path) -> None: # noqa: ARG002 + def load(self) -> None: logger.info("Model loaded successfully, 1.5s") @pytest.fixture -def mock_model(): +def mock_model(tmp_path): """Fixture providing a MockEmbeddingModel instance.""" - return MockEmbeddingModel() + return MockEmbeddingModel(tmp_path / "model") @pytest.fixture diff --git a/tests/test_cli.py b/tests/test_cli.py index e51eb35..63a466c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,6 @@ from embeddings.cli import main +from embeddings.models import registry +from tests.conftest import MockEmbeddingModel def test_cli_default_logging(caplog, runner): @@ -18,7 +20,118 @@ def test_cli_debug_logging(caplog, runner): def test_download_model_unknown_uri(caplog, runner): result = runner.invoke( - main, ["download-model", "--model-uri", "unknown/model", "--output", "out.zip"] + main, + ["download-model", "--model-uri", "unknown/model", "--model-path", "out.zip"], ) assert result.exit_code != 0 assert "Unknown model URI" in caplog.text + + +def test_model_required_decorator_with_cli_option(caplog, monkeypatch, runner, tmp_path): + """Test decorator successfully initializes model from --model-uri option.""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + + output_path = tmp_path / "model.zip" + result = runner.invoke( + main, + [ + "download-model", + "--model-uri", + "test/mock-model", + "--model-path", + str(output_path), + ], + ) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + assert output_path.exists() + + +def test_model_required_decorator_with_env_var(caplog, monkeypatch, runner, tmp_path): + """Test decorator successfully initializes model from TE_MODEL_URI env var.""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + monkeypatch.setenv("TE_MODEL_URI", "test/mock-model") + + output_path = tmp_path / "model.zip" + result = runner.invoke(main, ["download-model", "--model-path", str(output_path)]) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + assert output_path.exists() + + +def test_model_required_decorator_missing_parameter(runner): + """Test decorator fails when --model-uri is not provided and env var is not set.""" + result = runner.invoke(main, ["download-model", "--model-path", "out.zip"]) + + assert result.exit_code != 0 + assert "Missing option '--model-uri'" in result.output + + +def test_model_required_decorator_stores_model_in_context( + caplog, monkeypatch, runner, tmp_path +): + """Test decorator stores model instance in ctx.obj['model'].""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + + output_path = tmp_path / "model.zip" + result = runner.invoke( + main, + [ + "download-model", + "--model-uri", + "test/mock-model", + "--model-path", + str(output_path), + ], + ) + + assert result.exit_code == 0 + # verify the model was used successfully (download method was called) + assert "Downloading model: test/mock-model" in caplog.text + assert output_path.exists() + + +def test_model_required_decorator_log_message(caplog, monkeypatch, runner, tmp_path): + """Test decorator logs correct initialization message.""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + + output_path = tmp_path / "model.zip" + result = runner.invoke( + main, + [ + "download-model", + "--model-uri", + "test/mock-model", + "--model-path", + str(output_path), + ], + ) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + + +def test_model_required_decorator_works_across_commands(caplog, monkeypatch, runner): + """Test decorator works for multiple commands (test_model_load).""" + monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel) + monkeypatch.setenv("TE_MODEL_PATH", "/fake/path") + + result = runner.invoke(main, ["test-model-load", "--model-uri", "test/mock-model"]) + + assert result.exit_code == 0 + assert ( + "Embedding class 'MockEmbeddingModel' initialized from model URI " + "'test/mock-model'" in caplog.text + ) + assert "OK" in result.output diff --git a/tests/test_models.py b/tests/test_models.py index caeba00..9fc8f24 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,24 +4,27 @@ from embeddings.models.base import BaseEmbeddingModel from embeddings.models.registry import MODEL_REGISTRY, get_model_class +from tests.conftest import MockEmbeddingModel def test_mock_model_instantiation(mock_model): assert mock_model.model_uri == "test/mock-model" -def test_mock_model_download_creates_zip(mock_model, tmp_path): +def test_mock_model_download_creates_zip(tmp_path): output_path = tmp_path / "test_model.zip" - result = mock_model.download(output_path) + mock_model = MockEmbeddingModel(output_path) + result = mock_model.download() assert result == output_path assert output_path.exists() assert zipfile.is_zipfile(output_path) -def test_mock_model_download_contains_expected_files(mock_model, tmp_path): +def test_mock_model_download_contains_expected_files(tmp_path): output_path = tmp_path / "test_model.zip" - mock_model.download(output_path) + mock_model = MockEmbeddingModel(output_path) + mock_model.download() with zipfile.ZipFile(output_path, "r") as zf: file_list = zf.namelist() diff --git a/tests/test_os_neural_sparse_doc_v3_gte.py b/tests/test_os_neural_sparse_doc_v3_gte.py index 06268fd..3e093eb 100644 --- a/tests/test_os_neural_sparse_doc_v3_gte.py +++ b/tests/test_os_neural_sparse_doc_v3_gte.py @@ -10,18 +10,18 @@ from embeddings.models.os_neural_sparse_doc_v3_gte import OSNeuralSparseDocV3GTE -def test_init(): +def test_init(tmp_path): """Test model initialization.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") assert model._model is None assert model._tokenizer is None assert model._special_token_ids is None assert model._id_to_token is None -def test_model_uri(): +def test_model_uri(tmp_path): """Test model_uri property returns correct HuggingFace URI.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") assert ( model.model_uri == "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte" @@ -36,10 +36,10 @@ def test_download_to_directory( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test download to directory (not zip).""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model_output" + model = OSNeuralSparseDocV3GTE(output_path) - result = model.download(output_path) + result = model.download() assert result == output_path assert output_path.exists() @@ -52,10 +52,10 @@ def test_download_to_zip_file( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test download creates zip when path ends in .zip.""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model.zip" + model = OSNeuralSparseDocV3GTE(output_path) - result = model.download(output_path) + result = model.download() assert result == output_path assert output_path.exists() @@ -66,8 +66,8 @@ def test_download_calls_patch_method( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path, monkeypatch ): """Test that download calls the Alibaba patching method.""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model_output" + model = OSNeuralSparseDocV3GTE(output_path) patch_called = False @@ -77,7 +77,7 @@ def mock_patch(temp_path): monkeypatch.setattr(model, "_patch_local_model_with_alibaba_new_impl", mock_patch) - model.download(output_path) + model.download() assert patch_called @@ -86,10 +86,10 @@ def test_download_returns_path( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test download returns the output path.""" - model = OSNeuralSparseDocV3GTE() output_path = tmp_path / "model_output" + model = OSNeuralSparseDocV3GTE(output_path) - result = model.download(output_path) + result = model.download() assert result == output_path assert isinstance(result, Path) @@ -99,7 +99,7 @@ def test_patch_downloads_alibaba_model( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test patch method downloads Alibaba-NLP/new-impl.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") model_temp_path = tmp_path / "temp_model" model_temp_path.mkdir() (model_temp_path / "config.json").write_text('{"model_type": "test"}') @@ -112,7 +112,7 @@ def test_patch_downloads_alibaba_model( def test_patch_copies_files(neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path): """Test patch copies modeling.py and configuration.py.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") model_temp_path = tmp_path / "temp_model" model_temp_path.mkdir() (model_temp_path / "config.json").write_text('{"model_type": "test"}') @@ -130,7 +130,7 @@ def test_patch_updates_config_json( neural_sparse_doc_v3_gte_mock_huggingface_snapshot, tmp_path ): """Test patch updates auto_map in config.json.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(tmp_path / "model") model_temp_path = tmp_path / "temp_model" model_temp_path.mkdir() initial_config = {"model_type": "test", "vocab_size": 30000} @@ -151,9 +151,9 @@ def test_load_success( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test successful load from local path.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._model is not None assert model._tokenizer is not None @@ -161,11 +161,11 @@ def test_load_success( def test_load_file_not_found(): """Test load raises FileNotFoundError for missing path.""" - model = OSNeuralSparseDocV3GTE() nonexistent_path = Path("/nonexistent/path") + model = OSNeuralSparseDocV3GTE(nonexistent_path) with pytest.raises(FileNotFoundError, match="Model not found at path"): - model.load(nonexistent_path) + model.load() def test_load_initializes_model_and_tokenizer( @@ -173,12 +173,12 @@ def test_load_initializes_model_and_tokenizer( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test load initializes _model and _tokenizer attributes.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) assert model._model is None assert model._tokenizer is None - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._model is not None assert model._tokenizer is not None @@ -189,9 +189,9 @@ def test_load_sets_up_special_token_ids( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test load sets up _special_token_ids list.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._special_token_ids is not None assert isinstance(model._special_token_ids, list) @@ -206,9 +206,9 @@ def test_load_sets_up_id_to_token_mapping( neural_sparse_doc_v3_gte_mock_transformers_models, ): """Test load creates _id_to_token mapping correctly.""" - model = OSNeuralSparseDocV3GTE() + model = OSNeuralSparseDocV3GTE(neural_sparse_doc_v3_gte_fake_model_directory) - model.load(neural_sparse_doc_v3_gte_fake_model_directory) + model.load() assert model._id_to_token is not None assert isinstance(model._id_to_token, list)