In [4]:
import jax.numpy as jnp
from jax import jit
from functools import partial

In [5]:
import jax

# 全デバイスを取得
devices = jax.devices()

# GPUデバイスだけを抽出
gpu_devices = [device for device in devices if device.platform == "gpu"]

# GPUの数を出力
print("使用可能なGPUの数:", len(gpu_devices))

使用可能なGPUの数: 1


In [6]:
# 良い例2

@partial(jit, static_argnums=(0,))
def variable_jax_dot_deco(size):
    x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    x_gram = jnp.dot(x, x.T)
    return x_gram

%timeit variable_jax_dot_deco(5).block_until_ready()

51.1 μs ± 298 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [3]:
def identity(dim: int) -> jnp.ndarray:
    """
    Returns the Pauli X (sigma_x) operator.
    The Pauli X operator is a 2x2 matrix defined as:
    [[0, 1],
     [1, 0]]
    
    Returns:
        jnp.array: A 2x2 matrix of the Pauli X operator with complex64 precision.
    """
    return jnp.identity(dim, dtype=jnp.complex64)

In [5]:
def pauli_x() -> jnp.ndarray:
    """
    Returns the Pauli X (sigma_x) operator.
    The Pauli X operator is a 2x2 matrix defined as:
    [[0, 1],
     [1, 0]]
    
    Returns:
        jnp.array: A 2x2 matrix of the Pauli X operator with complex64 precision.
    """
    return jnp.array([[0+0j, 1j+0j], [1j+0j, 0+0j]], dtype=jnp.complex64)

In [7]:
pauli_x()

Array([[0.+0.j, 0.+1.j],
       [0.+1.j, 0.+0.j]], dtype=complex64)

In [2]:
import jax.numpy as jnp
from functools import reduce

def x_gate_on_qubit(i: int, n_qubits: int) -> jnp.ndarray:
    """
    全 n_qubits 系において、0-indexedで i 番目の量子ビットにXゲートを作用させる行列を生成します。

    Args:
        i (int): Xゲートを作用させる量子ビットのインデックス（0-indexed）
        n_qubits (int): 全体の量子ビット数

    Returns:
        jnp.ndarray: サイズ (2**n_qubits, 2**n_qubits) の作用素行列
    """
    # Xゲートの定義
    X = jnp.array([[0, 1],
                   [1, 0]], dtype=jnp.complex64)
    # 2次元単位行列
    I = jnp.eye(2, dtype=jnp.complex64)
    
    # 各量子ビットに対する作用素のリストを作成
    ops = [X if j == i else I for j in range(n_qubits)]
    
    # functools.reduce を用いてテンソル積（Kronecker積）で全体作用素を構成
    full_operator = reduce(jnp.kron, ops)
    return full_operator

# 使用例: 3量子ビット系の2番目（0-indexed: i=1）にXゲートを作用させる場合
if __name__ == "__main__":
    operator_matrix = x_gate_on_qubit(1, 2)
    print(operator_matrix)

[[0.+0.j 1.+0.j 0.+0.j 0.+0.j]
 [1.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 1.+0.j]
 [0.+0.j 0.+0.j 1.+0.j 0.+0.j]]


In [None]:
import jax.numpy as jnp
from jax import jit
from functools import partial

def X(n_qubit: int, target_qubit_idx: int) -> jnp.ndarray:
    

In [11]:
jax.local_device_count()

1

In [12]:
jax.devices()

[CudaDevice(id=0)]

In [15]:
import jax.numpy as jnp
from jax import random, pmap

# 8枚のGPUがある環境で行列積を並列実行します。
# 8つのランダムな5000x6000 行列を定義します。
keys = random.split(random.PRNGKey(0), jax.local_device_count())
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# 明示的なデータ転送なく、行列積を各GPUで並列実行できます。
%timeit result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape = (8, 5000, 5000)

# 各GPUの計算結果の平均を求めます。
print(pmap(jnp.mean)(result))

9.5 ms ± 57.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[1.2109671]


In [17]:
%timeit jnp.dot(mats, mats.T)

TypeError: dot_general requires contracting dimensions to have the same shape, got (6000,) and (5000,).

In [18]:
import jax.numpy as jnp
from jax import jit
from functools import partial

def X(n_qubit: int, target_qubit_idx: int) -> jnp.ndarray:
    

# 1-qubit gate
## pauli X
@partial(jit, static_argnums=(0, 0, ))
def X(n_qubit: int, target_qubit_idx: int) -> jnp.ndarray:
    I = identity
    local_X = sigma_x()
    if target_qubit_idx==0:
        mat = local_X
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_X)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

In [22]:
jit(X(2, 0)).block_until_ready()

TypeError: Expected a callable value, got [[0. 0. 1. 0.]
 [0. 0. 0. 1.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]]

In [13]:
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial, reduce

def identity(dim: int) -> jnp.ndarray:
    return jnp.identity(dim, dtype=jnp.complex64)

def pauli_x() -> jnp.ndarray:
    return jnp.array([[0+0j, 1+0j], [1+0j, 0+0j]], dtype=jnp.complex64)

@partial(jit, static_argnums=(0, 1))
def X(target_qubit_idx: int, n_qubits: int) -> jnp.ndarray:
    ops = [pauli_x() if i == target_qubit_idx else identity(dim=2) for i in range(n_qubits)]
    mat = reduce(jnp.kron, ops)
    return mat

In [None]:
%timeit X(0, 10)

In [8]:
def X(target_qubit_idx: int, n_qubits: int) -> jnp.ndarray:
    ops = [pauli_x() if i == target_qubit_idx else identity(dim=2) for i in range(n_qubits)]
    mat = reduce(jnp.kron, ops)
    return mat

In [9]:
%timeit X(0, 10)

5.33 ms ± 192 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%load_ext autoreload
%autoreload 2
import gate

gate.X(0, 5)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


ModuleNotFoundError: No module named 'gate'

In [13]:
import jax

def body_fun(i, val):
    return val + i  # i番目のループで構成した値を返却 → 次のループで使用されるvalとなる

# 最終ループで構成したvalがresultに格納される
result = jax.lax.fori_loop(
    1,        # i=0から始める
    10,       # i<10まで実行する
    body_fun, # for文の中で実行する関数
    0,        # 実行する関数に渡すvalの初期値
) 

print(result)

45


In [25]:
def X(target_qubit_idx: int, n_qubits: int) -> jnp.ndarray:
    
    init_ops = jax.lax.cond(target_qubit_idx==0, lambda: pauli_x(), lambda: identity(2))

    def loop_kron(i, acc):
        ops = jax.lax.cond(target_qubit_idx==i, lambda: pauli_x(), lambda: identity(2))
        return jnp.kron(acc, ops)

    tot_mat = jax.lax.fori_loop(1, n_qubits, loop_kron, init_ops)
    return tot_mat

In [26]:
%timeit X(0, 10)

TypeError: scan body function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:

The input carry component loop_carry[1] has type complex64[2,2] but the corresponding output carry component has type complex64[4,4], so the shapes do not match.

Revise the function so that all output types (e.g. shapes and dtypes) match the corresponding input types.

In [17]:
import jax
import jax.numpy as jnp

def identity(dim: int) -> jnp.ndarray:
    return jnp.identity(dim, dtype=jnp.complex64)

def pauli_x() -> jnp.ndarray:
    return jnp.array([[0+0j, 1+0j],
                      [1+0j, 0+0j]], dtype=jnp.complex64)

def X(target_qubit_idx: int, n_qubits: int) -> jnp.ndarray:
    # 最終的な行列サイズは 2^n_qubits × 2^n_qubits
    final_dim = 2 ** n_qubits

    # 初期状態：量子ビット 0 に対応するオペレータを設定し、最終サイズのゼロ行列の左上に埋め込む
    init_op = jax.lax.cond(
        target_qubit_idx == 0,
        lambda _: pauli_x(),
        lambda _: identity(2),
        operand=None
    )
    acc0 = jnp.zeros((final_dim, final_dim), dtype=jnp.complex64)
    acc0 = jax.lax.dynamic_update_slice(acc0, init_op, (0, 0))
    
    def loop_kron(i, acc):
        # 現在の有効部分のサイズ（2^i × 2^i）
        current_size = 2 ** i
        current = jax.lax.dynamic_slice(acc, (0, 0), (current_size, current_size))
        # i 番目の量子ビットに対するオペレータを選択
        op = jax.lax.cond(
            target_qubit_idx == i,
            lambda _: pauli_x(),
            lambda _: identity(2),
            operand=None
        )
        # 現在のオペレータとのクロネッカー積を計算（サイズは 2^(i+1) × 2^(i+1) になる）
        new_val = jnp.kron(current, op)
        # 最終サイズのゼロ行列を作成し、左上に new_val を埋め込むことで carry の型を固定
        new_acc = jnp.zeros((final_dim, final_dim), dtype=jnp.complex64)
        new_acc = jax.lax.dynamic_update_slice(new_acc, new_val, (0, 0))
        return new_acc

    # i = 1 から n_qubits - 1 までループ
    tot_mat = jax.lax.fori_loop(1, n_qubits, loop_kron, acc0)
    return tot_mat

In [18]:
%timeit X(0, 10)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function loop_kron at /tmp/ipykernel_637915/777855946.py:25 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].
The error occurred while tracing the function loop_kron at /tmp/ipykernel_637915/777855946.py:25 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].

## quantum gate

In [4]:
import jax.numpy as jnp
from jax import jit
from functools import partial

# 1-qubit gate
## pauli X
@partial(jit, static_argnums=(0, 0, ))
def X(n_qubit: int, target_qubit_idx: int) -> jnp.ndarray:
    I = jnp.eye(2)
    local_X = jnp.array([[0,1],[1,0]])
    if target_qubit_idx==0:
        mat = local_X
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_X)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## pauli Y
def Y(n_qubit, target_qubit_idx):
    I = jnp.eye(2)
    local_Y = np.array([[0,-1j], [1j,0]])
    if target_qubit_idx==0:
        mat = local_Y
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_Y)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## pauli Z
def Z(n_qubit, target_qubit_idx):
    I = np.eye(2)
    local_Z = jnp.array([[1,0], [0,-1]])
    if target_qubit_idx==0:
        mat = local_Z
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_Z)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## Hadamard gate
def H(n_qubit, target_qubit_idx):
    I = np.eye(2)
    local_H = jnp.array([[1,1], [1,-1]]) / np.sqrt(2)
    if target_qubit_idx==0:
        mat = local_H
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_H)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## S gate
def S(n_qubit, target_qubit_idx):
    I = np.eye(2)
    local_S = jnp.array([[1,0], [0,1j]])
    if target_qubit_idx==0:
        mat = local_S
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_S)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## T gate
def T(n_qubit, target_qubit_idx):
    I = np.eye(2)
    local_T = jnp.array([[1,0], [0,-np.exp(1j*np.pi/4)]])
    if target_qubit_idx==0:
        mat = local_T
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_T)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## Rx gate
def Rx(n_qubit, target_qubit_idx, theta):
    I = np.eye(2)
    local_Rx = jnp.array([[np.cos(theta/2),-1j*np.sin(theta/2)], [-1j*np.sin(theta/2),np.cos(theta/2)]])
    if target_qubit_idx==0:
        mat = local_Rx
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_Rx)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## Ry gate
def Ry(n_qubit, target_qubit_idx, theta):
    I = np.eye(2)
    local_Ry = jnp.array([[np.cos(theta/2),-np.sin(theta/2)], [-np.sin(theta/2),np.cos(theta/2)]])
    if target_qubit_idx==0:
        mat = local_Ry
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = jnp.kron(mat, local_Ry)
        else:
            mat = jnp.kron(mat, I)
            
    return mat

## Rz gate
def Rz(n_qubit, target_qubit_idx, theta):
    I = np.eye(2)
    local_Rz = np.array([[np.exp(-1j*theta/2),0], [0,np.cos(1j*theta/2)]])
    if target_qubit_idx==0:
        mat = local_Rz
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = np.kron(mat, local_Rz)
        else:
            mat = np.kron(mat, I)
            
    return mat

# 2-qubit gate
## CX gate
def CX(n_qubit, control_qubit_idx, target_qubit_idx):
    I = np.eye(2)
    ket_0 = np.array([[1],[0]]) 
    ket_1 = np.array([[0],[1]])
    
    mat_00 = ket_0 @ ket_0.T.conjugate() ### |0><0|
    mat_11 = ket_1 @ ket_1.T.conjugate() ### |1><1|
    
    eye_tensor = I
    
    for i in range(n_qubit-2):
        eye_tensor = np.kron(eye_tensor, I)
    
    if control_qubit_idx < target_qubit_idx:
        ### |0><0|　\otimes I \otimes ... \otimes I
        cx_mat_term1 = np.kron(mat_00, eye_tensor)
        ### |1><1| \otimes I ... \otimes X ...
        cx_mat_term2 = np.kron(mat_11, X(n_qubit-1, target_qubit_idx-1))
        
    if control_qubit_idx > target_qubit_idx:
        ### I \otimes ... \otimes I \otimes |0><0|
        cx_mat_term1 = np.kron(eye_tensor, mat_00)
        ### ... \otimes X ... \otimes |1><1|
        cx_mat_term2 = np.kron(X(n_qubit-1, target_qubit_idx), mat_11)
    
    mat = cx_mat_term1 + cx_mat_term2
    
    return mat

