fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss#263
Merged
a710128 merged 2 commits intoOpenBMB:mainfrom Apr 21, 2026
Merged
fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss#263a710128 merged 2 commits intoOpenBMB:mainfrom
a710128 merged 2 commits intoOpenBMB:mainfrom
Conversation
VoxCPM checkpoints default to bfloat16. Following commit e4e0496 which added MPS device routing, running with `device=mps` selects bf16 on Apple Silicon. On Metal, bf16 introduces enough numerical drift in the diffusion AR loop that the synthesized audio is glitched and trips the model's badcase detector, which retries until the per-call retry budget is exhausted. Effectively MPS support is unusable in the default config. This patch adds a single helper, `pick_runtime_dtype(device, dtype)`, that promotes any low-precision dtype to float32 when the resolved device is `mps`. CUDA and CPU paths are untouched. An opt-out env var `VOXCPM_MPS_DTYPE` lets users force a specific dtype on MPS once future PyTorch / macOS releases improve bf16 stability. Both VoxCPMModel and VoxCPM2Model adopt the helper in their __init__, replacing what would otherwise be duplicated inline checks. Verified locally on Apple M5 Max, PyTorch 2.11, macOS 15: - VoxCPM2 (2B): clean output, RTF ~0.78 steady state - VoxCPM 0.5B: clean output, RTF ~0.92 - No badcase retries fired in any test - VOXCPM_MPS_DTYPE=bfloat16 round-trips and reproduces the original glitched output, confirming the override path.
8 tasks
Contributor
|
I noticed a small inconsistency in the new MPS override path.
So the override validation is a bit broader than the actual dtype parser. It looks like this could be fixed either by:
Not blocking from a design point of view, but I think this should be aligned before merge to avoid a confusing runtime failure for users trying the override. |
Drop "half" from _VALID_DTYPE_OVERRIDES / _LOW_PRECISION_DTYPES. get_dtype() has never accepted "half", so VOXCPM_MPS_DTYPE=half would pass override validation and then crash downstream with "Unsupported dtype: half". The remaining aliases (bfloat16/bf16, float16/fp16, float32/fp32) already cover the intended dtype space. Adds a standalone unit check under scripts/ to guard the invariant that every accepted override parses through get_dtype(). Addresses review feedback on OpenBMB#263.
a710128
approved these changes
Apr 21, 2026
Contributor
|
Looks good to me — the previous half inconsistency is fixed, and the added guard makes the dtype override behavior much clearer. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Following commit e4e0496 which added MPS device routing, running with
device=mpsselects the checkpoint defaultbfloat16on Apple Silicon. On Metal, bf16 introduces enough numerical drift in the diffusion AR loop that the synthesized audio is noticeably glitched and consistently trips the model's badcase detector, which then retries until the per-call retry budget is exhausted. In effect, MPS support is unusable in the default config — users either get garbled audio or a hung process.This PR adds a single helper,
pick_runtime_dtype(device, configured_dtype), inmodel/utils.pythat promotes any low-precision dtype tofloat32when the resolved device ismps. CUDA and CPU paths are untouched.An opt-out env var
VOXCPM_MPS_DTYPElets advanced users force a specific dtype on MPS for future testing as PyTorch / macOS Metal bf16 support improves.Both
VoxCPMModelandVoxCPM2Model__init__adopt the helper, replacing what would otherwise be duplicated inline checks.Reproducer (before this patch)
Observed on Apple M5 Max, macOS 15, PyTorch 2.11:
Badcase detected, audio_text_ratio=...retries fire repeatedly and exhaustretry_badcase_max_timesBehavior table
VOXCPM_MPS_DTYPE=bfloat16)VOXCPM_MPS_DTYPE=banana)ValueErrorVerified locally
Apple M5 Max, 128 GB unified memory, PyTorch 2.11.0, macOS 15.
Voice cloning, Voice Design, and Ultimate Cloning paths all verified end-to-end. Multilingual generation (12 languages) verified. No badcase retries fired in any post-patch run.
The
VOXCPM_MPS_DTYPE=bfloat16override round-trips: it reproduces the original glitched output, confirming the env path works and that the dtype is the actual cause.Notes for review
VoxCPMModel.__init__andVoxCPM2Model.__init__rather than burying the promotion insidesetup_cacheso the user-facing log clearly shows the effective dtype on startup.