Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast

from model2vec.distill.inference import (
create_output_embeddings_from_model_name,
create_output_embeddings_from_model_name_and_tokens,
create_output_embeddings_from_model,
create_output_embeddings_from_model_and_tokens,
)
from model2vec.distill.tokenizer import add_tokens, preprocess_vocabulary, remove_tokens
from model2vec.distill.utils import select_optimal_device
Expand Down Expand Up @@ -88,7 +88,7 @@ def distill_from_model(
tokens: list[str] = []
if use_subword:
# Create the subword embeddings.
tokens, embeddings = create_output_embeddings_from_model_name(model=model, tokenizer=tokenizer, device=device)
tokens, embeddings = create_output_embeddings_from_model(model=model, tokenizer=tokenizer, device=device)
new_tokenizer, embeddings = _remove_tokens_and_embeddings(tokenizer, token_remove_pattern, tokens, embeddings)
else:
# We need to keep the unk token in the tokenizer.
Expand All @@ -111,7 +111,7 @@ def distill_from_model(
# Only create embeddings if we have tokens to add.
if cleaned_vocabulary:
# Create the embeddings.
_, token_embeddings = create_output_embeddings_from_model_name_and_tokens(
_, token_embeddings = create_output_embeddings_from_model_and_tokens(
model=model,
tokenizer=tokenizer,
tokens=cleaned_vocabulary,
Expand Down
18 changes: 10 additions & 8 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ class ModulewithWeights(Protocol):
weight: torch.nn.Parameter


def create_output_embeddings_from_model_name_and_tokens(
def create_output_embeddings_from_model_and_tokens(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
tokens: list[str],
device: str,
) -> tuple[list[str], np.ndarray]:
"""
Create output embeddings for a bunch of tokens from a model name.
Create output embeddings for a bunch of tokens using a pretrained model.

It does a forward pass for all tokens passed in tokens.
It does a forward pass for all tokens passed in `tokens`.

:param model: The model name to use.
:param model: The model to use.
This should be a transformers model.
:param tokenizer: The tokenizer to use.
:param tokens: The tokens to use.
:param device: The torch device to use.
Expand Down Expand Up @@ -99,17 +100,18 @@ def _encode_mean_using_model(model: PreTrainedModel, tokenizer: PreTrainedTokeni
return result / divisor[:, None]


def create_output_embeddings_from_model_name(
def create_output_embeddings_from_model(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: str,
) -> tuple[list[str], np.ndarray]:
"""
Create output embeddings for a bunch of tokens from a model name.
Create output embeddings for a bunch of tokens using a pretrained model.

It does a forward pass for all ids in the tokenizer.
It does a forward pass for all tokens passed in the tokenizer vocabulary.

:param model: The model name to use.
:param model: The model to use.
This should be a transformers model.
:param tokenizer: The tokenizer to use.
:param device: The torch device to use.
:return: The tokens and output embeddings.
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading