GPU-accelerated matrix exponential and related functions for Apple MLX.
MLX has no built-in expm. This package fills the gap.
| Function | Description | Algorithm |
|---|---|---|
expm(A) |
Matrix exponential | Scaling & squaring + [13/13] Pade |
expm_frechet(A, E) |
Frechet derivative of expm | Block-triangular (Van Loan) |
logm(A) |
Principal matrix logarithm | Inverse scaling & squaring |
sqrtm(A) |
Principal matrix square root | Denman-Beavers iteration |
pip install -e ".[dev]"import mlx.core as mx
from mlx_expm import expm, logm, sqrtm
# Quantum time evolution: U = exp(-i H t)
H = mx.array([[0.0, 1.0], [1.0, 0.0]]) # Pauli X
t = 0.5
U = expm(-1j * H * t)
# Matrix logarithm (inverse of expm)
A = mx.array([[1.0, 2.0], [0.0, 3.0]])
L = logm(expm(A)) # recovers A
# Matrix square root
S = sqrtm(mx.array([[4.0, 0.0], [0.0, 9.0]])) # [[2, 0], [0, 3]]
# Frechet derivative
A = mx.array([[1.0, 0.5], [0.0, 2.0]])
E = mx.array([[0.1, 0.0], [0.0, 0.1]])
expm_A, L = expm_frechet(A, E)- Quantum mechanics: Unitary evolution
U = exp(-iHt) - Control theory: State transition matrix
exp(At) - Differential equations: Matrix ODE solutions
- Neural ODEs: Continuous-time dynamics
- Lie groups: Exponential map on matrix Lie algebras
The same algorithm as scipy.linalg.expm (Higham 2005/2009):
- Scaling: Find
ssuch that||A / 2^s||_1 <= theta_13 = 5.37 - Pade: Evaluate the [13/13] rational approximant
R_{13}(A/2^s) - Squaring: Recover
exp(A) = R_{13}^{2^s}via repeated squaring
This requires 13 matrix multiplications + 1 linear solve — all on MLX GPU.
- Repeatedly compute square roots (via Denman-Beavers) until
||X - I||is small - Evaluate
log(I + E)via degree-8 Taylor series - Scale back by
2^s
Quadratically convergent iteration:
Y_{k+1} = (Y_k + Z_k^{-1}) / 2Z_{k+1} = (Z_k + Y_k^{-1}) / 2
Converges to Y -> A^{1/2}, Z -> A^{-1/2}.
python benchmark.py
python benchmark.py --sizes 32 64 128 256 512 --complexMeasured on Apple M1 Max (MLX 0.31.1, Python 3.11) vs scipy.linalg.expm:
| n | real speedup | complex speedup |
|---|---|---|
| 32 | 0.03x | 0.01x |
| 128 | 0.15x | 0.26x |
| 256 | 0.53x | 1.96x |
| 512 | 0.54x | 1.33x |
| 1024 | 2.08x | - |
| 2048 | 1.87x | - |
Crossover is around n=1024 real / n=256 complex; below that, SciPy's CPU LAPACK wins on dispatch overhead. Max abs error stays at the float32 noise floor (~1e-6 * n). See benchmark_results.md for full numbers and discussion.
pytest tests/ -v- Higham, Functions of Matrices, SIAM, 2008.
- Al-Mohy & Higham, "A New Scaling and Squaring Algorithm for the Matrix Exponential", SIAM J. Matrix Anal. Appl. 31(3), 2009.
- Al-Mohy & Higham, "Computing the Frechet Derivative of the Matrix Exponential", SIAM J. Matrix Anal. Appl. 30(4), 2009.
MIT — Sheng-Kai Huang