Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions model2vec/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _create_model_card(


def load_pretrained(
folder_or_repo_path: str | Path, token: str | None = None
folder_or_repo_path: str | Path, token: str | None = None, from_sentence_transformers: bool = False
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
"""
Loads a pretrained model from a folder.
Expand All @@ -93,26 +93,31 @@ def load_pretrained(
- If this is a local path, we will load from the local path.
- If the local path is not found, we will attempt to load from the huggingface hub.
:param token: The huggingface token to use.
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
:return: The embeddings, tokenizer, config, and metadata.

"""
if from_sentence_transformers:
model_file = "0_StaticEmbedding/model.safetensors"
tokenizer_file = "0_StaticEmbedding/tokenizer.json"
config_name = "config_sentence_transformers.json"
else:
model_file = "model.safetensors"
tokenizer_file = "tokenizer.json"
config_name = "config.json"

folder_or_repo_path = Path(folder_or_repo_path)
if folder_or_repo_path.exists():
embeddings_path = folder_or_repo_path / "model.safetensors"
embeddings_path = folder_or_repo_path / model_file
if not embeddings_path.exists():
old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
if old_embeddings_path.exists():
logger.warning("Old embeddings file found. Please rename to `model.safetensors` and re-save.")
embeddings_path = old_embeddings_path
else:
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")

config_path = folder_or_repo_path / "config.json"
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")

config_path = folder_or_repo_path / config_name
if not config_path.exists():
raise FileNotFoundError(f"Config file does not exist in {folder_or_repo_path}")

tokenizer_path = folder_or_repo_path / "tokenizer.json"
tokenizer_path = folder_or_repo_path / tokenizer_file
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")

Expand All @@ -122,18 +127,7 @@ def load_pretrained(

else:
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
try:
embeddings_path = huggingface_hub.hf_hub_download(
folder_or_repo_path.as_posix(), "model.safetensors", token=token
)
except huggingface_hub.utils.EntryNotFoundError as e:
try:
embeddings_path = huggingface_hub.hf_hub_download(
folder_or_repo_path.as_posix(), "embeddings.safetensors", token=token
)
except huggingface_hub.utils.EntryNotFoundError:
# Raise original exception.
raise e
embeddings_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), model_file, token=token)

try:
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
Expand All @@ -142,11 +136,14 @@ def load_pretrained(
logger.info("No README found in the model folder. No model card loaded.")
metadata = {}

config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)
config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), config_name, token=token)
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), tokenizer_file, token=token)

opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
embeddings = opened_tensor_file.get_tensor("embeddings")
if from_sentence_transformers:
embeddings = opened_tensor_file.get_tensor("embedding.weight")
else:
embeddings = opened_tensor_file.get_tensor("embeddings")

tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
config = json.load(open(config_path))
Expand Down
23 changes: 22 additions & 1 deletion model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,33 @@ def from_pretrained(
"""
from model2vec.hf_utils import load_pretrained

embeddings, tokenizer, config, metadata = load_pretrained(path, token=token)
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)

return cls(
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
)

@classmethod
def from_sentence_transformers(
cls: type[StaticModel],
path: PathLike,
token: str | None = None,
) -> StaticModel:
"""
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.

NOTE: if you load a private model from the huggingface hub, you need to pass a token.

:param path: The path to load your static model from.
:param token: The huggingface token to use.
:return: A StaticModel
"""
from model2vec.hf_utils import load_pretrained

embeddings, tokenizer, config, _ = load_pretrained(path, token=token, from_sentence_transformers=True)

return cls(embeddings, tokenizer, config, base_model_name=None, language=None)

def encode_as_sequence(
self,
sentences: list[str] | str,
Expand Down
30 changes: 15 additions & 15 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading