# Introduction to QAX

Author: R.Maekura (ryomaekura@g.ecc.u-tokyo.ac.jp)

## Import package

In [6]:
import jax
import jax.numpy as jnp

import sys
sys.path.append("/home/users/u0001529/ondemand/qax-project")

%load_ext autoreload
%autoreload 2
from qax import state
from qax import operator as op
from qax.utils import linalg
from qax.utils import device

## Quantum states

The matrix representations of commonly used quantum states can be accessed via `qax.state`. \
In addition, basic information such as the shape and type of an operator can be displayed using `qax.utils.display.info`

### state vector

In [7]:
# initial n-qubits state
n = 2
qubit_state = state.n_qubit(n)
qubit_state.info()

Shape: (4,)
Dtype: complex64
Norm: 1.000000


<IPython.core.display.Math object>

In [12]:
# vaccum state
sys_dim = 10
vac_state = state.vacuum(sys_dim)
vac_state.info()

Shape: (10,)
Dtype: complex64
Norm: 1.000000


<IPython.core.display.Math object>

In [9]:
# fock state
sys_dim = 10
n = 5
fock_state = state.fock(n, sys_dim)
fock_state.info()

Shape: (10,)
Dtype: complex64
Norm: 1.000000


<IPython.core.display.Math object>

### density matrix

In QAX, the function `qax.utils.linalg.vec2dm()` is used to convert the state vector into a density matrix.

In [10]:
# density matrix of initial n-qubits state
qubit_dm = linalg.vec2dm(qubit_state)
qubit_dm.info()

Shape: (4, 4)
Dtype: complex64
Hermitian: True


<IPython.core.display.Math object>

In [11]:
# density matrix of fock state
fock_dm = linalg.vec2dm(fock_state)
fock_dm.info()

Shape: (10, 10)
Dtype: complex64
Hermitian: True


<IPython.core.display.Math object>

## Quantum operators

### 2-level system

In [3]:
# pauli operators
## pauli x
pauli_x = op.pauli_x()
pauli_x.info()

Shape: (2, 2)
Dtype: complex64
Is Hermitian: True


<IPython.core.display.Math object>

In [4]:
## pauli y
pauli_y = op.pauli_y()
pauli_y.info()

Shape: (2, 2)
Dtype: complex64
Is Hermitian: True


<IPython.core.display.Math object>

In [5]:
## pauli z
pauli_z = op.pauli_z()
pauli_z.info()

Shape: (2, 2)
Dtype: complex64
Is Hermitian: True


<IPython.core.display.Math object>

### finite-level system

In [3]:
# annihilation operator
sys_dim = 5
a = op.annihilation(sys_dim)
a.info()

Shape: (5, 5)
Dtype: complex64
Hermitian: False


<IPython.core.display.Math object>

In [4]:
# creation operator
a_dag = op.annihilation(sys_dim).dagger()
a_dag.info()

Shape: (5, 5)
Dtype: complex64
Hermitian: False


<IPython.core.display.Math object>

In [6]:
# position operator
x = op.position(sys_dim)
x.info()

Shape: (5, 5)
Dtype: complex64
Hermitian: True


<IPython.core.display.Math object>

In [7]:
# momentum operator
p = op.momentum(sys_dim)
p.info()

Shape: (5, 5)
Dtype: complex64
Hermitian: True


<IPython.core.display.Math object>

## Time evolution

As a demonstration, the time evolution of the system is calculated using the following Jaynes–Cummings Hamiltonian.
$$
\begin{align*}
H &= \omega_{\rm c} \hat{a}^\dagger \hat{a} - \frac{1}{2} \omega_{\rm a} \sigma_{\rm z} + g \left( \hat{a} \sigma_+ + \hat{a}^\dagger \sigma_- \right)
\end{align*}
$$

In [None]:
# Jaynes-Cumming Hamiltonian
## physical parameter settings
omega_c = 1.0  
omega_a = 1.0  
g = 0.1

## operator settings |atom_state, cavity_state>
cavity_dim = 100
atom_dim = 2

pauli_z = linalg.kron_prod(op.pauli_z(), op.identity(cavity_dim))
a = linalg.kron_prod(op.identity(atom_dim), op.creation(cavity_dim))
a_dag = linalg.kron_prod(op.identity(atom_dim), op.annihilation(cavity_dim))

