-
Notifications
You must be signed in to change notification settings - Fork 0
USE 112 - refactor model load #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,10 @@ | ||
| import functools | ||
| import logging | ||
| import os | ||
| import time | ||
| from collections.abc import Callable | ||
| from datetime import timedelta | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import click | ||
|
|
||
|
|
@@ -13,21 +13,8 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def model_required(f: Callable) -> Callable: | ||
| """Decorator for commands that require a specific model.""" | ||
|
|
||
| @click.option( | ||
| "--model-uri", | ||
| envvar="TE_MODEL_URI", | ||
| required=True, | ||
| help="HuggingFace model URI (e.g., 'org/model-name')", | ||
| ) | ||
| @functools.wraps(f) | ||
| def wrapper(*args: list, **kwargs: dict) -> Callable: | ||
| return f(*args, **kwargs) | ||
|
|
||
| return wrapper | ||
| if TYPE_CHECKING: | ||
| from embeddings.models.base import BaseEmbeddingModel | ||
|
|
||
|
|
||
| @click.group("embeddings") | ||
|
|
@@ -60,6 +47,60 @@ def _log_command_elapsed_time() -> None: | |
| ctx.call_on_close(_log_command_elapsed_time) | ||
|
|
||
|
|
||
| def model_required(f: Callable) -> Callable: | ||
| """Middleware decorator for commands that require an embedding model. | ||
|
|
||
| This decorator adds two CLI options: | ||
| - "--model-uri": defaults to environment variable "TE_MODEL_URI" | ||
| - "--model-path": defaults to environment variable "TE_MODEL_PATH" | ||
|
|
||
| The decorator intercepts these parameters, uses the model URI to identify and | ||
| instantiate the appropriate embedding model class with the provided model path, | ||
| and stores the model instance in the Click context at ctx.obj["model"]. | ||
|
|
||
| Both model_uri and model_path parameters are consumed by the decorator and not | ||
| passed to the decorated command function. | ||
| """ | ||
|
Comment on lines
+50
to
+63
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noted in the docstring, this decorator grew a bit in responsibility. When applied to a CLI command, the arguments are injected, and now we get a nearly fully initialized embedding class instance back, with the This decorator was also moved in the file, hence the git diff showing all new. |
||
|
|
||
| @click.option( | ||
| "--model-uri", | ||
| envvar="TE_MODEL_URI", | ||
| required=True, | ||
| help="HuggingFace model URI (e.g., 'org/model-name')", | ||
| ) | ||
| @click.option( | ||
| "--model-path", | ||
| required=True, | ||
| envvar="TE_MODEL_PATH", | ||
| type=click.Path(path_type=Path), | ||
| help=( | ||
| "Path where the model will be downloaded to and loaded from, " | ||
| "e.g. '/path/to/model'." | ||
| ), | ||
| ) | ||
| @functools.wraps(f) | ||
| def wrapper(*args: tuple, **kwargs: dict[str, str | Path]) -> Callable: | ||
| # pop "model_uri" and "model_path" from CLI args | ||
| model_uri: str = str(kwargs.pop("model_uri")) | ||
| model_path: str | Path = str(kwargs.pop("model_path")) | ||
|
|
||
| # initialize embedding class | ||
| model_class = get_model_class(str(model_uri)) | ||
| model: BaseEmbeddingModel = model_class(model_path) | ||
| logger.info( | ||
| f"Embedding class '{model.__class__.__name__}' " | ||
| f"initialized from model URI '{model_uri}'." | ||
| ) | ||
|
|
||
| # save embedding class instance to Context | ||
| ctx: click.Context = args[0] # type: ignore[assignment] | ||
| ctx.obj["model"] = model | ||
|
|
||
| return f(*args, **kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
| @main.command() | ||
| def ping() -> None: | ||
| """Emit 'pong' to debug logs and stdout.""" | ||
|
|
@@ -68,53 +109,49 @@ def ping() -> None: | |
|
|
||
|
|
||
| @main.command() | ||
| @click.pass_context | ||
| @model_required | ||
| @click.option( | ||
| "--output", | ||
| required=True, | ||
| envvar="TE_MODEL_DOWNLOAD_PATH", | ||
| type=click.Path(path_type=Path), | ||
| help="Output path for zipped model (e.g., '/path/to/model.zip')", | ||
| ) | ||
| def download_model(model_uri: str, output: Path) -> None: | ||
| """Download a model from HuggingFace and save as zip file.""" | ||
| # load embedding model class | ||
| model_class = get_model_class(model_uri) | ||
| model = model_class() | ||
| def download_model( | ||
| ctx: click.Context, | ||
| ) -> None: | ||
| """Download a model from HuggingFace and save locally.""" | ||
| model: BaseEmbeddingModel = ctx.obj["model"] | ||
|
|
||
| # download model assets | ||
| logger.info(f"Downloading model: {model_uri}") | ||
| result_path = model.download(output) | ||
| logger.info(f"Downloading model: {model.model_uri}") | ||
| result_path = model.download() | ||
|
|
||
| message = f"Model downloaded and saved to: {result_path}" | ||
| logger.info(message) | ||
| click.echo(result_path) | ||
|
|
||
|
|
||
| @main.command() | ||
| def test_model_load() -> None: | ||
| @click.pass_context | ||
| @model_required | ||
| def test_model_load(ctx: click.Context) -> None: | ||
| """Test loading of embedding class and local model based on env vars. | ||
|
|
||
| In a deployed context, the following env vars are expected: | ||
| - TE_MODEL_URI | ||
| - TE_MODEL_DOWNLOAD_PATH | ||
| - TE_MODEL_PATH | ||
|
|
||
| With these set, the embedding class should be registered successfully and initialized, | ||
| and the model loaded from a local copy. | ||
| """ | ||
| # load embedding model class | ||
| model_class = get_model_class(os.environ["TE_MODEL_URI"]) | ||
| model = model_class() | ||
|
|
||
| model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"]) | ||
| This CLI command is NOT used during normal workflows. This is used primary | ||
| during development and after model downloading/loading changes to ensure the model | ||
| loads correctly. | ||
| """ | ||
| model: BaseEmbeddingModel = ctx.obj["model"] | ||
| model.load() | ||
| click.echo("OK") | ||
|
|
||
|
|
||
| @main.command() | ||
| @click.pass_context | ||
| @model_required | ||
| def create_embeddings(_model_uri: str) -> None: | ||
| # TODO: docstring # noqa: FIX002 | ||
| raise NotImplementedError | ||
| def create_embedding(ctx: click.Context) -> None: | ||
| """Create a single embedding for a single input text.""" | ||
|
|
||
|
|
||
| if __name__ == "__main__": # pragma: no cover | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,14 @@ class BaseEmbeddingModel(ABC): | |
|
|
||
| MODEL_URI: str # Type hint to document the requirement | ||
|
|
||
| def __init__(self, model_path: str | Path) -> None: | ||
| """Initialize the embedding model with a model path. | ||
|
|
||
| Args: | ||
| model_path: Path where the model will be downloaded to and loaded from. | ||
| """ | ||
| self.model_path = Path(model_path) | ||
|
Comment on lines
+15
to
+21
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most changes can be traced back to this change. When we instantiate an embedding class, we pass a This removes the need to pass around paths, and we can assume it's always required. |
||
|
|
||
| def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105 | ||
| super().__init_subclass__(**kwargs) | ||
|
|
||
|
|
@@ -28,17 +36,13 @@ def model_uri(self) -> str: | |
| return self.MODEL_URI | ||
|
|
||
| @abstractmethod | ||
| def download(self, output_path: str | Path) -> Path: | ||
| """Download and prepare model, saving to output_path. | ||
| def download(self) -> Path: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💯 for the rename to |
||
| """Download and prepare model, saving to self.model_path. | ||
|
|
||
| Args: | ||
| output_path: Path where the model zip should be saved. | ||
| Returns: | ||
| Path where the model was saved. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def load(self, model_path: str | Path) -> None: | ||
| """Load model from local, downloaded instance. | ||
|
|
||
| Args: | ||
| model_path: Path of local model directory. | ||
| """ | ||
| def load(self) -> None: | ||
| """Load model from self.model_path.""" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated docstring for this CLI command @ehanson8 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!