Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ COPY pyproject.toml uv.lock* ./
# Copy CLI application
COPY embeddings ./embeddings

# Install package into system python, includes "marimo-launcher" script
# Install package into system python
RUN uv pip install --system .

# Download the model and include in the Docker image
# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_DOWNLOAD_PATH" are set here to support
# the downloading of the model into this image build, but persist in the container and
# effectively also set this as the default model.
# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_PATH" are set here to support
# the downloading of the model during image build, but also persist in the container and
# effectively set the default model.
ENV HF_HUB_DISABLE_PROGRESS_BARS=true
ENV TE_MODEL_URI=opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
ENV TE_MODEL_DOWNLOAD_PATH=/model
ENV TE_MODEL_PATH=/model
RUN python -m embeddings.cli --verbose download-model

ENTRYPOINT ["python", "-m", "embeddings.cli"]
38 changes: 29 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WORKSPACE=### Set to `dev` for local development, this will be set to `stage` an

```shell
TE_MODEL_URI=# HuggingFace model URI
TE_MODEL_DOWNLOAD_PATH=# Download location for model
TE_MODEL_PATH=# Path where the model will be downloaded to and loaded from
HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts
```

Expand All @@ -34,7 +34,7 @@ This CLI application is designed to create embeddings for input texts. To do th

To this end, there is a base embedding class `BaseEmbeddingModel` that is designed to be extended and customized for a particular embedding model.

Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_DOWNLOAD_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI.
Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI.

This allows invoking the CLI without specifying a model URI or local location, allowing this model to serve as the default, e.g.:

Expand All @@ -61,18 +61,38 @@ Usage: embeddings ping [OPTIONS]
```text
Usage: embeddings download-model [OPTIONS]

Download a model from HuggingFace and save as zip file.
Download a model from HuggingFace and save locally.

Options:
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') [required]
--output PATH Output path for zipped model (e.g., '/path/to/model.zip')
[required]
--help Show this message and exit.
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
[required]
--model-path PATH Path where the model will be downloaded to and loaded
from, e.g. '/path/to/model'. [required]
--help Show this message and exit.
```

### `create-embeddings`
### `test-model-load`
```text
TODO...
Usage: embeddings test-model-load [OPTIONS]

Test loading of embedding class and local model based on env vars.

In a deployed context, the following env vars are expected: -
TE_MODEL_URI - TE_MODEL_PATH

With these set, the embedding class should be registered successfully and
initialized, and the model loaded from a local copy.

This CLI command is NOT used during normal workflows. This is used primary
during development and after model downloading/loading changes to ensure the
model loads correctly.
Comment on lines +86 to +88
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated docstring for this CLI command @ehanson8 😅

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


Options:
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
[required]
--model-path PATH Path where the model will be downloaded to and loaded
from, e.g. '/path/to/model'. [required]
--help Show this message and exit.
```


119 changes: 78 additions & 41 deletions embeddings/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import functools
import logging
import os
import time
from collections.abc import Callable
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING

import click

Expand All @@ -13,21 +13,8 @@

logger = logging.getLogger(__name__)


def model_required(f: Callable) -> Callable:
"""Decorator for commands that require a specific model."""

@click.option(
"--model-uri",
envvar="TE_MODEL_URI",
required=True,
help="HuggingFace model URI (e.g., 'org/model-name')",
)
@functools.wraps(f)
def wrapper(*args: list, **kwargs: dict) -> Callable:
return f(*args, **kwargs)

return wrapper
if TYPE_CHECKING:
from embeddings.models.base import BaseEmbeddingModel


@click.group("embeddings")
Expand Down Expand Up @@ -60,6 +47,60 @@ def _log_command_elapsed_time() -> None:
ctx.call_on_close(_log_command_elapsed_time)


def model_required(f: Callable) -> Callable:
"""Middleware decorator for commands that require an embedding model.

This decorator adds two CLI options:
- "--model-uri": defaults to environment variable "TE_MODEL_URI"
- "--model-path": defaults to environment variable "TE_MODEL_PATH"

The decorator intercepts these parameters, uses the model URI to identify and
instantiate the appropriate embedding model class with the provided model path,
and stores the model instance in the Click context at ctx.obj["model"].

Both model_uri and model_path parameters are consumed by the decorator and not
passed to the decorated command function.
"""
Comment on lines +50 to +63
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted in the docstring, this decorator grew a bit in responsibility.