H = omega_c*a@a_dag - 0.5*omega_a*pauli_z + g*(jnp.dot()+jnp.dot())

### time-dependent Schrödinger equation

$$
\begin{align*}
ih \frac{\partial}{\partial t} \ket{\Psi} &= \hat{H} \ket{\Psi}
\end{align*}
$$

In [None]:
from qax.dynamics.solver import schrodinger
from qax.utils import linalg

In [None]:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

class SchrodingerEquation:
    '''
    Compute time-dependent Schrödinger equation.
    '''
    def __init__(self, 
                 hbar = 1,
                 H: Operator, 
                 init_state: StateVector, 
                 t_range: tuple,
                 t_step: float,
                 device: str) -> None:
        '''
        Initializes the solver for the Schrödinger equation.

        Args:
            Hamiltonian (Operator | Callable): The system's Hamiltonian. Can be a static Operator or a time-dependent function H(t).
            initial_state (StateVector): The initial state vector at t0.
            t_span (tuple[float, float]): The start and end time (t0, t1).
            hbar (float, optional): Planck's constant. Defaults to 1.0.
            target_device (str | jax.Device | None, optional): The device to run the computation on ('cpu', 'gpu', etc.).\
                                                               If None, JAX's default device is used. Defaults to None.
        '''
        self.H = H

    @jax.jit
    def run(self, ) -> jax.Array:
        '''
        
        '''
        rhs = lambda t, y, args: -1.0j * jnp.dot(H, y)
        eq_term = ODETerm(rhs)
        method = Dopri5()
        return 

    @jax.jit
    def observable_ev(self, obs: jax.Array) -> jax.Array:
        '''
        Compute expectation value of observable.
        '''
        return 

In [None]:
class LindbladMasterEquation:
    def __init__(self,
                 hbar = 1,
                 H: Operator,
                 collapse_ops: ,
                 t_range: tuple,
                 t_step: int,
                 device: str):

    def solver(self, ):
        rhs = lambda t, y, args: -

In [12]:
@jax.jit
def heavy_computation(x: jax.Array) -> jax.Array:
    for _ in range(5):
        x = jnp.tanh(jnp.dot(x, x.T) / 1000.0)
    return x.mean()

# デバイス選択ロジックを分離した、メインの実行関数
def run_computation(data: jax.Array, on_device: str):
    """
    device.pyを使ってデバイスを取得し、計算を実行する。
    """
    try:
        # device.pyの関数でデバイスオブジェクトを取得
        target_device = device.get_device(on_device)
        print(f"--- Running computation on: {target_device} ---")
    except ValueError as e:
        print(e)
        return

    # withブロックで計算を実行
    with jax.default_device(target_device):
        data_on_device = jax.device_put(data)
        
        start_time = time.time()
        result = heavy_computation(data_on_device).block_until_ready()
        end_time = time.time()
        
        print(f"Computation finished in {end_time - start_time:.4f} seconds.")
        print(f"Result: {result}")

# --- 実行 ---
key = jax.random.PRNGKey(0)
large_array = jax.random.normal(key, (20000, 20000))

# 引数で'cpu'や'gpu'を切り替えるだけで、実行デバイスを制御できる
run_computation(large_array, on_device='cpu')
print("-" * 20)
run_computation(large_array, on_device='gpu') # GPUがなければエラーメッセージが表示される

--- Running computation on: TFRT_CPU_0 ---
Computation finished in 15.6389 seconds.
Result: 7.222556434992103e-24
--------------------
--- Running computation on: cuda:0 ---
Computation finished in 0.2711 seconds.
Result: 7.211465050664695e-24


In [1]:
import jax
import jax.numpy as jnp
import time

# JITでコンパイルする計算の本体（前回の修正版）
@jax.jit
def heavy_computation(x: jax.Array) -> jax.Array:
    """A sample heavy computation that is numerically stable."""
    for _ in range(5):
        x = jnp.tanh(jnp.dot(x, x.T) / 1000.0)
    return x.mean()


