In [1]:
from jax.experimental import sparse
import jax.numpy as jnp
import jax.scipy as jsp
import jaxquantum as jqt

In [None]:
a = jqt.destroy(100).data
vac = jqt.basis(100, 1).data

a_sp = sparse.BCOO.fromdense(a)
vac_sp = sparse.BCOO.fromdense(vac)

In [3]:
%timeit -n1 -r1 a @ vac
%timeit -n1 -r1 a_sp @ vac_sp

%timeit a @ vac
%timeit a_sp @ vac_sp

14.6 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
724 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
19.6 µs ± 3.42 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.08 ms ± 422 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
sparse.sparsify(jnp.kron)(a_sp, a_sp).todense() - jnp.kron(a, a)

Array([[0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       ...,
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j]],      dtype=complex128)

In [13]:
sparse.sparsify(jnp.conj)(a_sp.T).todense() - jnp.conj(a.T)

Array([[0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       ...,
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j]],      dtype=complex128)

In [26]:
sparse.sparsify(jsp.linalg.exp)(a_sp)

AttributeError: module 'jax.scipy.linalg' has no attribute 'exp'

In [32]:
sparse.linalg._svqb(a_sp)

TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'jax.experimental.sparse.bcoo.BCOO'> at position 0.

In [4]:
jnp.kron(a_sp,a_sp)

TypeError: kron requires ndarray or scalar arguments, got <class 'jax.experimental.sparse.bcoo.BCOO'> at position 0.

In [44]:
a_sp.shape

(100, 100)

In [45]:
import flax.struct as struct
import jax.numpy as jnp
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Union

# ------------------------------
# Abstract Implementation
# ------------------------------
class QarrayImpl(ABC):
    @abstractmethod
    def matmul(self, other: "QarrayImpl") -> "QarrayImpl":
        ...

    @abstractmethod
    def to_dense(self):
        ...


# ------------------------------
# Concrete implementations
# ------------------------------
class DenseImpl(QarrayImpl):
    def __init__(self, data):
        self.data = jnp.asarray(data)

    def matmul(self, other: "DenseImpl") -> "DenseImpl":
        return DenseImpl(self.data @ other.data)

    def to_dense(self):
        return self.data


class SparseImpl(QarrayImpl):
    def __init__(self, data):
        self.data = data  # e.g., jax.experimental.sparse.BCOO

    def matmul(self, other: "SparseImpl") -> "SparseImpl":
        return SparseImpl(self.data @ other.data)

    def to_dense(self):
        return self.data.todense()


# ------------------------------
# Type variable for typing clarity
# ------------------------------
ImplT = TypeVar("ImplT", bound=QarrayImpl)


# ------------------------------
# Public Qarray class
# ------------------------------
@struct.dataclass
class Qarray(Generic[ImplT]):
    impl: ImplT

    @classmethod
    def from_dense(cls, data) -> "Qarray[DenseImpl]":
        return cls(DenseImpl(data))

    @classmethod
    def from_sparse(cls, data) -> "Qarray[SparseImpl]":
        return cls(SparseImpl(data))

    def matmul(self: "Qarray[ImplT]", other: "Qarray[ImplT]") -> "Qarray[ImplT]":
        """For now: only allow matching implementations (Dense@Dense, Sparse@Sparse)."""
        return Qarray(self.impl.matmul(other.impl))  # type: ignore

    def to_dense(self):
        return self.impl.to_dense()


In [47]:
A: Qarray[DenseImpl] = Qarray.from_dense(jnp.eye(2))
B: Qarray[DenseImpl] = Qarray.from_dense(jnp.array([[0, 1], [1, 0]]))

C = A.matmul(B)  # inferred as Qarray[DenseImpl]
# type(C)   # -> Qarray[DenseImpl]

from jax.experimental import sparse
sp = sparse.BCOO.fromdense(jnp.eye(2))

S: Qarray[SparseImpl] = Qarray.from_sparse(sp)
T: Qarray[SparseImpl] = Qarray.from_sparse(sp)

U = S.matmul(T)  # inferred as Qarray[SparseImpl]


In [54]:
from jax import Array
isinstance(U.impl.data, Array)

False

In [55]:
U.impl.data

BCOO(float64[2, 2], nse=4)

In [8]:
# filename: diffrax_jax_sparse_schrodinger.py
# Requirements: jax, diffrax
# (pip install "jax[cpu]" diffrax)  -- or the GPU jax builds if you have one.

import jax
import jax.numpy as jnp
from jax.experimental import sparse
import diffrax

# Optional: enable 64-bit if you want double precision
# from jax.config import config
# config.update("jax_enable_x64", True)

def make_1d_tight_binding_hamiltonian(n_sites: int, hopping: float = -1.0):
    """Return a dense Hamiltonian for convenience; we'll convert to sparse BCOO below."""
    H = jnp.zeros((n_sites, n_sites), dtype=jnp.complex128)
    # nearest-neighbor hopping (tridiagonal)
    offsets = jnp.arange(n_sites - 1)
    H = H.at[offsets, offsets + 1].set(hopping)
    H = H.at[offsets + 1, offsets].set(hopping)
    # optionally add on-site potential (zeros here)
    return H

def main():
    n = 200  # number of sites (make larger to benefit from sparsity)
    H_dense = make_1d_tight_binding_hamiltonian(n)

    # Convert to sparse BCOO (compresses zeros)
    H_bcoo = sparse.BCOO.fromdense(H_dense)   # experimental sparse type in JAX. :contentReference[oaicite:1]{index=1}
    # H_bcoo = H_dense

    # Define the vector field for Schrödinger eqn:
    # i dψ/dt = H ψ  =>  dψ/dt = -i H ψ
    def vector_field(t, y, args):
        # y is complex vector shape (n,)
        # Use sparse matvec to compute H @ y -> dense result
        # For BCOO, the @ operator works as matmul-like; alternatively use sparse.matmul(...)
        Hy = H_bcoo @ y                     # returns a dense vector
        return -1j * Hy                     # RHS is complex

    term = diffrax.ODETerm(vector_field)

    # initial state: localized at center
    y0 = jnp.zeros((n,), dtype=jnp.complex128)
    y0 = y0.at[n // 2].set(1.0 + 0.0j)

    solver = diffrax.Tsit5()  # explicit solver OK for this non-stiff unitary evolution
    saveat = diffrax.SaveAt(t0=True, t1=True)  # just save start and end
    # set dt0 to a fixed initial step size to speed up compilation (optional)
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0.0,
        t1=50.0,
        dt0=0.1,
        y0=y0,
        saveat=saveat,
    )

    psi_final = sol.ys[-1]  # if you saved t0 and t1, ys will be [y(t0), y(t1)] (see docs). :contentReference[oaicite:2]{index=2}
    # check unitarity (norm conservation)
    norm0 = jnp.linalg.norm(y0)
    norm_final = jnp.linalg.norm(psi_final)
    print("||psi(t0)|| =", norm0)
    print("||psi(t1)|| =", norm_final)


In [7]:
main()

  out = fun(*args, **kwargs)


||psi(t0)|| = 1.0
||psi(t1)|| = 0.9999996348072795


In [3]:
from jax import jit
import jaxquantum as jqt 
import jax.numpy as jnp
import matplotlib.pyplot as plt

N = 100
a = jqt.destroy(N); n = a.dag() @ a

omega_a = 2.0*jnp.pi*5.0; H0 = omega_a*n # Hamiltonian

kappa = 2*jnp.pi*jnp.array([1,2]); batched_loss_op = jnp.sqrt(kappa)*a; 
c_ops = jqt.Qarray.from_list([batched_loss_op]) # collapse operators

initial_state = (jqt.displace(N, 0.1) @ jqt.basis(N,0)).to_dm() # initial state

ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101) # Time points

solver_options = jqt.SolverOptions.create(progress_meter=True) 
states = jit(jqt.mesolve, static_argnums=(5))(
    H0, initial_state, ts, c_ops=c_ops, solver_options=solver_options) # solve

n_exp = jnp.real(jqt.overlap(n, states)); a_exp = jqt.overlap(a, states) # expectation values

# Plot

fig, axs = plt.subplots(2,1, dpi=200, figsize=(6,5))
ax = axs[0]
ax.plot(ts, jnp.real(a_exp)[:,0], label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 0
ax.plot(ts, jnp.real(a_exp)[:,1], "--", label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 1
ax.plot(ts, jnp.imag(a_exp)[:,0], label=r"$Im[\langle a(t)\rangle]$", color="red") # Batch kappa value 0
ax.plot(ts, jnp.imag(a_exp)[:,1], "--", label=r"$Im[\langle a(t)\rangle]$", color="red") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()

ax = axs[1]
ax.plot(ts, n_exp[:,0], label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 0
ax.plot(ts, n_exp[:,1], "--", label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()
fig.tight_layout()

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int64[].

The error arose for the nse argument of bcoo_fromdense. In order for
BCOO.fromdense() to be used in traced/compiled code, you must pass a concrete
value to the nse (number of stored elements) argument.

The error occurred while tracing the function mesolve at /Users/phionx/Github/qc/EQuS/bosonic/jax/jaxquantum/jaxquantum/core/solvers.py:109 for jit. This concrete value was not available in Python because it depends on the value of the argument c_ops._data.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [2]:
from jax import jit
import jaxquantum as jqt 
import jax.numpy as jnp
import matplotlib.pyplot as plt

N = 100
a = jqt.destroy(N); n = a.dag() @ a

omega_a = 2.0*jnp.pi*5.0; H0 = omega_a*n # Hamiltonian

kappa = 2*jnp.pi*1.0; batched_loss_op = jnp.sqrt(kappa)*a; 
c_ops = jqt.Qarray.from_list([batched_loss_op]) # collapse operators

initial_state = (jqt.displace(N, 0.1) @ jqt.basis(N,0)).to_dm() # initial state

ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101) # Time points

solver_options = jqt.SolverOptions.create(progress_meter=True) 
states = jqt.mesolve(
    H0, initial_state, ts, c_ops=c_ops, solver_options=solver_options) # solve

n_exp = jnp.real(jqt.overlap(n, states)); a_exp = jqt.overlap(a, states) # expectation values

# Plot

fig, axs = plt.subplots(2,1, dpi=200, figsize=(6,5))
ax = axs[0]
ax.plot(ts, jnp.real(a_exp)[:], label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 0
ax.plot(ts, jnp.imag(a_exp)[:], label=r"$Im[\langle a(t)\rangle]$", color="red") # Batch kappa value 0
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()

ax = axs[1]
ax.plot(ts, n_exp[:], label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 0
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()
fig.tight_layout()

ValueError: Terms are not compatible with solver!

In [1]:
%load_ext autoreload
%autoreload 2

from jax import jit
import jaxquantum as jqt 
import jax.numpy as jnp
import matplotlib.pyplot as plt

N = 100
a = jqt.destroy(N); n = a.dag() @ a

omega_a = 2.0*jnp.pi*5.0; H0 = omega_a*n # Hamiltonian

kappa = 2*jnp.pi*1.0; batched_loss_op = jnp.sqrt(kappa)*a; 
c_ops = jqt.Qarray.from_list([batched_loss_op]) # collapse operators

initial_state = (jqt.displace(N, 0.1) @ jqt.basis(N,0)) # initial state

ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101) # Time points

solver_options = jqt.SolverOptions.create(progress_meter=True, solver="Dopri5") 
states = jqt.sesolve(
    H0, initial_state, ts, solver_options=solver_options) # solve

n_exp = jnp.real(jqt.overlap(n, states)); a_exp = jqt.overlap(a, states) # expectation values

# Plot

fig, axs = plt.subplots(2,1, dpi=200, figsize=(6,5))
ax = axs[0]
ax.plot(ts, jnp.real(a_exp)[:], label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 0
ax.plot(ts, jnp.imag(a_exp)[:], label=r"$Im[\langle a(t)\rangle]$", color="red") # Batch kappa value 0
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()

ax = axs[1]
ax.plot(ts, n_exp[:], label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 0
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()
fig.tight_layout()

ValueError: Terms are not compatible with solver!

In [None]:
import jax
import jax.numpy as jnp
import jax.experimental.sparse as jsparse
import diffrax

# Build sparse Hamiltonian (1D Laplacian with Dirichlet BCs)
def make_hamiltonian(N, L=1.0):
    dx = L / (N+1)
    diag = -2.0 * jnp.ones(N)
    offdiag = jnp.ones(N-1)
    H_dense = (1.0 / (2.0 * dx * dx)) * (
        jnp.diag(diag) + jnp.diag(offdiag, 1) + jnp.diag(offdiag, -1)
    )
    return jsparse.BCOO.fromdense(H_dense)

# RHS of the TDSE: dψ/dt = -i H ψ
def schrodinger_rhs(t, psi, H):
    return -1j * (H @ psi)

# Parameters
N = 100
L = 1.0
H_sparse = make_hamiltonian(N, L)

# Initial wavefunction: Gaussian packet centered in the box
x = jnp.linspace(0, L, N+2)[1:-1]  # exclude boundaries
psi0 = jnp.exp(-200 * (x - 0.5*L)**2) + 0j

# Setup ODE solver
term = diffrax.ODETerm(schrodinger_rhs)
solver = diffrax.Dopri5()

sol = diffrax.diffeqsolve(
    term,
    solver,
    t0=0.0,
    t1=0.05,
    dt0=1e-4,
    y0=psi0,
    args=H_sparse,
    saveat=diffrax.SaveAt(ts=jnp.linspace(0, 0.05, 20)),
)

# sol.ys is an array of shape (20, N) with the wavefunction at each time


  out = fun(*args, **kwargs)