## CY gate
def CY(n_qubit, control_qubit_idx, target_qubit_idx):
    I = np.eye(2)
    ket_0 = np.array([[1],[0]]) 
    ket_1 = np.array([[0],[1]])
    
    mat_00 = ket_0 @ ket_0.T.conjugate() ### |0><0|
    mat_11 = ket_1 @ ket_1.T.conjugate() ### |1><1|
    
    eye_tensor = I
    
    for i in range(n_qubit-2):
        eye_tensor = np.kron(eye_tensor, I)
    
    if control_qubit_idx < target_qubit_idx:
        ### |0><0|　\otimes I \otimes ... \otimes I
        cx_mat_term1 = np.kron(mat_00, eye_tensor)
        ### |1><1| \otimes ... \otimes Y ...
        cx_mat_term2 = np.kron(mat_11, Y(n_qubit-1, target_qubit_idx-1))
        
    if control_qubit_idx > target_qubit_idx:
        ### I \otimes ... \otimes I \otimes |0><0|
        cx_mat_term1 = np.kron(eye_tensor, mat_00)
        ### ... \otimes Y ... \otimes |1><1|
        cx_mat_term2 = np.kron(Y(n_qubit-1, target_qubit_idx), mat_11)
    
    mat = cx_mat_term1 + cx_mat_term2
    
    return mat

## CZ gate
def CZ(n_qubit, control_qubit_idx, target_qubit_idx):
    I = np.eye(2)
    ket_0 = np.array([[1],[0]]) 
    ket_1 = np.array([[0],[1]])
    
    mat_00 = ket_0 @ ket_0.T.conjugate() ### |0><0|
    mat_11 = ket_1 @ ket_1.T.conjugate() ### |1><1|
    
    eye_tensor = I
    
    for i in range(n_qubit-2):
        eye_tensor = np.kron(eye_tensor, I)
    
    if control_qubit_idx < target_qubit_idx:
        ### |0><0|　\otimes I \otimes ... \otimes I
        cx_mat_term1 = np.kron(mat_00, eye_tensor)
        ### |1><1| \otimes I ... \otimes Z ...
        cx_mat_term2 = np.kron(mat_11, Z(n_qubit-1, target_qubit_idx-1))
        
    if control_qubit_idx > target_qubit_idx:
        ### I \otimes ... \otimes I \otimes |0><0|
        cx_mat_term1 = np.kron(eye_tensor, mat_00)
        ### ... \otimes Z ... \otimes |1><1|
        cx_mat_term2 = np.kron(Z(n_qubit-1, target_qubit_idx), mat_11)
    
    mat = cx_mat_term1 + cx_mat_term2
    
    return mat

## CH gate
def CH(n_qubit, control_qubit_idx, target_qubit_idx):
    I = np.eye(2)
    ket_0 = np.array([[1],[0]]) 
    ket_1 = np.array([[0],[1]])
    
    mat_00 = ket_0 @ ket_0.T.conjugate() ### |0><0|
    mat_11 = ket_1 @ ket_1.T.conjugate() ### |1><1|
    
    eye_tensor = I
    
    for i in range(n_qubit-2):
        eye_tensor = np.kron(eye_tensor, I)
    
    if control_qubit_idx < target_qubit_idx:
        ### |0><0|　\otimes I \otimes ... \otimes I
        cx_mat_term1 = np.kron(mat_00, eye_tensor)
        ### |1><1| \otimes I ... \otimes H ...
        cx_mat_term2 = np.kron(mat_11, H(n_qubit-1, target_qubit_idx-1))
        
    if control_qubit_idx > target_qubit_idx:
        ### I \otimes ... \otimes I \otimes |0><0|
        cx_mat_term1 = np.kron(eye_tensor, mat_00)
        ### ... \otimes H ... \otimes |1><1|
        cx_mat_term2 = np.kron(H(n_qubit-1, target_qubit_idx), mat_11)
    
    mat = cx_mat_term1 + cx_mat_term2
    
    return mat

