Skip to content

fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss#263

Merged
a710128 merged 2 commits intoOpenBMB:mainfrom
Oumnya:fix/mps-bf16-dtype
Apr 21, 2026
Merged

fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss#263
a710128 merged 2 commits intoOpenBMB:mainfrom
Oumnya:fix/mps-bf16-dtype

Conversation

@Oumnya
Copy link
Copy Markdown
Contributor

@Oumnya Oumnya commented Apr 15, 2026

Summary

Following commit e4e0496 which added MPS device routing, running with device=mps selects the checkpoint default bfloat16 on Apple Silicon. On Metal, bf16 introduces enough numerical drift in the diffusion AR loop that the synthesized audio is noticeably glitched and consistently trips the model's badcase detector, which then retries until the per-call retry budget is exhausted. In effect, MPS support is unusable in the default config — users either get garbled audio or a hung process.

This PR adds a single helper, pick_runtime_dtype(device, configured_dtype), in model/utils.py that promotes any low-precision dtype to float32 when the resolved device is mps. CUDA and CPU paths are untouched.

An opt-out env var VOXCPM_MPS_DTYPE lets advanced users force a specific dtype on MPS for future testing as PyTorch / macOS Metal bf16 support improves.

Both VoxCPMModel and VoxCPM2Model __init__ adopt the helper, replacing what would otherwise be duplicated inline checks.

Reproducer (before this patch)

from voxcpm import VoxCPM
m = VoxCPM.from_pretrained("openbmb/VoxCPM2", load_denoiser=False, optimize=False, device="mps")
wav = m.generate(text="Hello from Apple Silicon.", cfg_value=2.0, inference_timesteps=10)

Observed on Apple M5 Max, macOS 15, PyTorch 2.11:

  • Audio output is glitched / unintelligible
  • Badcase detected, audio_text_ratio=... retries fire repeatedly and exhaust retry_badcase_max_times
  • RTF inflated by ~2x due to retries even when output finally completes

Behavior table

device configured dtype resolved dtype notes
cpu bfloat16 bfloat16 unchanged
cuda bfloat16 bfloat16 unchanged
mps bfloat16 float32 new
mps float16 float32 new
mps float32 float32 unchanged
mps (with VOXCPM_MPS_DTYPE=bfloat16) any bfloat16 opt-out
mps (with VOXCPM_MPS_DTYPE=banana) any raises ValueError invalid

Verified locally

Apple M5 Max, 128 GB unified memory, PyTorch 2.11.0, macOS 15.

Model Device Outcome Steady-state RTF
VoxCPM 0.5B cpu clean 1.45
VoxCPM 0.5B mps (bf16, before) glitched + badcase loop 1.73
VoxCPM 0.5B mps (fp32, after) clean 0.92
VoxCPM2 2B mps (fp32, after) clean 0.78

Voice cloning, Voice Design, and Ultimate Cloning paths all verified end-to-end. Multilingual generation (12 languages) verified. No badcase retries fired in any post-patch run.

The VOXCPM_MPS_DTYPE=bfloat16 override round-trips: it reproduces the original glitched output, confirming the env path works and that the dtype is the actual cause.

Notes for review

  • The helper is intentionally minimal (single function, dtype constants, env override) so it is easy to extend if more devices need similar handling later.
  • I kept the call sites in VoxCPMModel.__init__ and VoxCPM2Model.__init__ rather than burying the promotion inside setup_cache so the user-facing log clearly shows the effective dtype on startup.
  • No README changes in this PR — happy to add a short Apple Silicon section in a follow-up if maintainers want it.

VoxCPM checkpoints default to bfloat16. Following commit e4e0496 which
added MPS device routing, running with `device=mps` selects bf16 on
Apple Silicon. On Metal, bf16 introduces enough numerical drift in the
diffusion AR loop that the synthesized audio is glitched and trips the
model's badcase detector, which retries until the per-call retry budget
is exhausted. Effectively MPS support is unusable in the default config.

This patch adds a single helper, `pick_runtime_dtype(device, dtype)`,
that promotes any low-precision dtype to float32 when the resolved
device is `mps`. CUDA and CPU paths are untouched. An opt-out env var
`VOXCPM_MPS_DTYPE` lets users force a specific dtype on MPS once future
PyTorch / macOS releases improve bf16 stability.

Both VoxCPMModel and VoxCPM2Model adopt the helper in their __init__,
replacing what would otherwise be duplicated inline checks.

Verified locally on Apple M5 Max, PyTorch 2.11, macOS 15:
- VoxCPM2 (2B): clean output, RTF ~0.78 steady state
- VoxCPM 0.5B: clean output, RTF ~0.92
- No badcase retries fired in any test
- VOXCPM_MPS_DTYPE=bfloat16 round-trips and reproduces the original
  glitched output, confirming the override path.
@a710128
Copy link
Copy Markdown
Contributor

a710128 commented Apr 21, 2026

I noticed a small inconsistency in the new MPS override path.

pick_runtime_dtype() accepts VOXCPM_MPS_DTYPE=half, but get_dtype() does not currently recognize "half" and will later raise ValueError: Unsupported dtype: half.

So the override validation is a bit broader than the actual dtype parser. It looks like this could be fixed either by:

  • adding "half" support in get_dtype(), or
  • removing "half" from the accepted override values.

Not blocking from a design point of view, but I think this should be aligned before merge to avoid a confusing runtime failure for users trying the override.

Drop "half" from _VALID_DTYPE_OVERRIDES / _LOW_PRECISION_DTYPES.
get_dtype() has never accepted "half", so VOXCPM_MPS_DTYPE=half would
pass override validation and then crash downstream with
"Unsupported dtype: half". The remaining aliases (bfloat16/bf16,
float16/fp16, float32/fp32) already cover the intended dtype space.

Adds a standalone unit check under scripts/ to guard the invariant
that every accepted override parses through get_dtype().

Addresses review feedback on OpenBMB#263.
@a710128
Copy link
Copy Markdown
Contributor

a710128 commented Apr 21, 2026

Looks good to me — the previous half inconsistency is fixed, and the added guard makes the dtype override behavior much clearer.

@a710128 a710128 merged commit cd79a64 into OpenBMB:main Apr 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants