Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from huggingface_hub.hf_api import model_info
from skeletoken import TokenizerModel
from skeletoken.external.transformers import reshape_embeddings
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings
from model2vec.distill.utils import select_optimal_device
Expand Down
3 changes: 3 additions & 0 deletions model2vec/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def clean_and_create_vocabulary(
logger.warning(
f"Token '{token}' was split into multiple tokens after preprocessing: [{split_into}], adding it as a multi-word token."
)
if token in model.vocabulary:
# If the unprocessed token (incorrectly) is in the vocabulary, we should remove it.
model = model.remove_token_from_vocabulary(token)
added_tokens_to_add.append(token)
continue
token = preprocessed[0]
Expand Down
7 changes: 4 additions & 3 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import pytest
from pytest import LogCaptureFixture
from skeletoken import TokenizerModel
from transformers import BertTokenizerFast
from transformers import BertTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast

from model2vec.distill.distillation import distill, distill_from_model
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
Expand All @@ -38,6 +38,7 @@
(None, None, 1e-4), # No PCA, SIF on
(None, 0.9, 1e-4), # PCA as float (variance), SIF on
(["star wars"], 8, None), # Multiword vocabulary
(["..."], 8, None), # Crashing multiword vocabulary
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
Expand Down Expand Up @@ -92,7 +93,7 @@ def test_distill_from_model(
def test_distill_removal_pattern_all_tokens(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: BertTokenizerFast,
mock_berttokenizer: BertTokenizer,
mock_transformer: PreTrainedModel,
) -> None:
"""Test the removal pattern."""
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading