diff --git a/model2vec/train/README.md b/model2vec/train/README.md index 2d7aad2b..e48acbb2 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -22,7 +22,7 @@ from model2vec.train import StaticModelForClassification # From a distilled model distilled_model = distill("baai/bge-base-en-v1.5") -classifier = StaticModelForClassification.from_static_model(distilled_model) +classifier = StaticModelForClassification.from_static_model(model=distilled_model) # From a pre-trained model: potion is the default classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m") diff --git a/model2vec/train/base.py b/model2vec/train/base.py index db1d5394..60749fe0 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -45,14 +45,14 @@ def construct_head(self) -> nn.Sequential: @classmethod def from_pretrained( - cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any + cls: type[ModelType], *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any ) -> ModelType: """Load the model from a pretrained model2vec model.""" model = StaticModel.from_pretrained(model_name) - return cls.from_static_model(model, out_dim, **kwargs) + return cls.from_static_model(model=model, out_dim=out_dim, **kwargs) @classmethod - def from_static_model(cls: type[ModelType], model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType: + def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType: """Load the model from a static model.""" model.embedding = np.nan_to_num(model.embedding) embeddings_converted = torch.from_numpy(model.embedding) diff --git a/tests/test_trainable.py b/tests/test_trainable.py index 9add8b25..ea0e354d 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -41,7 +41,7 @@ def test_init_base_class(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: """Test initializion from a static model.""" model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer) - s = FinetunableStaticModel.from_static_model(model) + s = FinetunableStaticModel.from_static_model(model=model) assert s.vectors.shape == mock_vectors.shape assert s.w.shape[0] == mock_vectors.shape[0] @@ -55,7 +55,7 @@ def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenize def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: """Test initializion from a static model.""" model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer) - s = StaticModelForClassification.from_static_model(model) + s = StaticModelForClassification.from_static_model(model=model) assert s.vectors.shape == mock_vectors.shape assert s.w.shape[0] == mock_vectors.shape[0]