When applied to a CLI command, the arguments are injected, and now we get a nearly fully initialized embedding class instance back, with the model_path set for use in download() or load() if the CLI commands do either of those.

This decorator was also moved in the file, hence the git diff showing all new.


@click.option(
"--model-uri",
envvar="TE_MODEL_URI",
required=True,
help="HuggingFace model URI (e.g., 'org/model-name')",
)
@click.option(
"--model-path",
required=True,
envvar="TE_MODEL_PATH",
type=click.Path(path_type=Path),
help=(
"Path where the model will be downloaded to and loaded from, "
"e.g. '/path/to/model'."
),
)
@functools.wraps(f)
def wrapper(*args: tuple, **kwargs: dict[str, str | Path]) -> Callable:
# pop "model_uri" and "model_path" from CLI args
model_uri: str = str(kwargs.pop("model_uri"))
model_path: str | Path = str(kwargs.pop("model_path"))

# initialize embedding class
model_class = get_model_class(str(model_uri))
model: BaseEmbeddingModel = model_class(model_path)
logger.info(
f"Embedding class '{model.__class__.__name__}' "
f"initialized from model URI '{model_uri}'."
)

# save embedding class instance to Context
ctx: click.Context = args[0] # type: ignore[assignment]
ctx.obj["model"] = model

return f(*args, **kwargs)

return wrapper


@main.command()
def ping() -> None:
"""Emit 'pong' to debug logs and stdout."""
Expand All @@ -68,53 +109,49 @@ def ping() -> None:


@main.command()
@click.pass_context
@model_required
@click.option(
"--output",
required=True,
envvar="TE_MODEL_DOWNLOAD_PATH",
type=click.Path(path_type=Path),
help="Output path for zipped model (e.g., '/path/to/model.zip')",
)
def download_model(model_uri: str, output: Path) -> None:
"""Download a model from HuggingFace and save as zip file."""
# load embedding model class
model_class = get_model_class(model_uri)
model = model_class()
def download_model(
ctx: click.Context,
) -> None:
"""Download a model from HuggingFace and save locally."""
model: BaseEmbeddingModel = ctx.obj["model"]

# download model assets
logger.info(f"Downloading model: {model_uri}")
result_path = model.download(output)
logger.info(f"Downloading model: {model.model_uri}")
result_path = model.download()

message = f"Model downloaded and saved to: {result_path}"
logger.info(message)
click.echo(result_path)


@main.command()
def test_model_load() -> None:
@click.pass_context
@model_required
def test_model_load(ctx: click.Context) -> None:
"""Test loading of embedding class and local model based on env vars.

In a deployed context, the following env vars are expected:
- TE_MODEL_URI
- TE_MODEL_DOWNLOAD_PATH
- TE_MODEL_PATH

With these set, the embedding class should be registered successfully and initialized,
and the model loaded from a local copy.
"""
# load embedding model class
model_class = get_model_class(os.environ["TE_MODEL_URI"])
model = model_class()

model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"])
This CLI command is NOT used during normal workflows. This is used primary
during development and after model downloading/loading changes to ensure the model
loads correctly.
"""
model: BaseEmbeddingModel = ctx.obj["model"]
model.load()
click.echo("OK")


@main.command()
@click.pass_context
@model_required
def create_embeddings(_model_uri: str) -> None:
# TODO: docstring # noqa: FIX002
raise NotImplementedError
def create_embedding(ctx: click.Context) -> None:
"""Create a single embedding for a single input text."""


if __name__ == "__main__": # pragma: no cover
Expand Down
24 changes: 14 additions & 10 deletions embeddings/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ class BaseEmbeddingModel(ABC):

MODEL_URI: str # Type hint to document the requirement

def __init__(self, model_path: str | Path) -> None:
"""Initialize the embedding model with a model path.

Args:
model_path: Path where the model will be downloaded to and loaded from.
"""
self.model_path = Path(model_path)
Comment on lines +15 to +21
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most changes can be traced back to this change. When we instantiate an embedding class, we pass a model_path always, which can be used for any downloading or loading of the model.

This removes the need to pass around paths, and we can assume it's always required.


def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
super().__init_subclass__(**kwargs)

Expand All @@ -28,17 +36,13 @@ def model_uri(self) -> str:
return self.MODEL_URI

@abstractmethod
def download(self, output_path: str | Path) -> Path:
"""Download and prepare model, saving to output_path.
def download(self) -> Path:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 for the rename to model_path, more descriptive and just fits better throughout the code

"""Download and prepare model, saving to self.model_path.

Args:
output_path: Path where the model zip should be saved.
Returns:
Path where the model was saved.
"""

@abstractmethod
def load(self, model_path: str | Path) -> None:
"""Load model from local, downloaded instance.

Args:
model_path: Path of local model directory.
"""
def load(self) -> None:
"""Load model from self.model_path."""
56 changes: 28 additions & 28 deletions embeddings/models/os_neural_sparse_doc_v3_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,27 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):

MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"

def __init__(self) -> None:
"""Initialize the model."""
super().__init__()
def __init__(self, model_path: str | Path) -> None:
"""Initialize the model.

Args:
model_path: Path where the model will be downloaded to and loaded from.
"""
super().__init__(model_path)
self._model: PreTrainedModel | None = None
self._tokenizer: DistilBertTokenizerFast | None = None
self._special_token_ids: list | None = None
self._id_to_token: list | None = None

def download(self, output_path: str | Path) -> Path:
"""Download and prepare model, saving to output_path.
def download(self) -> Path:
"""Download and prepare model, saving to self.model_path.

Args:
output_path: Path where the model should be saved.
Returns:
Path where the model was saved.
"""
start_time = time.perf_counter()

output_path = Path(output_path)
logger.info(f"Downloading model: {self.model_uri}, saving to: {output_path}.")
logger.info(f"Downloading model: {self.model_uri}, saving to: {self.model_path}.")

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
Expand All @@ -60,19 +63,21 @@ def download(self, output_path: str | Path) -> Path:
self._patch_local_model_with_alibaba_new_impl(temp_path)

# compress model directory as a zip file
if output_path.suffix.lower() == ".zip":
if self.model_path.suffix.lower() == ".zip":
logger.debug("Creating zip file of model contents.")
shutil.make_archive(str(output_path.with_suffix("")), "zip", temp_path)
shutil.make_archive(
str(self.model_path.with_suffix("")), "zip", temp_path
)

# copy to output directory without zipping
else:
logger.debug(f"Copying model contents to {output_path}")
if output_path.exists():
shutil.rmtree(output_path)
shutil.copytree(temp_path, output_path)
logger.debug(f"Copying model contents to {self.model_path}")
if self.model_path.exists():
shutil.rmtree(self.model_path)
shutil.copytree(temp_path, self.model_path)

logger.info(f"Model downloaded successfully, {time.perf_counter() - start_time}s")
return output_path
return self.model_path

def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> None:
"""Patch downloaded model with required assets from Alibaba-NLP/new-impl.
Expand Down Expand Up @@ -124,28 +129,23 @@ def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> Non

logger.debug("Dependency model Alibaba-NLP/new-impl downloaded and used.")

def load(self, model_path: str | Path) -> None:
"""Load the model from the specified path.

Args:
model_path: Path to the model directory.
"""
def load(self) -> None:
"""Load the model from self.model_path."""
start_time = time.perf_counter()
logger.info(f"Loading model from: {model_path}")
model_path = Path(model_path)
logger.info(f"Loading model from: {self.model_path}")

# ensure model exists locally
if not model_path.exists():
raise FileNotFoundError(f"Model not found at path: {model_path}")
if not self.model_path.exists():
raise FileNotFoundError(f"Model not found at path: {self.model_path}")

# load local model and tokenizer
self._model = AutoModelForMaskedLM.from_pretrained(
model_path,
self.model_path,
trust_remote_code=True,
local_files_only=True,
)
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
model_path,
self.model_path,
local_files_only=True,
)

Expand Down
Loading