In [None]:
# Needs iqm-benchmarks from the github repo to access all the mGST functions: https://github.com/iqm-finland/iqm-benchmarks
from mGST.low_level_jit import dK, objf
from mGST.low_level_jit import dK_jax, dK_jax_jit
from iqm.benchmarks.compressive_gst.compressive_gst import GSTConfiguration
from mGST.utility_functions_comparisons import get_full_mgst_parameters_from_configuration
# Check compilation time for contract_jax_jit.
import jax.numpy as jnp
import numpy as np

backend = "iqmfakeapollo"

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# 1GST

In [37]:
# First using the JAX compilation with smaller qubits to see if this helps
Q1_GST = GSTConfiguration(
    qubit_layouts=[[0]],
    gate_set="1QXYI",
    num_circuits=100,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q1_GST, backend
)

2025-01-14 16:02:48,529 - iqm.benchmarks.logging_config - INFO - Now generating 100 random GST circuits...
2025-01-14 16:02:48,547 - iqm.benchmarks.logging_config - INFO - Will transpile all 100 circuits according to fixed physical layout
2025-01-14 16:02:48,547 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 16:02:48,653 - iqm.benchmarks.logging_config - INFO - Submitting batch with 100 circuits corresponding to qubits [0]
2025-01-14 16:02:48,653 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 16:02:48,672 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [38]:
J_processed = [row[row != -1] for row in J]
len(J_processed)

100

In [17]:
# Numba
%timeit dK(X, K, E, rho, J, y, d, r, rK)

5.66 ms ± 188 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
# JAX
%timeit dK_jax(K, E, rho, J, y, d, r)

Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
718 ms ± 7.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
# JAX JIT
%timeit dK_jax_jit(K, E, rho, J_processed, y, d, r)
# before using J: 213 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
168 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# just checking correctness
numba_grad = dK(X, K, E, rho, J, y, d, r, rK)
jax_grad = dK_jax_jit(K, E, rho, J_processed, y, d, r)
jnp.allclose(2*numba_grad, jax_grad)
# apparantly the result is the same whether we use J or J_processed (?)

Using JAX power


Array(True, dtype=bool)

# 2GST

In [21]:
# Euclidean Gradient
Q2_GST = GSTConfiguration(
    qubit_layouts=[[0, 1]],
    gate_set="2QXYCZ",
    num_circuits=800,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q2_GST, backend
)

2025-01-14 15:39:53,802 - iqm.benchmarks.logging_config - INFO - Now generating 800 random GST circuits...
2025-01-14 15:39:53,968 - iqm.benchmarks.logging_config - INFO - Will transpile all 800 circuits according to fixed physical layout
2025-01-14 15:39:53,968 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 15:39:55,562 - iqm.benchmarks.logging_config - INFO - Submitting batch with 800 circuits corresponding to qubits [0, 1]
2025-01-14 15:39:55,564 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 15:39:55,620 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [22]:
J_processed = [row[row != -1] for row in J]
len(J_processed)

800

In [23]:
# Numba
%timeit dK(X, K, E, rho, J, y, d, r, rK)

411 ms ± 18.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [24]:
# JAX
%timeit dK_jax(K, E, rho, J, y, d, r)

Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
7.56 s ± 69.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
# JAX JIT
%timeit dK_jax_jit(K, E, rho, J_processed, y, d, r)

Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
Using JAX power
2.02 s ± 33.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# 3GST

In [26]:
Q3_GST = GSTConfiguration(
    qubit_layouts=[[0,1,3]],
    gate_set="3QXYCZ",
    num_circuits=1000,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q3_GST, backend
)

2025-01-14 15:42:25,133 - iqm.benchmarks.logging_config - INFO - Now generating 1000 random GST circuits...
2025-01-14 15:42:25,647 - iqm.benchmarks.logging_config - INFO - Will transpile all 1000 circuits according to fixed physical layout
2025-01-14 15:42:25,647 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 15:42:28,056 - iqm.benchmarks.logging_config - INFO - Submitting batch with 1000 circuits corresponding to qubits [0, 1, 3]
2025-01-14 15:42:28,057 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 15:42:28,114 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [27]:
J_processed = [row[row != -1] for row in J]
len(J_processed)

1000

In [28]:
# Numba
%timeit dK(X, K, E, rho, J, y, d, r, rK)

1min 20s ± 8.24 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
# JAX
%timeit -n1 -r3 dK_jax(K, E, rho, J, y, d, r)

Using JAX power
Using JAX power
Using JAX power
15 s ± 354 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


In [30]:
# JAX JIT
%timeit -n1 -r3 dK_jax_jit(K, E, rho, J_processed, y, d, r)

Using JAX power
Using JAX power
Using JAX power
5.72 s ± 926 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


In [31]:
# Is it still correct?
numba_grad = dK(X, K, E, rho, J, y, d, r, rK)

In [32]:
jax_grad = dK_jax_jit(K, E, rho, J_processed, y, d, r)

Using JAX power


In [33]:
jnp.allclose(2*numba_grad, jax_grad)
# apparantly the result is the same whether we use J or J_processed (?)

Array(True, dtype=bool)

# 4GST

In [39]:
from iqm.qiskit_iqm import IQMCircuit as QuantumCircuit
from qiskit.circuit.library import CZGate, RGate

cz_cz = QuantumCircuit(4)
cz_cz.append(CZGate(), [0,1])
cz_cz.append(CZGate(), [2,3])

gate_list = [
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, np.pi / 2),
    RGate(0.5 * np.pi, np.pi / 2),
    RGate(0.5 * np.pi, np.pi / 2),
    RGate(0.5 * np.pi, np.pi / 2),
    cz_cz,
]
gates = [QuantumCircuit(4, 0) for _ in range(len(gate_list))]
gate_qubits = [[0], [1], [2], [3], [0], [1], [2], [3], [0, 1, 2, 3]]
for i, gate in enumerate(gate_list):
    if isinstance(gate, QuantumCircuit):
        gates[i].compose(gate, gate_qubits[i], inplace = True)
    else:
        gates[i].append(gate, gate_qubits[i])
gate_labels = ["Rx(pi/2)", "Rx(pi/2)", "Rx(pi/2)", "Rx(pi/2)", 
               "Ry(pi/2)", "Ry(pi/2)", "Ry(pi/2)", "Ry(pi/2)", 
               "CZ-CZ"]

Q4_GST = GSTConfiguration(
    qubit_layouts=[[0,1,3,4]],
    gate_set=gates,
    gate_labels=gate_labels,
    num_circuits=2000,
    shots=1000,
    rank=1,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q4_GST, backend
)

2025-01-14 16:03:09,344 - iqm.benchmarks.logging_config - INFO - Now generating 2000 random GST circuits...
2025-01-14 16:03:10,251 - iqm.benchmarks.logging_config - INFO - Will transpile all 2000 circuits according to fixed physical layout
2025-01-14 16:03:10,251 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 16:03:13,559 - iqm.benchmarks.logging_config - INFO - Submitting batch with 2000 circuits corresponding to qubits [0, 1, 3, 4]
2025-01-14 16:03:13,561 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 16:03:13,630 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [40]:
# Process J beforehand
J_processed = [row[row != -1] for row in J]
len(J_processed)

2000

In [41]:
# JAX JIT
%timeit -n1 -r3 dK_jax_jit(K, E, rho, J_processed, y, d, r)

Using JAX power
Using JAX power
Using JAX power
3min 1s ± 5.24 s per loop (mean ± std. dev. of 3 runs, 1 loop each)


# Results

<table>
  <tr>
    <th>Function</th>
    <th>Running time</th>
  </tr>
  <tr>
    <td colspan="2", style="text-align: center;">1 GST</td>
  </tr>
  <tr>
    <td>Numba</td> 
    <td> 5.66 ms ± 188 μs</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> 718 ms ± 7.46 ms</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 168 ms ± 1.14 ms</td>
  </tr>
  <tr>
    <td colspan="2", style="text-align: center;">2 GST</td>
  <tr>
    <td>Numba</td> 
    <td> 411 ms ± 18.6 ms</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> 7.56 s ± 69.1 ms</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 2.02 s ± 33.6 ms</td>
  </tr>
  <tr>
    <td colspan="2", style="text-align: center;">3 GST</td>
  <tr>
    <td>Numba</td> 
    <td> 1min 20s ± 8.24 s</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> 15 s ± 354 ms </td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 5.72 s ± 926 ms</td>
  <tr>
    <td colspan="2", style="text-align: center;">4 GST</td>
  <tr>
    <td>Numba</td> 
    <td> 89.63 min</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> NA </td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 3min 1s ± 5.24 s</td>
  </tr>
</table>
