diff --git a/model2vec/model.py b/model2vec/model.py index 5961e180..39ca64a1 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -35,7 +35,7 @@ def __init__( :param vectors: The vectors to use. :param tokenizer: The Transformers tokenizer to use. :param config: Any metadata config. - :param normalize: Whether to normalize. + :param normalize: Whether to normalize the embeddings. :param base_model_name: The used base model name. Used for creating a model card. :param language: The language of the model. Used for creating a model card. :raises: ValueError if the number of tokens does not match the number of vectors. @@ -149,6 +149,7 @@ def from_pretrained( cls: type[StaticModel], path: PathLike, token: str | None = None, + normalize: bool | None = None, ) -> StaticModel: """ Load a StaticModel from a local path or huggingface hub path. @@ -157,6 +158,7 @@ def from_pretrained( :param path: The path to load your static model from. :param token: The huggingface token to use. + :param normalize: Whether to normalize the embeddings. :return: A StaticModel """ from model2vec.hf_utils import load_pretrained @@ -164,7 +166,12 @@ def from_pretrained( embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False) return cls( - embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language") + embeddings, + tokenizer, + config, + normalize=normalize, + base_model_name=metadata.get("base_model"), + language=metadata.get("language"), ) @classmethod @@ -172,6 +179,7 @@ def from_sentence_transformers( cls: type[StaticModel], path: PathLike, token: str | None = None, + normalize: bool | None = None, ) -> StaticModel: """ Load a StaticModel trained with sentence transformers from a local path or huggingface hub path. @@ -180,13 +188,14 @@ def from_sentence_transformers( :param path: The path to load your static model from. :param token: The huggingface token to use. + :param normalize: Whether to normalize the embeddings. :return: A StaticModel """ from model2vec.hf_utils import load_pretrained embeddings, tokenizer, config, _ = load_pretrained(path, token=token, from_sentence_transformers=True) - return cls(embeddings, tokenizer, config, base_model_name=None, language=None) + return cls(embeddings, tokenizer, config, normalize=normalize, base_model_name=None, language=None) def encode_as_sequence( self,