## CRx gate
def CRx(n_qubit, control_qubit_idx, target_qubit_idx, theta):
    I = np.eye(2)
    ket_0 = np.array([[1],[0]]) 
    ket_1 = np.array([[0],[1]])
    
    mat_00 = ket_0 @ ket_0.T.conjugate() ### |0><0|
    mat_11 = ket_1 @ ket_1.T.conjugate() ### |1><1|
    
    eye_tensor = I
    
    for i in range(n_qubit-2):
        eye_tensor = np.kron(eye_tensor, I)
    
    if control_qubit_idx < target_qubit_idx:
        ### |0><0|　\otimes I \otimes ... \otimes I
        cx_mat_term1 = np.kron(mat_00, eye_tensor)
        ### |1><1| \otimes I ... \otimes Rx ...
        cx_mat_term2 = np.kron(mat_11, Rx(n_qubit-1, target_qubit_idx-1, theta))
        
    if control_qubit_idx > target_qubit_idx:
        ### I \otimes ... \otimes I \otimes |0><0|
        cx_mat_term1 = np.kron(eye_tensor, mat_00)
        ### ... \otimes Rx ... \otimes |1><1|
        cx_mat_term2 = np.kron(Rx(n_qubit-1, target_qubit_idx, theta), mat_11)
    
    mat = cx_mat_term1 + cx_mat_term2
    
    return mat

## CRy gate
def CRy(n_qubit, control_qubit_idx, target_qubit_idx, theta):
    I = np.eye(2)
    ket_0 = np.array([[1],[0]]) 
    ket_1 = np.array([[0],[1]])
    
    mat_00 = ket_0 @ ket_0.T.conjugate() ### |0><0|
    mat_11 = ket_1 @ ket_1.T.conjugate() ### |1><1|
    
    eye_tensor = I
    
    for i in range(n_qubit-2):
        eye_tensor = np.kron(eye_tensor, I)
    
    if control_qubit_idx < target_qubit_idx:
        ### |0><0|　\otimes I \otimes ... \otimes I
        cx_mat_term1 = np.kron(mat_00, eye_tensor)
        ### |1><1| \otimes I ... \otimes Ry ...
        cx_mat_term2 = np.kron(mat_11, Ry(n_qubit-1, target_qubit_idx-1, theta))
        
    if control_qubit_idx > target_qubit_idx:
        ### I \otimes ... \otimes I \otimes |0><0|
        cx_mat_term1 = np.kron(eye_tensor, mat_00)
        ### ... \otimes Ry ... \otimes |1><1|
        cx_mat_term2 = np.kron(Ry(n_qubit-1, target_qubit_idx, theta), mat_11)
    
    mat = cx_mat_term1 + cx_mat_term2
    
    return mat

## CRz gate
def CRz(n_qubit, control_qubit_idx, target_qubit_idx, theta):
    I = np.eye(2)
    ket_0 = np.array([[1],[0]]) 
    ket_1 = np.array([[0],[1]])
    
    mat_00 = ket_0 @ ket_0.T.conjugate() ### |0><0|
    mat_11 = ket_1 @ ket_1.T.conjugate() ### |1><1|
    
    eye_tensor = I
    
    for i in range(n_qubit-2):
        eye_tensor = np.kron(eye_tensor, I)
    
    if control_qubit_idx < target_qubit_idx:
        ### |0><0|　\otimes I \otimes ... \otimes I
        cx_mat_term1 = np.kron(mat_00, eye_tensor)
        ### |1><1| \otimes I ... \otimes Rz ...
        cx_mat_term2 = np.kron(mat_11, Rz(n_qubit-1, target_qubit_idx-1, theta))
        
    if control_qubit_idx > target_qubit_idx:
        ### I \otimes ... \otimes I \otimes |0><0|
        cx_mat_term1 = np.kron(eye_tensor, mat_00)
        ### ... \otimes Rz ... \otimes |1><1|
        cx_mat_term2 = np.kron(Rz(n_qubit-1, target_qubit_idx, theta), mat_11)
    
    mat = cx_mat_term1 + cx_mat_term2
    
    return mat

## SWAP gate
def SWAP(n_qubit, qubit_idx_1, qubit_idx_2):
    mat = CX(n_qubit, qubit_idx_1, qubit_idx_2) @ CX(n_qubit, qubit_idx_2, qubit_idx_1)
    mat = mat @ CX(n_qubit, qubit_idx_1, qubit_idx_2)
    
    return mat

# 3-qubit gate
## toffoli gate
"""
def toffoli(n_qubit, control_qubit_idx_1, control_qubit_idx_2, target_qubit_idx):
"""

