Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
:param vectors: The vectors to use.
:param tokenizer: The Transformers tokenizer to use.
:param config: Any metadata config.
:param normalize: Whether to normalize.
:param normalize: Whether to normalize the embeddings.
:param base_model_name: The used base model name. Used for creating a model card.
:param language: The language of the model. Used for creating a model card.
:raises: ValueError if the number of tokens does not match the number of vectors.
Expand Down Expand Up @@ -149,6 +149,7 @@ def from_pretrained(
cls: type[StaticModel],
path: PathLike,
token: str | None = None,
normalize: bool | None = None,
) -> StaticModel:
"""
Load a StaticModel from a local path or huggingface hub path.
Expand All @@ -157,21 +158,28 @@ def from_pretrained(

:param path: The path to load your static model from.
:param token: The huggingface token to use.
:param normalize: Whether to normalize the embeddings.
:return: A StaticModel
"""
from model2vec.hf_utils import load_pretrained

embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)

return cls(
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
embeddings,
tokenizer,
config,
normalize=normalize,
base_model_name=metadata.get("base_model"),
language=metadata.get("language"),
)

@classmethod
def from_sentence_transformers(
cls: type[StaticModel],
path: PathLike,
token: str | None = None,
normalize: bool | None = None,
) -> StaticModel:
"""
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
Expand All @@ -180,13 +188,14 @@ def from_sentence_transformers(

:param path: The path to load your static model from.
:param token: The huggingface token to use.
:param normalize: Whether to normalize the embeddings.
:return: A StaticModel
"""
from model2vec.hf_utils import load_pretrained

embeddings, tokenizer, config, _ = load_pretrained(path, token=token, from_sentence_transformers=True)

return cls(embeddings, tokenizer, config, base_model_name=None, language=None)
return cls(embeddings, tokenizer, config, normalize=normalize, base_model_name=None, language=None)

def encode_as_sequence(
self,
Expand Down