diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index 0fa35148..b9e853fa 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -84,7 +84,7 @@ def _create_model_card( def load_pretrained( - folder_or_repo_path: str | Path, token: str | None = None + folder_or_repo_path: str | Path, token: str | None = None, from_sentence_transformers: bool = False ) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]: """ Loads a pretrained model from a folder. @@ -93,26 +93,31 @@ def load_pretrained( - If this is a local path, we will load from the local path. - If the local path is not found, we will attempt to load from the huggingface hub. :param token: The huggingface token to use. + :param from_sentence_transformers: Whether to load the model from a sentence transformers model. :raises: FileNotFoundError if the folder exists, but the file does not exist locally. :return: The embeddings, tokenizer, config, and metadata. """ + if from_sentence_transformers: + model_file = "0_StaticEmbedding/model.safetensors" + tokenizer_file = "0_StaticEmbedding/tokenizer.json" + config_name = "config_sentence_transformers.json" + else: + model_file = "model.safetensors" + tokenizer_file = "tokenizer.json" + config_name = "config.json" + folder_or_repo_path = Path(folder_or_repo_path) if folder_or_repo_path.exists(): - embeddings_path = folder_or_repo_path / "model.safetensors" + embeddings_path = folder_or_repo_path / model_file if not embeddings_path.exists(): - old_embeddings_path = folder_or_repo_path / "embeddings.safetensors" - if old_embeddings_path.exists(): - logger.warning("Old embeddings file found. Please rename to `model.safetensors` and re-save.") - embeddings_path = old_embeddings_path - else: - raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}") - - config_path = folder_or_repo_path / "config.json" + raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}") + + config_path = folder_or_repo_path / config_name if not config_path.exists(): raise FileNotFoundError(f"Config file does not exist in {folder_or_repo_path}") - tokenizer_path = folder_or_repo_path / "tokenizer.json" + tokenizer_path = folder_or_repo_path / tokenizer_file if not tokenizer_path.exists(): raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}") @@ -122,18 +127,7 @@ def load_pretrained( else: logger.info("Folder does not exist locally, attempting to use huggingface hub.") - try: - embeddings_path = huggingface_hub.hf_hub_download( - folder_or_repo_path.as_posix(), "model.safetensors", token=token - ) - except huggingface_hub.utils.EntryNotFoundError as e: - try: - embeddings_path = huggingface_hub.hf_hub_download( - folder_or_repo_path.as_posix(), "embeddings.safetensors", token=token - ) - except huggingface_hub.utils.EntryNotFoundError: - # Raise original exception. - raise e + embeddings_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), model_file, token=token) try: readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token) @@ -142,11 +136,14 @@ def load_pretrained( logger.info("No README found in the model folder. No model card loaded.") metadata = {} - config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token) - tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token) + config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), config_name, token=token) + tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), tokenizer_file, token=token) opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy")) - embeddings = opened_tensor_file.get_tensor("embeddings") + if from_sentence_transformers: + embeddings = opened_tensor_file.get_tensor("embedding.weight") + else: + embeddings = opened_tensor_file.get_tensor("embeddings") tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path)) config = json.load(open(config_path)) diff --git a/model2vec/model.py b/model2vec/model.py index 8afeb220..5d60d8b7 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -160,12 +160,33 @@ def from_pretrained( """ from model2vec.hf_utils import load_pretrained - embeddings, tokenizer, config, metadata = load_pretrained(path, token=token) + embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False) return cls( embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language") ) + @classmethod + def from_sentence_transformers( + cls: type[StaticModel], + path: PathLike, + token: str | None = None, + ) -> StaticModel: + """ + Load a StaticModel trained with sentence transformers from a local path or huggingface hub path. + + NOTE: if you load a private model from the huggingface hub, you need to pass a token. + + :param path: The path to load your static model from. + :param token: The huggingface token to use. + :return: A StaticModel + """ + from model2vec.hf_utils import load_pretrained + + embeddings, tokenizer, config, _ = load_pretrained(path, token=token, from_sentence_transformers=True) + + return cls(embeddings, tokenizer, config, base_model_name=None, language=None) + def encode_as_sequence( self, sentences: list[str] | str, diff --git a/uv.lock b/uv.lock index 596a7f1a..ee17e1ec 100644 --- a/uv.lock +++ b/uv.lock @@ -160,7 +160,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -536,7 +536,7 @@ wheels = [ [[package]] name = "model2vec" -version = "0.3.4" +version = "0.3.5" source = { editable = "." } dependencies = [ { name = "jinja2" }, @@ -1627,19 +1627,19 @@ dependencies = [ { name = "jinja2" }, { name = "networkx", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -1666,7 +1666,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [