Skip to content

Fix Apple Silicon MPS compatibility and refactor native dtype/device casting#151

Merged
sansiro77 merged 7 commits intoTuringQ:mainfrom
sansiro77:fix-mac-mps
Mar 19, 2026
Merged

Fix Apple Silicon MPS compatibility and refactor native dtype/device casting#151
sansiro77 merged 7 commits intoTuringQ:mainfrom
sansiro77:fix-mac-mps

Conversation

@sansiro77
Copy link
Copy Markdown
Contributor

@sansiro77 sansiro77 commented Mar 13, 2026

Description

This PR resolves the Apple Silicon MPS backend issues reported in #149 , and significantly refactors the internal device/dtype casting mechanism to align with PyTorch's native behaviors.

Key Changes

1. Fix Apple Silicon MPS Compatibility

  • Resolve Parametric Gates Stacking: Explicitly cast tensors to complex dtype before torch.stack in parametric gates (e.g., Rx). This fixes the cat_int32_t_float_float2 RuntimeError on the MPS backend.
  • Auto-fallback to CPU for >16 dimensions: Add a guard in the forward pass. If the state tensor exceeds the 16-dimension limit, it throws a clear warning and automatically falls back to CPU.

2. Refactor PyTorch Native Casting

  • Replace Custom .to() with ._apply(): Transition all custom nn.Modules from overriding .to() to using ._apply() to enable robust, native recursive graph traversal.
  • Implement Unified Complex Probing: Add apply_complex_fix inside ._apply() to safely map floating-point actions to their complex counterparts during device/dtype transitions.

Limitations & Known Issues

  • Gaussian Backend MPS Support: Due to current limitations in PyTorch's MPS backend, torch.linalg.det() for complex-valued matrices is not yet supported, resulting in RuntimeError: linalg.lu_factor(): MPS doesn't support complex types.

@Hugh-888
Copy link
Copy Markdown
Collaborator

Test examples:

  • Fock backend with basis=True
print("torch", torch.__version__)
print("mps available", torch.backends.mps.is_available())

nmode=4
cir = dq.QumodeCircuit(nmode=4, init_state='vac', backend='fock', basis=True, cutoff=2)
for i in range(nmode-1):
    cir.bs([i, i+1])
cir.homodyne(0)
cir.to("mps")
prob = cir(is_prob=True)

  • Fock backend with basis=False
print("torch", torch.__version__)
print("mps available", torch.backends.mps.is_available())

nmode=4
cir = dq.QumodeCircuit(nmode=4, init_state='vac', backend='fock', basis=False, cutoff=2)
for i in range(nmode-1):
    cir.bs([i, i+1])
cir.homodyne(0)
cir.to("mps")
prob = cir(is_prob=True)
  • Homodyne measurement for Fock backend with basis=False
print("torch", torch.__version__)
print("mps available", torch.backends.mps.is_available())

nmode=4
cir = dq.QumodeCircuit(nmode=4, init_state='vac', backend='fock', basis=False, cutoff=2)
for i in range(nmode-1):
    cir.bs([i, i+1])
cir.homodyne(0)
cir.to("mps")
state = cir()
samples = cir.measure_homodyne()
  • Gaussian backend
    Due to current limitations in PyTorch's MPS backend, torch.linalg.det() for complex-valued matrices is not yet supported, resulting in RuntimeError: linalg.lu_factor(): MPS doesn't support complex types.
print("torch", torch.__version__)
print("mps available", torch.backends.mps.is_available())

nmode=4
cir = dq.QumodeCircuit(nmode=4, init_state='vac', backend='Gaussian', basis=True, cutoff=2)
for i in range(nmode-1):
    cir.bs([i, i+1])
cir.homodyne(0)
cir.to("mps")
prob = cir(is_prob=True)
image

@sansiro77 sansiro77 merged commit ea5469f into TuringQ:main Mar 19, 2026
@sansiro77 sansiro77 deleted the fix-mac-mps branch March 19, 2026 07:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bugfix Fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Apple Silicon MPS issues on macOS: parametric gates fail due to mixed float/complex stack, and dense circuits fail around 16 qubits

2 participants