# Batch Computation

This tutorial involves how to use batch computation to simplify a series of inputs in QuAIRKit.

**Table of Contents**
- [Batch computation in quantum circuit](#Batch-computation-in-quantum-circuit)
- [Batch computation in measurement](#Batch-computation-in-measurement)

In [1]:
import torch

import quairkit as qkit
from quairkit import Circuit
from quairkit.database import *
from quairkit.loss import ExpecVal, Measure

qkit.set_dtype("complex128")

## Batch computation in quantum circuit

`Circuit` class in QuAIRKit supports add batched parameters and gates to the circuit. 

For parameterized gates like $R_x(\theta)$, $R_y(\theta)$, $R_z(\theta)$, one can add batched parameters to the circuit by passing a 3-dimensional tensor to gate function, where the 3 dimensions are:

- *len(qubits_idx)* : the number of qubits acted by the gates;
- *batch_size* : the number of batched parameters;
- *num_acted_param* : the number of parameters that characterize the gate. For example, *num_acted_param* for Ry gate is 1 and that for universal three qubit gate is 15.

Here is an example of batched parameters as an input onto a parameterized quantum circuit.

In [2]:
num_qubits = 2
batch_size = 3

list_x = torch.rand(num_qubits * batch_size * 1)  # num_acted_param=1
cir = Circuit(num_qubits)
cir.rx(param=list_x)  # set Rx gate
print(f"Quantum circuit output: {cir()}")

# this is equivalent to below code
# for x in list_x:
#     cir_1 = Circuit(1)
#     cir_1.rx(param=x)
#     print(f"Quantum circuit output for adding one Rx gate: {cir()}")

Quantum circuit output: 
---------------------------------------------------
 Backend: state_vector
 System dimension: [2, 2]
 System sequence: [1, 0]
 Batch size: [3]

 # 0:
[ 0.95+0.j    0.  -0.12j  0.  -0.28j -0.03+0.j  ]
 # 1:
[ 0.97+0.j    0.  -0.23j  0.  -0.08j -0.02+0.j  ]
 # 2:
[ 0.98+0.j    0.  -0.14j  0.  -0.16j -0.02+0.j  ]
---------------------------------------------------



For oracles stored as torch.Tensor, one can add batched matrices to the circuit by `oracle` or `control_oracle`.

In [3]:
cir_ora = Circuit(2)
list_unitary = random_unitary(1, size=batch_size)
print(f"The shape of oracle unitary: {list_unitary.shape}")

cir_ora.oracle(list_unitary, [1])
print(f"Quantum circuit output: {cir_ora()}")

# this is equivalent to below code
# for idx, unitary in enumerate(list_unitary):
#     cir_ora2 = Circuit(2)
#     cir_ora2.oracle(unitary, [1])
#     print(f"Quantum circuit {idx}: {cir_ora2()}")

The shape of oracle unitary: torch.Size([3, 2, 2])
Quantum circuit output: 
---------------------------------------------------
 Backend: state_vector
 System dimension: [2, 2]
 System sequence: [1, 0]
 Batch size: [3]

 # 0:
[0.61-0.4j  0.  +0.j   0.54-0.42j 0.  +0.j  ]
 # 1:
[ 0.08+0.06j  0.  +0.j   -0.99+0.03j  0.  +0.j  ]
 # 2:
[ 0.03-0.74j  0.  +0.j   -0.61+0.29j  0.  +0.j  ]
---------------------------------------------------



QuAIRKit also supports batched channels through batching their Kraus or Choi operators. One can add batched channels to the circuit via `kraus_channel` or `choi_channel`.  Notice that Kraus representation is recommended in batch computation.

In [4]:
cir_kra = Circuit(2)
list_kraus = random_channel(num_qubits=1, size=batch_size)
cir_kra.kraus_channel(list_kraus, [0])
print(f"Kraus channel: {cir_kra()}")
output_state = cir_kra()

# this is equivalent to below code
# for idx, kraus in enumerate(list_kraus):
#     cir_kra2 = Circuit(2)
#     cir_kra2.kraus_channel(kraus, [0])
#     print(f"Kraus channel {idx}: {cir_kra2()}")

Kraus channel: 
---------------------------------------------------
 Backend: density_matrix
 System dimension: [2, 2]
 System sequence: [0, 1]
 Batch size: [3]

 # 0:
[[0.45+0.j   0.  +0.j   0.25+0.43j 0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]
 [0.25-0.43j 0.  +0.j   0.55+0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]]
 # 1:
[[ 0.23+0.j    0.  +0.j   -0.21-0.37j  0.  +0.j  ]
 [ 0.  +0.j    0.  +0.j    0.  +0.j    0.  +0.j  ]
 [-0.21+0.37j  0.  +0.j    0.77+0.j    0.  +0.j  ]
 [ 0.  +0.j    0.  +0.j    0.  +0.j    0.  +0.j  ]]
 # 2:
[[0.06+0.j   0.  +0.j   0.09+0.21j 0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]
 [0.09-0.21j 0.  +0.j   0.94+0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]]
---------------------------------------------------



In [5]:
cir_cho = Circuit(2)
list_choi = random_channel(num_qubits=1, target="choi", size=batch_size)
cir_cho.choi_channel(list_choi, [1])
print(f"Choi channel: {cir_cho()}")

# this is equivalent to below code
# for idx, choi in enumerate(list_choi):
#     cir_cho2 = Circuit(2)
#     cir_cho2.choi_channel(choi, [0])
#     print(f"Choi channel {idx}: {cir_cho2()}")

Choi channel: 
---------------------------------------------------
 Backend: density_matrix
 System dimension: [2, 2]
 System sequence: [0, 1]
 Batch size: [3]

 # 0:
[[0.4 +0.j   0.06+0.35j 0.  +0.j   0.  +0.j  ]
 [0.06-0.35j 0.6 +0.j   0.  +0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]]
 # 1:
[[0.45+0.j   0.08-0.05j 0.  +0.j   0.  +0.j  ]
 [0.08+0.05j 0.55+0.j   0.  +0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]]
 # 2:
[[ 0.22+0.j  -0.21+0.3j  0.  +0.j   0.  +0.j ]
 [-0.21-0.3j  0.78+0.j   0.  +0.j   0.  +0.j ]
 [ 0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j ]
 [ 0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j ]]
---------------------------------------------------



Mathematical property of Kraus operators is checked.

One can then check that this circuit preserves the trace.

In [6]:
tr = output_state.trace()
torch.allclose(tr, torch.ones_like(tr))

True

For clarity, the following figure illustrates how batch computation works in quantum circuits.

<figure style="text-align: center;">
  <img src="./figures/batch.jpg" alt="alt text" width="700"/>
  <figcaption>Fig.1: Depiction of batched quantum circuits on single input state.</figcaption>
</figure>


The code of these circuits is given as follows 

In [7]:
rho = random_state(1)
list_x = torch.rand(batch_size)
list_depo = torch.stack(
    [depolarizing_kraus(torch.rand(1)) for _ in list(range(batch_size))]
)

batch_cir = Circuit(1)
batch_cir.ry()
batch_cir.rz(param=list_x)
batch_cir.kraus_channel(list_depo, 0)
batch_cir.ry()
batch_cir.rz(param=list_x)
print(f"Output state: {batch_cir(rho)}")

Output state: 
---------------------------------------------------
 Backend: density_matrix
 System dimension: [2]
 System sequence: [0]
 Batch size: [3]

 # 0:
[[0.72+0.j   0.07-0.33j]
 [0.07+0.33j 0.28+0.j  ]]
 # 1:
[[ 0.63+0.j   -0.05-0.23j]
 [-0.05+0.23j  0.37+0.j  ]]
 # 2:
[[ 0.6 -0.j   -0.04-0.17j]
 [-0.04+0.17j  0.4 +0.j  ]]
---------------------------------------------------



## Batch computation in measurement

Measurement in QuAIRKit also support batch computation. We start with an observable represented by `Hamiltonian` and a projection valued measure (PVM).

In [8]:
H = random_hamiltonian_generator(num_qubits)
print(f"Hamiltonian: {H.pauli_str}")

Hamiltonian: [[0.158452521275237, 'X1'], [-0.06368870737400667, 'X0'], [0.6403721911006688, 'Z0']]


One can call the `expec_val` of `State` class, or implement the neural network module `ExpecVal` on batched states.

In [9]:
print(f"Output state: {output_state}")
op = ExpecVal(H)
print(f"expectation value: {op(output_state)}")
# this is equivalent to below code
# for state in output_state:
#     print(f"expectation value of each: {op(state)}")

print(f"expectation value: {output_state.expec_val(H)}")
# return the expectation value of each Pauli term
print(
    f"expectation value of each Pauli term: {output_state.expec_val(H, decompose=True)}"
)

Output state: 
---------------------------------------------------
 Backend: density_matrix
 System dimension: [2, 2]
 System sequence: [0, 1]
 Batch size: [3]

 # 0:
[[0.45+0.j   0.  +0.j   0.25+0.43j 0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]
 [0.25-0.43j 0.  +0.j   0.55+0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]]
 # 1:
[[ 0.23+0.j    0.  +0.j   -0.21-0.37j  0.  +0.j  ]
 [ 0.  +0.j    0.  +0.j    0.  +0.j    0.  +0.j  ]
 [-0.21+0.37j  0.  +0.j    0.77+0.j    0.  +0.j  ]
 [ 0.  +0.j    0.  +0.j    0.  +0.j    0.  +0.j  ]]
 # 2:
[[0.06+0.j   0.  +0.j   0.09+0.21j 0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]
 [0.09-0.21j 0.  +0.j   0.94+0.j   0.  +0.j  ]
 [0.  +0.j   0.  +0.j   0.  +0.j   0.  +0.j  ]]
---------------------------------------------------

expectation value: tensor([-0.0933, -0.3136, -0.5813])
expectation value: tensor([-0.0933, -0.3136, -0.5813])
expectation value of each Pauli term: tensor([[ 0.0000,  0.0000,  0.0000],


Similarly, to measure the output state, one can call the `measure` of `State` class, or implement the neural network module `Measure` on batched states. The following code measures the second qubit of the output state.

In [10]:
output_state = cir_kra()
basis = random_unitary(1).unsqueeze(-1)
pvm = basis @ basis.mH
print(f"The shape of PVM: {pvm.shape}")

op = Measure(pvm)
print(f"expectation value: {op(output_state, [0])}")
# this is equivalent to below code
# for state in output_state:
#     print(f"expectation value: {op(state, [0])}")

print(f"expectation value: {output_state.measure(pvm, [0])}")

The shape of PVM: torch.Size([2, 2, 2])
expectation value: tensor([[0.0784, 0.9216],
        [0.9952, 0.0048],
        [0.4924, 0.5076]])
expectation value: tensor([[0.0784, 0.9216],
        [0.9952, 0.0048],
        [0.4924, 0.5076]])


One can also keep the collapsed states after the measurement by setting `keep_state = True`.

---

*Table: A reference of notation conventions in this tutorial.*

| Symbol        | Variant          | Description                               |
|:---------------:|:------------------:|-------------------------------------------|
| $R_{x/y/z}(\theta)$     |       |rotation gates about the $X$/$Y$/$Z$-axis  |
|$\rho_{\text{in}}$||input quantum state|
|$\rho_{\text{out}}$|$\rho_{\text{out},1}$, $\rho_{\text{out},2}$, $\rho_{\text{out},3}$| output quantum state|
|$\mathcal{N}$|$\mathcal{N}_1$, $\mathcal{N}_2$, $\mathcal{N}_3$|quantum channel  |


In [11]:
qkit.print_info()


---------VERSION---------
quairkit: 0.1.0
torch: 2.3.1+cpu
numpy: 1.26.0
scipy: 1.14.0
matplotlib: 3.9.0
---------SYSTEM---------
Python version: 3.10.14
OS: Windows
OS version: 10.0.26100
---------DEVICE---------
CPU: ARMv8 (64-bit) Family 8 Model 1 Revision 201, Qualcomm Technologies Inc
