Dispatch sentence-transformers pooling for bert + xlm_roberta (fix CLS-pooled checkpoints)#63
Merged
Blaizzy merged 4 commits intoMay 13, 2026
Conversation
bert.py and xlm_roberta.py both hard-coded `mean_pooling` for `text_embeds`,
ignoring sentence-transformers' `1_Pooling/config.json`. CLS-pooled checkpoints
(bge-base-en-v1.5, snowflake-arctic-embed-l-v2.0, the rest of the bge family,
mxbai, …) silently returned mean-pooled vectors instead of CLS, with measurable
cosine drift against the SentenceTransformer reference.
Changes:
- base.py: add cls_pooling, max_pooling, lasttoken_pooling translated from
sentence_transformers/sentence_transformer/modules/pooling.py @ 8151750.
Add `_normalize_pooling_config` (port of `_convert_legacy_pooling_kwargs`)
and `pool_by_config` dispatcher. Modes outside {cls, mean, max, lasttoken}
raise NotImplementedError; tuple multi-mode and include_prompt=False raise
too. Empirical mode coverage: ~100% of top-60 ST checkpoints on the Hub.
- utils.py: `_read_pooling_config` loads `1_Pooling/config.json` when present
and injects it into the config dict, so the existing `model_config` override
mechanism keeps "caller wins" precedence intact.
- bert.py / xlm_roberta.py: add `pooling_config` field on ModelArgs with
default `{"pooling_mode": "mean"}` (visible in the dataclass signature),
swap the hard-coded `mean_pooling` for `pool_by_config(...)`.
- tests/test_base.py: ports of the five HF unit tests that map to our
supported surface — cls right-pad, cls left-pad, max-respects-mask,
lasttoken finds last, lasttoken all-padding-zeros — plus the gold-standard
`test_pooling_exact_values` (HF's shared fixture, all four supported modes)
and the two `_convert_legacy_pooling_kwargs` conversion tests. Each carries
a line-pinned reference back to the HF source at commit 8151750.
Verified end-to-end with the mlx-embeddings-tests harness:
bge-base-en-v1.5 (bert, CLS): cos_sim 0.9587 -> 1.000000
snowflake-arctic-embed-l-v2.0 (xlm-r, CLS): cos_sim 0.8156 -> 1.000000
all-MiniLM-L6-v2 (bert, MEAN, control): cos_sim 1.000000 (unchanged)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…vert - Move pool helpers + dispatcher from `models/base.py` to a new `models/pooling.py`. `base.py` goes back to its pre-PR shape (data classes + `normalize_embeddings`). - Update imports across all model files that use `mean_pooling`: `bert`, `xlm_roberta`, `modernbert`, `gemma3_text`, `llama_bidirec`, `lfm2`, `llama_nemotron_vl`. - Tighten `pool_by_config` / `_normalize_pooling_config` signatures from `Optional[Dict[str, Any]]` to `Dict[str, Any]`; remove the now-unreachable `if cfg is None` fallback (the dataclass default factory guarantees a dict reaches the dispatcher). - Tighten `bert.ModelArgs.pooling_config` and `xlm_roberta.ModelArgs.pooling_config` from `Optional[dict]` to `dict`. - `convert.py`: preserve `1_Pooling/` subdirectory when converting an HF checkpoint. The top-level `*.json` glob doesn't recurse; without this, converted `mlx-community/*` variants lose the pooling sidecar and the loader silently falls back to mean. - Split `tests/test_base.py` -> `tests/test_pooling.py` for the new pool-helper / dispatcher / config-normalization tests; replace per-test "Port of test_xxx" annotations with the upstream comments verbatim and add the two HF-port dispatcher tests (`test_forward_all_modes`, `test_invalid_mode_raises`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Owner
|
Could you run pre-commit and push changes ? |
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
Author
|
Pre-commit applied |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The bug
Some models that look like they're supported by
mlx-embeddingsactually produce wrong embeddings. The model loads fine andtext_embedscomes out the right shape, but the values diverge from whatsentence-transformerswould produce for the same checkpoint.Concretely:
BAAI/bge-base-en-v1.5(BERT backbone)Snowflake/snowflake-arctic-embed-l-v2.0(XLM-RoBERTa backbone)Both checkpoints use CLS pooling per their
1_Pooling/config.json. Butbert.pyandxlm_roberta.pyhard-codemean_pooling(...)fortext_embedsregardless of what the checkpoint says. So every CLS-pooledsentence-transformersmodel (the whole BGE family, Snowflake's arctic-embed family,mixedbread-ai/mxbai-embed-*, etc.) silently returns mean-pooled vectors.The fix
1_Pooling/config.jsonfrom the model directory at load time and surface it onModelArgs.pooling_config. Falls back to mean pooling when no1_Pooling/config.jsonis present (i.e. for plain HF checkpoints), so nothing that worked before breaks.pool_by_config) that pickscls_pooling/mean_pooling/max_pooling/lasttoken_poolingbased on the loaded mode.bert.pyandxlm_roberta.pyin place of the hard-codedmean_pooling. Other model files are unaffected by the bug (different defaults or different pooling expectations), so they're left as-is.1_Pooling/inconvert.pyso checkpoints uploaded tomlx-community/*via the conversion CLI keep their pooling sidecar (the top-level*.jsonglob didn't recurse into subdirectories, so pre-converted variants on the Hub today are missing it).The pool functions, the legacy-config conversion logic, and the unit tests are translated from
sentence-transformers.Verification
I ran a small benchmark on a few inputs, comparing the embeddings produced by
mlx-embeddingsagainst those fromSentenceTransformer.encode(...)(cosine similarity):BAAI/bge-base-en-v1.5Snowflake/snowflake-arctic-embed-l-v2.050/50 unit tests pass (
pytest mlx_embeddings/tests/).Note about existing users
If you've stored embeddings produced by mlx-embeddings against
BAAI/bge-*,Snowflake/snowflake-arctic-embed-*,mixedbread-ai/mxbai-embed-*, or any other CLS-pooled sentence-transformers checkpoint, those vectors will change after this PR lands.First PR note
This is my first PR to this repo — I bumped into the bug while using
bge-base-en-v1.5for a retrieval pipeline, noticed the cosine drift, and traced it back here. I didn't find a CONTRIBUTING.md, so I made my best guess at the codebase's style; happy to revise anything that doesn't fit your conventions.