diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index 1c261538..4da4e503 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -13,8 +13,8 @@ from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast from model2vec.distill.inference import ( - create_output_embeddings_from_model_name, - create_output_embeddings_from_model_name_and_tokens, + create_output_embeddings_from_model, + create_output_embeddings_from_model_and_tokens, ) from model2vec.distill.tokenizer import add_tokens, preprocess_vocabulary, remove_tokens from model2vec.distill.utils import select_optimal_device @@ -88,7 +88,7 @@ def distill_from_model( tokens: list[str] = [] if use_subword: # Create the subword embeddings. - tokens, embeddings = create_output_embeddings_from_model_name(model=model, tokenizer=tokenizer, device=device) + tokens, embeddings = create_output_embeddings_from_model(model=model, tokenizer=tokenizer, device=device) new_tokenizer, embeddings = _remove_tokens_and_embeddings(tokenizer, token_remove_pattern, tokens, embeddings) else: # We need to keep the unk token in the tokenizer. @@ -111,7 +111,7 @@ def distill_from_model( # Only create embeddings if we have tokens to add. if cleaned_vocabulary: # Create the embeddings. - _, token_embeddings = create_output_embeddings_from_model_name_and_tokens( + _, token_embeddings = create_output_embeddings_from_model_and_tokens( model=model, tokenizer=tokenizer, tokens=cleaned_vocabulary, diff --git a/model2vec/distill/inference.py b/model2vec/distill/inference.py index a8dbfcc9..6b63519f 100644 --- a/model2vec/distill/inference.py +++ b/model2vec/distill/inference.py @@ -24,18 +24,19 @@ class ModulewithWeights(Protocol): weight: torch.nn.Parameter -def create_output_embeddings_from_model_name_and_tokens( +def create_output_embeddings_from_model_and_tokens( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, tokens: list[str], device: str, ) -> tuple[list[str], np.ndarray]: """ - Create output embeddings for a bunch of tokens from a model name. + Create output embeddings for a bunch of tokens using a pretrained model. - It does a forward pass for all tokens passed in tokens. + It does a forward pass for all tokens passed in `tokens`. - :param model: The model name to use. + :param model: The model to use. + This should be a transformers model. :param tokenizer: The tokenizer to use. :param tokens: The tokens to use. :param device: The torch device to use. @@ -99,17 +100,18 @@ def _encode_mean_using_model(model: PreTrainedModel, tokenizer: PreTrainedTokeni return result / divisor[:, None] -def create_output_embeddings_from_model_name( +def create_output_embeddings_from_model( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device: str, ) -> tuple[list[str], np.ndarray]: """ - Create output embeddings for a bunch of tokens from a model name. + Create output embeddings for a bunch of tokens using a pretrained model. - It does a forward pass for all ids in the tokenizer. + It does a forward pass for all tokens passed in the tokenizer vocabulary. - :param model: The model name to use. + :param model: The model to use. + This should be a transformers model. :param tokenizer: The tokenizer to use. :param device: The torch device to use. :return: The tokens and output embeddings. diff --git a/uv.lock b/uv.lock index ee17e1ec..951b309e 100644 --- a/uv.lock +++ b/uv.lock @@ -536,7 +536,7 @@ wheels = [ [[package]] name = "model2vec" -version = "0.3.5" +version = "0.3.6" source = { editable = "." } dependencies = [ { name = "jinja2" },