# デバイスを選択して計算を実行するメインの関数
def run_computation_on_device(data: jax.Array, device_str: str) -> jax.Array | None:
    """
    指定されたデバイスで計算を実行し、結果を返す関数。
    """
    print(f"--- Computation requested on: '{device_str}' ---")
    try:
        target_device = jax.devices(device_str)[0]
        print(f"Found device: {target_device}")
    except IndexError:
        print(f"Error: Device '{device_str}' not found.")
        print(f"Available devices are: {jax.devices()}")
        return None

    with jax.default_device(target_device):
        print(f"Placing data onto {target_device}...")
        data_on_device = jax.device_put(data)
        
        # ★★★ ここが修正点 ★★★
        # 古い .device() メソッドから、新しい .devices() メソッドを使った形式に変更
        device_set = data_on_device.devices()
        if device_set:
            print(f"Data is now on: {list(device_set)[0]}")
        else:
            print("Could not determine the device for the data.")

        print("Running computation...")
        start_time = time.time()
        result = heavy_computation(data_on_device).block_until_ready()
        end_time = time.time()
        
        duration = end_time - start_time
        print(f"Computation finished in {duration:.4f} seconds.")

    return result

# --- 関数の使い方 ---

# 大きなサンプルデータを作成
key = jax.random.PRNGKey(0)
large_array = jax.random.normal(key, (20000, 20000))

# CPUで実行
result_cpu = run_computation_on_device(large_array, 'cpu')
if result_cpu is not None:
    print(f"Result on CPU: {result_cpu}\n")

# GPUで実行 (GPUが利用可能な場合)
result_gpu = run_computation_on_device(large_array, 'gpu')
if result_gpu is not None:
    print(f"Result on GPU: {result_gpu}\n")

--- Computation requested on: 'cpu' ---
Found device: TFRT_CPU_0
Placing data onto TFRT_CPU_0...
Data is now on: cuda:0
Running computation...
Computation finished in 16.0898 seconds.
Result on CPU: 7.222556434992103e-24

--- Computation requested on: 'gpu' ---
Found device: cuda:0
Placing data onto cuda:0...
Data is now on: cuda:0
Running computation...
Computation finished in 48.1150 seconds.
Result on GPU: 7.211465050664695e-24



In [20]:
import jax
import jax.numpy as jnp
import time

# GPUが利用可能か確認し、デバイスオブジェクトを取得
try:
    gpu = jax.devices('gpu')[0]
    print(f"GPUが見つかりました: {gpu}")
except IndexError:
    print("GPUが見つかりませんでした。この例はCPUで実行されます。")
    gpu = jax.devices('cpu')[0]

key = jax.random.PRNGKey(0)

# 大きな行列を準備
A = jax.random.normal(key, (2000, 3000))
B = jax.random.normal(jax.random.split(key)[0], (3000, 2500))

# データを明示的にターゲットデバイスに配置
A_on_device = jax.device_put(A, gpu)
B_on_device = jax.device_put(B, gpu)
print(f"入力データはデバイス {list(A_on_device.devices())[0]} に配置されました。")

# ---------------------------------------------------
# 2つの方法をJITコンパイルした関数として定義
# ---------------------------------------------------

@jax.jit
def matmul_with_at_operator(m1, m2):
    return m1 @ m2

@jax.jit
def matmul_with_jnp_function(m1, m2):
    return jnp.matmul(m1, m2)

# --- 実行と検証 ---

# '@' 演算子を使った関数を実行
print("\n'@' 演算子で計算を実行中...")
result_at = matmul_with_at_operator(A_on_device, B_on_device)
result_at.block_until_ready() # 計算の完了を待つ
print(f"出力デバイス: {list(result_at.devices())[0]}") # 結果がどのデバイスにあるか確認

# 'jnp.matmul' を使った関数を実行
print("\n'jnp.matmul' 関数で計算を実行中...")
result_func = matmul_with_jnp_function(A_on_device, B_on_device)
result_func.block_until_ready() # 計算の完了を待つ
print(f"出力デバイス: {list(result_func.devices())[0]}")

# 結果が一致することを確認
print(f"\n両者の結果は一致するか？ {jnp.allclose(result_at, result_func)}")

GPUが見つかりました: cuda:0
入力データはデバイス cuda:0 に配置されました。

'@' 演算子で計算を実行中...
出力デバイス: cuda:0

'jnp.matmul' 関数で計算を実行中...
出力デバイス: cuda:0

両者の結果は一致するか？ True


### master equation