Skip to content

Dispatch sentence-transformers pooling for bert + xlm_roberta (fix CLS-pooled checkpoints)#63

Merged
Blaizzy merged 4 commits into
Blaizzy:mainfrom
leonnoirclerc:feat/sentence-transformers-pooling-dispatch
May 13, 2026
Merged

Dispatch sentence-transformers pooling for bert + xlm_roberta (fix CLS-pooled checkpoints)#63
Blaizzy merged 4 commits into
Blaizzy:mainfrom
leonnoirclerc:feat/sentence-transformers-pooling-dispatch

Conversation

@leonnoirclerc
Copy link
Copy Markdown
Contributor

The bug

Some models that look like they're supported by mlx-embeddings actually produce wrong embeddings. The model loads fine and text_embeds comes out the right shape, but the values diverge from what sentence-transformers would produce for the same checkpoint.

Concretely:

  • BAAI/bge-base-en-v1.5 (BERT backbone)
  • Snowflake/snowflake-arctic-embed-l-v2.0 (XLM-RoBERTa backbone)

Both checkpoints use CLS pooling per their 1_Pooling/config.json. But bert.py and xlm_roberta.py hard-code mean_pooling(...) for text_embeds regardless of what the checkpoint says. So every CLS-pooled sentence-transformers model (the whole BGE family, Snowflake's arctic-embed family, mixedbread-ai/mxbai-embed-*, etc.) silently returns mean-pooled vectors.

The fix

  1. Read 1_Pooling/config.json from the model directory at load time and surface it on ModelArgs.pooling_config. Falls back to mean pooling when no 1_Pooling/config.json is present (i.e. for plain HF checkpoints), so nothing that worked before breaks.
  2. Add a small dispatcher (pool_by_config) that picks cls_pooling / mean_pooling / max_pooling / lasttoken_pooling based on the loaded mode.
  3. Wire the dispatcher into bert.py and xlm_roberta.py in place of the hard-coded mean_pooling. Other model files are unaffected by the bug (different defaults or different pooling expectations), so they're left as-is.
  4. Preserve 1_Pooling/ in convert.py so checkpoints uploaded to mlx-community/* via the conversion CLI keep their pooling sidecar (the top-level *.json glob didn't recurse into subdirectories, so pre-converted variants on the Hub today are missing it).

The pool functions, the legacy-config conversion logic, and the unit tests are translated from sentence-transformers.

Verification

I ran a small benchmark on a few inputs, comparing the embeddings produced by mlx-embeddings against those from SentenceTransformer.encode(...) (cosine similarity):

Model Backbone Before After
BAAI/bge-base-en-v1.5 bert 0.9587 1.000000
Snowflake/snowflake-arctic-embed-l-v2.0 xlm_roberta 0.8156 1.000000

50/50 unit tests pass (pytest mlx_embeddings/tests/).

Note about existing users

If you've stored embeddings produced by mlx-embeddings against BAAI/bge-*, Snowflake/snowflake-arctic-embed-*, mixedbread-ai/mxbai-embed-*, or any other CLS-pooled sentence-transformers checkpoint, those vectors will change after this PR lands.

First PR note

This is my first PR to this repo — I bumped into the bug while using bge-base-en-v1.5 for a retrieval pipeline, noticed the cosine drift, and traced it back here. I didn't find a CONTRIBUTING.md, so I made my best guess at the codebase's style; happy to revise anything that doesn't fit your conventions.

leonnoirclerc and others added 2 commits May 12, 2026 21:26
bert.py and xlm_roberta.py both hard-coded `mean_pooling` for `text_embeds`,
ignoring sentence-transformers' `1_Pooling/config.json`. CLS-pooled checkpoints
(bge-base-en-v1.5, snowflake-arctic-embed-l-v2.0, the rest of the bge family,
mxbai, …) silently returned mean-pooled vectors instead of CLS, with measurable
cosine drift against the SentenceTransformer reference.

Changes:

- base.py: add cls_pooling, max_pooling, lasttoken_pooling translated from
  sentence_transformers/sentence_transformer/modules/pooling.py @ 8151750.
  Add `_normalize_pooling_config` (port of `_convert_legacy_pooling_kwargs`)
  and `pool_by_config` dispatcher. Modes outside {cls, mean, max, lasttoken}
  raise NotImplementedError; tuple multi-mode and include_prompt=False raise
  too. Empirical mode coverage: ~100% of top-60 ST checkpoints on the Hub.

- utils.py: `_read_pooling_config` loads `1_Pooling/config.json` when present
  and injects it into the config dict, so the existing `model_config` override
  mechanism keeps "caller wins" precedence intact.

- bert.py / xlm_roberta.py: add `pooling_config` field on ModelArgs with
  default `{"pooling_mode": "mean"}` (visible in the dataclass signature),
  swap the hard-coded `mean_pooling` for `pool_by_config(...)`.

- tests/test_base.py: ports of the five HF unit tests that map to our
  supported surface — cls right-pad, cls left-pad, max-respects-mask,
  lasttoken finds last, lasttoken all-padding-zeros — plus the gold-standard
  `test_pooling_exact_values` (HF's shared fixture, all four supported modes)
  and the two `_convert_legacy_pooling_kwargs` conversion tests. Each carries
  a line-pinned reference back to the HF source at commit 8151750.

Verified end-to-end with the mlx-embeddings-tests harness:

  bge-base-en-v1.5 (bert, CLS):                cos_sim 0.9587 -> 1.000000
  snowflake-arctic-embed-l-v2.0 (xlm-r, CLS):  cos_sim 0.8156 -> 1.000000
  all-MiniLM-L6-v2 (bert, MEAN, control):      cos_sim 1.000000 (unchanged)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…vert

- Move pool helpers + dispatcher from `models/base.py` to a new
  `models/pooling.py`. `base.py` goes back to its pre-PR shape (data
  classes + `normalize_embeddings`).
- Update imports across all model files that use `mean_pooling`:
  `bert`, `xlm_roberta`, `modernbert`, `gemma3_text`, `llama_bidirec`,
  `lfm2`, `llama_nemotron_vl`.
- Tighten `pool_by_config` / `_normalize_pooling_config` signatures from
  `Optional[Dict[str, Any]]` to `Dict[str, Any]`; remove the now-unreachable
  `if cfg is None` fallback (the dataclass default factory guarantees a
  dict reaches the dispatcher).
- Tighten `bert.ModelArgs.pooling_config` and `xlm_roberta.ModelArgs.pooling_config`
  from `Optional[dict]` to `dict`.
- `convert.py`: preserve `1_Pooling/` subdirectory when converting an HF
  checkpoint. The top-level `*.json` glob doesn't recurse; without this,
  converted `mlx-community/*` variants lose the pooling sidecar and the
  loader silently falls back to mean.
- Split `tests/test_base.py` -> `tests/test_pooling.py` for the new
  pool-helper / dispatcher / config-normalization tests; replace
  per-test "Port of test_xxx" annotations with the upstream comments
  verbatim and add the two HF-port dispatcher tests
  (`test_forward_all_modes`, `test_invalid_mode_raises`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! 🚀

@Blaizzy
Copy link
Copy Markdown
Owner

Blaizzy commented May 13, 2026

Could you run pre-commit and push changes ?

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@leonnoirclerc
Copy link
Copy Markdown
Contributor Author

Pre-commit applied

@Blaizzy Blaizzy merged commit 9b28270 into Blaizzy:main May 13, 2026
1 check passed
@leonnoirclerc leonnoirclerc deleted the feat/sentence-transformers-pooling-dispatch branch May 14, 2026 07:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants