S. R. Jha, S. Chowdhury, M. Hays, J. A. Grover, W. D. Oliver
Docs: equs.github.io/jaxquantum
Community Discord: discord.gg/frWqbjvZ4s
jaxquantum
leverages JAX to enable the auto differentiable and (CPU, GPU, TPU) accelerated simulation of quantum dynamical systems, including tooling such as operator construction, unitary evolution and master equation solving. As such, jaxquantum
serves as a QuTiP drop-in replacement written entirely in JAX.
This package also serves as an essential dependency for bosonic
and qcsys
. Together, these packages form an end-to-end toolkit for quantum circuit design, simulation and control.
Recommended: As this is a rapidly evolving project, we recommend installing the latest version of jaxquantum
from source as follows:
pip install git+https://github.com/EQuS/jaxquantum.git
If you are interested in contributing to the package, please clone this repository and install this package in editable mode after changing into the root directory of this repository:
pip install -e ".[dev,docs]"
This will also install extras from the dev
and docs
flags, which can be useful when developing the package. Since this is installed in editable mode, the package will automatically be updated after pulling new changes in the repository.
jaxquantum
is also published on PyPI. Simply run the following code to install the package:
pip install jaxquantum
For more details, please visit the getting started > installation section of our docs.
Here's an example of how to set up a simulation in jaxquantum.
from jax import jit
import jaxquantum as jqt
import jax.numpy as jnp
import matplotlib.pyplot as plt
omega_q = 5.0 #GHz
Omega = .1
g_state = jqt.basis(2,0) ^ jqt.basis(2,0)
g_state_dm = g_state.to_dm()
ts = jnp.linspace(0,5*jnp.pi/Omega,101)
c_ops = [0.1*jqt.sigmam()^jqt.identity(N=2)]
sz0 = jqt.sigmaz() ^ jqt.identity(N=2)
@jit
def Ht(t):
H0 = omega_q/2.0*((jqt.sigmaz()^jqt.identity(N=2)) + (jqt.identity(N=2)^jqt.sigmaz()))
H1 = Omega*jnp.cos((omega_q)*t)*((jqt.sigmax()^jqt.identity(N=2)) + (jqt.identity(N=2)^jqt.sigmax()))
return H0 + H1
states = jqt.mesolve(g_state_dm, ts, c_ops=c_ops, Ht=Ht)
szt = jnp.real(jqt.calc_expect(sz0, states))
fig, ax = plt.subplots(1, dpi=200, figsize=(4,3))
ax.plot(ts, szt)
ax.set_xlabel("Time (ns)")
ax.set_ylabel("<σz(t)>")
fig.tight_layout()
Core Devs: Shantanu A. Jha, Shoumik Chowdhury
This package was initially a small part of bosonic
. In early 2022, jaxquantum
was extracted and made into its own package. This package was briefly announced to the world at APS March Meeting 2023 and released to a select few academic groups shortly after. Since then, this package has been open sourced and developed while conducting research in the Engineering Quantum Systems Group at MIT with invaluable advice from Prof. William D. Oliver.
Thank you for taking the time to try our package out. If you found it useful in your research, please cite us as follows:
@software{jha2024jaxquantum,
author = {Shantanu R. Jha and Shoumik Chowdhury and Max Hays and Jeff A. Grover and William D. Oliver},
title = {An auto differentiable and hardware accelerated software toolkit for quantum circuit design, simulation and control},
url = {https://github.com/EQuS/jaxquantum, https://github.com/EQuS/bosonic, https://github.com/EQuS/qcsys},
version = {0.1.0},
year = {2024},
}
S. R. Jha, S. Chowdhury, M. Hays, J. A. Grover, W. D. Oliver. An auto differentiable and hardware accelerated software toolkit for quantum circuit design, simulation and control (2024), in preparation.
This package is open source and, as such, very open to contributions. Please don't hesitate to open an issue, report a bug, request a feature, or create a pull request. We are also open to deeper collaborations to create a tool that is more useful for everyone. If a discussion would be helpful, please email shanjha@mit.edu to set up a meeting.