diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index a6df4cb..8bd6a5a 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -9,9 +9,8 @@ from huggingface_hub.hf_api import model_info from skeletoken import TokenizerModel from skeletoken.external.transformers import reshape_embeddings -from transformers import AutoModel, AutoTokenizer +from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings from model2vec.distill.utils import select_optimal_device diff --git a/model2vec/tokenizer/tokenizer.py b/model2vec/tokenizer/tokenizer.py index 0d4f394..d78bb10 100644 --- a/model2vec/tokenizer/tokenizer.py +++ b/model2vec/tokenizer/tokenizer.py @@ -53,6 +53,9 @@ def clean_and_create_vocabulary( logger.warning( f"Token '{token}' was split into multiple tokens after preprocessing: [{split_into}], adding it as a multi-word token." ) + if token in model.vocabulary: + # If the unprocessed token (incorrectly) is in the vocabulary, we should remove it. + model = model.remove_token_from_vocabulary(token) added_tokens_to_add.append(token) continue token = preprocessed[0] diff --git a/tests/test_distillation.py b/tests/test_distillation.py index 7aea79d..9828a6f 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -9,9 +9,9 @@ import pytest from pytest import LogCaptureFixture from skeletoken import TokenizerModel -from transformers import BertTokenizerFast +from transformers import BertTokenizer from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast from model2vec.distill.distillation import distill, distill_from_model from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings @@ -38,6 +38,7 @@ (None, None, 1e-4), # No PCA, SIF on (None, 0.9, 1e-4), # PCA as float (variance), SIF on (["star wars"], 8, None), # Multiword vocabulary + (["..."], 8, None), # Crashing multiword vocabulary ], ) @patch.object(import_module("model2vec.distill.distillation"), "model_info") @@ -92,7 +93,7 @@ def test_distill_from_model( def test_distill_removal_pattern_all_tokens( mock_auto_model: MagicMock, mock_model_info: MagicMock, - mock_berttokenizer: BertTokenizerFast, + mock_berttokenizer: BertTokenizer, mock_transformer: PreTrainedModel, ) -> None: """Test the removal pattern.""" diff --git a/uv.lock b/uv.lock index 4a731b5..9c5aa2f 100644 --- a/uv.lock +++ b/uv.lock @@ -1212,7 +1212,7 @@ requires-dist = [ { name = "scikit-learn", marker = "extra == 'quantization'" }, { name = "scikit-learn", marker = "extra == 'train'" }, { name = "setuptools" }, - { name = "skeletoken", marker = "extra == 'distill'", specifier = ">=0.3.2" }, + { name = "skeletoken", marker = "extra == 'distill'", specifier = ">=0.3.3" }, { name = "skops", marker = "extra == 'inference'" }, { name = "skops", marker = "extra == 'train'" }, { name = "tokenizers", specifier = ">=0.20" }, @@ -2776,7 +2776,7 @@ wheels = [ [[package]] name = "skeletoken" -version = "0.3.2" +version = "0.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, @@ -2785,9 +2785,9 @@ dependencies = [ { name = "tokenizers" }, { name = "transformers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1e/a6/71abc874578c9a3290faa04ad8db7eb60e05b3f5052f6dbec525b28bc133/skeletoken-0.3.2.tar.gz", hash = "sha256:24a423e8f789719f62f5e69e040a062f7467f932626119c14ad4b87184457b46", size = 234150, upload-time = "2026-03-12T15:03:02.096Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/23/4892fa72b6f3ba7fc38a023621216dc4fc2de221288e8a3d98d7978f9cd1/skeletoken-0.3.3.tar.gz", hash = "sha256:f96405a7583ba089fb327a65e92a650d939ea6d35621b3a05b05205246030f1a", size = 234168, upload-time = "2026-04-03T12:39:43.591Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b0/77/555503fd8e5cbef7a824481f7154bf90979c4d1b0604ce9de0221d8d38a1/skeletoken-0.3.2-py3-none-any.whl", hash = "sha256:483d6b76bb508b7de7aa2c00b17915804a2bb1b106393efc9ac73fe3de162690", size = 40302, upload-time = "2026-03-12T15:03:00.92Z" }, + { url = "https://files.pythonhosted.org/packages/2a/0d/fdef375cd6cadb706df245d2ba831eb127a3cb01cb583088aea6422c786d/skeletoken-0.3.3-py3-none-any.whl", hash = "sha256:258b801852312d6c247ca9ca495a758dd6f39523d70c92334f9c617152478b20", size = 40317, upload-time = "2026-04-03T12:39:41.968Z" }, ] [[package]]