From 361a7d50821d0a8b6cf0f35cf5c4651b2575a5a8 Mon Sep 17 00:00:00 2001 From: stephantul Date: Tue, 21 Apr 2026 07:10:31 +0200 Subject: [PATCH 1/2] clean up post-processing code --- model2vec/distill/distillation.py | 14 +++++++------- model2vec/distill/inference.py | 28 +++++++++++++++------------- tests/test_distillation.py | 6 ++++-- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index fd00a2e..589d791 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -5,14 +5,13 @@ import re from typing import cast -import numpy as np from huggingface_hub.hf_api import model_info from skeletoken import TokenizerModel from skeletoken.external.transformers import reshape_embeddings from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast from transformers.modeling_utils import PreTrainedModel -from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings +from model2vec.distill.inference import PCADimType, PoolingMode, apply_pca, compute_weights, create_embeddings from model2vec.distill.utils import select_optimal_device from model2vec.model import StaticModel from model2vec.quantization import DType, quantize_embeddings @@ -108,16 +107,17 @@ def distill_from_model( pooling=pooling, ) - # Maybe apply quantization + # Apply quantization if vocabulary_quantization is not None: - _, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient) + weights = compute_weights(len(embeddings), sif_coefficient=sif_coefficient) embeddings, token_mapping, weights = quantize_vocabulary( - n_clusters=vocabulary_quantization, weights=weights, embeddings=np.asarray(embeddings) + n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings ) - embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient) + embeddings = apply_pca(embeddings, pca_dims) else: # Post-process the embeddings. - embeddings, weights = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient) + weights = compute_weights(len(embeddings), sif_coefficient=sif_coefficient) + embeddings = apply_pca(embeddings, pca_dims) embeddings = embeddings * weights[:, None] weights = None token_mapping = None diff --git a/model2vec/distill/inference.py b/model2vec/distill/inference.py index b08a021..7d83bd1 100644 --- a/model2vec/distill/inference.py +++ b/model2vec/distill/inference.py @@ -207,10 +207,20 @@ def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch return pooler.cpu() -def post_process_embeddings( - embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4 -) -> tuple[np.ndarray, np.ndarray]: - """Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law.""" +def compute_weights(n_embeddings: int, sif_coefficient: float | None) -> np.ndarray: + """Compute the weights based on Zipf's law and a SIF coefficient.""" + if sif_coefficient is None: + return np.ones(n_embeddings) + logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.") + inv_rank = 1 / (np.arange(2, n_embeddings + 2)) + proba = inv_rank / np.sum(inv_rank) + weight = sif_coefficient / (sif_coefficient + proba) + + return weight + + +def apply_pca(embeddings: np.ndarray, pca_dims: PCADimType) -> np.ndarray: + """Apply PCA to the embeddings.""" if pca_dims is not None: if pca_dims == "auto": pca_dims = embeddings.shape[1] @@ -242,12 +252,4 @@ def post_process_embeddings( logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.") logger.info(f"Explained variance: {explained_variance:.3f}.") - if sif_coefficient is not None: - logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.") - inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2)) - proba = inv_rank / np.sum(inv_rank) - weight = sif_coefficient / (sif_coefficient + proba) - else: - weight = np.ones(embeddings.shape[0]) - - return embeddings, weight + return embeddings diff --git a/tests/test_distillation.py b/tests/test_distillation.py index 9828a6f..f5816ec 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -14,7 +14,7 @@ from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast from model2vec.distill.distillation import distill, distill_from_model -from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings +from model2vec.distill.inference import PoolingMode, apply_pca, compute_weights, create_embeddings from model2vec.model import StaticModel from model2vec.tokenizer import clean_and_create_vocabulary @@ -259,7 +259,9 @@ def test__post_process_embeddings( # The implementation logs a warning and skips reduction; no exception expected. pass - processed_embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient) + processed_embeddings = apply_pca(embeddings, pca_dims) + weights = compute_weights(len(processed_embeddings), sif_coefficient=sif_coefficient) + processed_embeddings = processed_embeddings * weights[:, None] # Assert the shape is correct assert processed_embeddings.shape == expected_shape From d1799fdb1a724807d39dbab30c39ef6d6b4e324b Mon Sep 17 00:00:00 2001 From: stephantul Date: Tue, 21 Apr 2026 07:21:06 +0200 Subject: [PATCH 2/2] add additional test --- tests/test_distillation.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_distillation.py b/tests/test_distillation.py index f5816ec..49bbcec 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -88,6 +88,36 @@ def test_distill_from_model( assert token in static_model.tokens or normalized in static_model.tokens +@patch.object(import_module("model2vec.distill.distillation"), "model_info") +@patch("transformers.AutoModel.from_pretrained") +def test_distill_quantization( + mock_auto_model: MagicMock, + mock_model_info: MagicMock, + mock_berttokenizer: PreTrainedTokenizerFast, + mock_transformer: PreTrainedModel, +) -> None: + """Test distill function with different parameters.""" + # Mock the return value of model_info to avoid calling the Hugging Face API + mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}}) + mock_auto_model.return_value = mock_transformer + + static_model = distill_from_model( + model=mock_transformer, + tokenizer=mock_berttokenizer, + vocabulary=None, + device="cpu", + pca_dims="auto", + sif_coefficient=1e-4, + token_remove_pattern=None, + vocabulary_quantization=3, + ) + + assert static_model.embedding.shape == (3, 768) + assert static_model.weights is not None + assert static_model.token_mapping is not None + assert len(static_model.weights) == static_model.tokenizer.get_vocab_size() + + @patch.object(import_module("model2vec.distill.distillation"), "model_info") @patch("transformers.AutoModel.from_pretrained") def test_distill_removal_pattern_all_tokens(