1+ import functools
12import logging
23import time
4+ from collections .abc import Callable
35from datetime import timedelta
46from pathlib import Path
57
1113logger = logging .getLogger (__name__ )
1214
1315
16+ def model_required (f : Callable ) -> Callable :
17+ """Decorator for commands that require a specific model."""
18+
19+ @click .option (
20+ "--model-uri" ,
21+ envvar = "TE_MODEL_URI" ,
22+ required = True ,
23+ help = "HuggingFace model URI (e.g., 'org/model-name')" ,
24+ )
25+ @functools .wraps (f )
26+ def wrapper (* args : list , ** kwargs : dict ) -> Callable :
27+ return f (* args , ** kwargs )
28+
29+ return wrapper
30+
31+
1432@click .group ("embeddings" )
1533@click .option (
1634 "-v" ,
@@ -49,22 +67,19 @@ def ping() -> None:
4967
5068
5169@main .command ()
52- @click .option (
53- "--model-uri" ,
54- required = True ,
55- help = "HuggingFace model URI (e.g., 'org/model-name')" ,
56- )
70+ @model_required
5771@click .option (
5872 "--output" ,
5973 required = True ,
74+ envvar = "TE_MODEL_DOWNLOAD_PATH" ,
6075 type = click .Path (path_type = Path ),
6176 help = "Output path for zipped model (e.g., '/path/to/model.zip')" ,
6277)
6378def download_model (model_uri : str , output : Path ) -> None :
6479 """Download a model from HuggingFace and save as zip file."""
6580 # load embedding model class
6681 model_class = get_model_class (model_uri )
67- model = model_class (model_uri )
82+ model = model_class ()
6883
6984 # download model assets
7085 logger .info (f"Downloading model: { model_uri } " )
@@ -76,11 +91,7 @@ def download_model(model_uri: str, output: Path) -> None:
7691
7792
7893@main .command ()
79- @click .option (
80- "--model-uri" ,
81- required = True ,
82- help = "HuggingFace model URI (e.g., 'org/model-name')" ,
83- )
94+ @model_required
8495def create_embeddings (_model_uri : str ) -> None :
8596 # TODO: docstring # noqa: FIX002
8697 raise NotImplementedError
0 commit comments