def global_depolarizing_error(state, n_qubit, error_rate):
    return (1-error_rate)*state + error_rate*np.trace(state)*np.eye(2**n_qubit)/(2**n_qubit)
    
def unitary_error(state, n_qubit, theta, target_qubit_idx):
    return Rx(n_qubit, target_qubit_idx, theta) @ state @ Rx(n_qubit, target_qubit_idx, theta).T.conjugate()

In [5]:
X(2, 0).block_until_ready()

TypeError: X() takes 2 positional arguments but 3 were given

In [3]:
import polars as pl
import numpy as np
df = pl.DataFrame(
    {
        "nrs": [1, 2, 3, None, 5],
        "names": ["foo", "ham", "spam", "egg", None],
        "random": np.random.rand(5),
        "groups": ["A", "A", "B", "C", "B"],
    }
)

shape: (5, 4)
┌──────┬───────┬──────────┬────────┐
│ nrs  ┆ names ┆ random   ┆ groups │
│ ---  ┆ ---   ┆ ---      ┆ ---    │
│ i64  ┆ str   ┆ f64      ┆ str    │
╞══════╪═══════╪══════════╪════════╡
│ 1    ┆ foo   ┆ 0.703911 ┆ A      │
│ 2    ┆ ham   ┆ 0.091943 ┆ A      │
│ 3    ┆ spam  ┆ 0.023759 ┆ B      │
│ null ┆ egg   ┆ 0.434993 ┆ C      │
│ 5    ┆ null  ┆ 0.376646 ┆ B      │
└──────┴───────┴──────────┴────────┘


In [4]:
def gen_dataset()

nrs,names,random,groups
i64,str,f64,str
1.0,"""foo""",0.703911,"""A"""
2.0,"""ham""",0.091943,"""A"""
3.0,"""spam""",0.023759,"""B"""
,"""egg""",0.434993,"""C"""
5.0,,0.376646,"""B"""


In [None]:
import jax
import jax.numpy as jnp
from flax import nnx

class RBM(nnx.Module):
    def __init__(self, num_vnodes: int, num_hnodes: int):
        """
        Define the parameters of RBM.
        Args:
            num_vnodes (int): number of visible nodes
            num_hnodes (int): number of hidden nodes
        """
        self.num_vnodes = num_vnodes
        self.num_hnodes = num_hnodes
        self.W = nnx.Param(jnp.zeros((num_vnodes, num_hnodes)))
        self.v_bias = nnx.Param(jnp.zeros(num_vnodes))
        self.h_bias = nnx.Param(jnp.zeros(num_hnodes))

    def free_energy(self, v: jnp.ndarray, h:jnp.ndarray) -> jnp.ndarray:
        """
        Compute energy function
        Args:
            v (jnp.ndarray): 
            h (jnp.ndarray): 
        """
        term_1 = jnp.dot(v, jnp.dot(self.W, h))
        term_2 = jnp.dot(self.v_bias, v)
        term_3 = jnp.dot(self.h_bias, h)
        
        return -(term_1+term_2+term3)

    def grad_energy

In [None]:
class RBM4PureState(nnx.Module):

class RBM4MixedState(nnx.Module):
    

In [1]:
import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

In [2]:
grads

Linear(weight=f32[3,2], bias=f32[3])

In [None]:
class RBM(eqx.Module):
    def __init__(self, num_vnode: int, num_hnode: int, rand_key: int):
        self.num_vnode = num_vnode
        self.num_hnode = num_hnode

    def initialize_params(self, 

    def energy(self, v, h):
        r"""
        \mathcal{E} \left( \vb*{v}, \vb*{h}, \vb*{\theta} \right)
        &= - \vb*{b}^\top \vb*{v} - \vb*{c}^\top \vb*{h} - \vb*{v}^\top \vb*{W} \vb*{h}
        """
        return -jnp.

    def partition(self, ):
        

    def free_energy(self, v: jax.Array) -> jax.Array:
        r"""
        $$
        F() &= 
        $$
        """
        
        return 

In [4]:
jnp.zeros(2)

Array([0., 0.], dtype=float32)

$$
\begin{align}
\mathcal{E}(\boldsymbol{v}) &= -\sum_{j}b_j v_j
                        - \sum_{i}\log
                            \left\lbrack 1 +
                                  \exp\left(c_{i} + \sum_{j} W_{ij} v_j\right)
                            \right\rbrack
\end{align}
$$

In [2]:
jnp.eye(2)

Array([[1., 0.],
       [0., 1.]], dtype=float32)