## Imports 

In [1]:
# Needs iqm-benchmarks from the github repo to access all the mGST functions: https://github.com/iqm-finland/iqm-benchmarks
from mGST.low_level_jit import dK, objf, ddM, dK_dMdM
from mGST import additional_fns
from iqm.benchmarks.compressive_gst.compressive_gst import GSTConfiguration, CompressiveGST
from iqm.benchmarks.compressive_gst.gst_analysis import dataset_counts_to_mgst_format

from mGST.qiskit_interface import qiskit_gate_to_operator

import numpy as np

from mGST.low_level_jit import cost_function_jax, cost_function_jax_jit

# Check compilation time for contract_jax_jit.
import jax

backend = "iqmfakeapollo"

%load_ext autoreload
%autoreload 2

### Renaming the parameters for the convention used in the derivatives

In [2]:
def get_mgst_parameters_from_dataset(dataset, qubit_layout, rK):
    y = dataset_counts_to_mgst_format(dataset, qubit_layout)
    J = dataset.attrs["J"]
    l = dataset.attrs["seq_len_list"][-1]
    d = dataset.attrs["num_gates"]
    pdim = dataset.attrs["pdim"]
    r = pdim ** 2
    n_povm = dataset.attrs["num_povm"]
    bsize = dataset.attrs["batch_size"]
    meas_samples = dataset.attrs["shots"]
    # Setting some additional matrix shape parameters for the first and second derivatives
    n = rK * pdim
    nt = rK * r
    return y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt

### Initialization

In [3]:
## Preparing an initialization (random gate set or target gate set)
from mGST.additional_fns import random_gs

def initialize_mgst_parameters(dataset, target_init = True):
    d = dataset.attrs["num_gates"]
    pdim = dataset.attrs["pdim"]
    r = pdim ** 2
    n_povm = dataset.attrs["num_povm"]
    rK = dataset.attrs["rank"]
    
    if target_init:
        K_target = qiskit_gate_to_operator(dataset.attrs["gate_set"])
        X_target = np.einsum("ijkl,ijnm -> iknlm", K_target, K_target.conj()).reshape(
            (dataset.attrs["num_gates"], dataset.attrs["pdim"] ** 2, dataset.attrs["pdim"] ** 2)
        )  # tensor of superoperators
        
        rho = (
            np.kron(additional_fns.basis(dataset.attrs["pdim"], 0).T.conj(), additional_fns.basis(dataset.attrs["pdim"], 0))
            .reshape(-1)
            .astype(np.complex128)
        )
        
        # Computational basis measurement:
        E = np.array(
            [
                np.kron(
                    additional_fns.basis(dataset.attrs["pdim"], i).T.conj(), additional_fns.basis(dataset.attrs["pdim"], i)
                ).reshape(-1)
                for i in range(dataset.attrs["pdim"])
            ]
        ).astype(np.complex128)
        
        
        K = additional_fns.perturbed_target_init(X_target, dataset.attrs["rank"])
        X = np.einsum("ijkl,ijnm -> iknlm", K, K.conj()).reshape((d, r, r))
    else:
        K, X, E, rho = random_gs(d, r, rK, n_povm)
        
    return K, X, E, rho

In [4]:
def get_full_mgst_parameters_from_configuration(configuration:GSTConfiguration, backend):
    
    benchmark = CompressiveGST(backend, configuration)
    result = benchmark.run()
    
    rK = configuration.rank
    qubit_layout = configuration.qubit_layouts[0]
    dataset = result.dataset
    y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt = get_mgst_parameters_from_dataset(dataset, qubit_layout=qubit_layout, rK=rK)
    K, X, E, rho = initialize_mgst_parameters(dataset=dataset, target_init=True)
    
    return K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK

In [5]:
from mGST.algorithm import gd

def get_x_from_k(k, d, r):
    return np.einsum("ijkl,ijnm -> iknlm", k, k.conj()).reshape((d, r, r))

def compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, gds_kwargs={}):
    K_gds = gd(K, E, rho, y, J, d, r, rK, fixed_gates=fixed_gates, ls="COBYLA", **gds_kwargs)
    return get_x_from_k(k=K_gds, d=d, r=r)

## Gradient and Hessian

The function dK computes the Wirtinger derivative $\frac{\partial \mathcal L}{\partial K}$.
Here $\mathcal L$ is the cost function "objf".

## 2-GST: $d = 5$

### New tests: using $df/dz$ vs $df/dz^*$ in gradient descent!

start: 08.01.25

update: moved from test_jax_gst_gradient.ipynb

goal: Currently the code from Raphael uses the df/dz instead of df/dz*, even thought the latter is the correct one (one relates to the other by conjugation)

In [6]:
# Euclidean Gradient
Q2_GST = GSTConfiguration(
    qubit_layouts=[[0, 1]],
    gate_set="2QXYCZ",
    num_circuits=800,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q2_GST, backend
)

dK_ = dK(X, K, E, rho, J, y, d, r, rK)

2025-01-10 16:25:05,814 - iqm.benchmarks.logging_config - INFO - Now generating 800 random GST circuits...
2025-01-10 16:25:06,233 - iqm.benchmarks.logging_config - INFO - Will transpile all 800 circuits according to fixed physical layout
2025-01-10 16:25:06,233 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-10 16:25:07,931 - iqm.benchmarks.logging_config - INFO - Submitting batch with 800 circuits corresponding to qubits [0, 1]
2025-01-10 16:25:07,939 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-10 16:25:08,020 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


The following code computes the Wirtinger derivatives "Fyconjy" $= \frac{\partial^2 \mathcal L}{\partial K \partial K^*}$ and "Fyy" $= \frac{\partial^2 \mathcal L}{\partial K \partial K}$. \
Here $\mathcal L$ is the cost function "objf".

In [7]:
# Euclidean Hessian (can take a while to compute depending on rK)
# compute individual second derivative terms
dK_, dM10, dM11 = dK_dMdM(X, K, E, rho, J, y, d, r, rK)
dd, dconjd = ddM(X, K, E, rho, J, y, d, r, rK)

# Assemple terms
Fyconjy = dM11.reshape(d, nt, d, nt) + np.einsum("ijklmnop->ikmojlnp", dconjd).reshape((d, nt, d, nt)) # Mixed derivate by K and K.conj()
Fyy = dM10.reshape(d, nt, d, nt) + np.einsum("ijklmnop->ikmojlnp", dd).reshape((d, nt, d, nt)) # Second derivate by K

In [8]:
print(K.shape)
print(Fyy.shape, d, nt)
# The second derivative is ordered with d = "nubmer of gates" and nt = pdim*pdim*rK = "Product of all Kraus tensor dimenstions per gate"
# So for instance the second derivative just by gate 0 - parameters is stored in Fyy[0,:,0,:], while a mixed derivative by gate 0 and gate 1 - parameters is in Fyy[0,:,1,:] and Fyy[1,:,0,:]

(5, 4, 4, 4)
(5, 64, 5, 64) 5 64


In [9]:
fixed_gates = np.array([(f"G%i" % i in []) for i in range(d)])
x_unoptimized = get_x_from_k(K, d, r)
d, fixed_gates, objf(x_unoptimized, E, rho, J, y)

(5, array([False, False, False, False, False]), 0.001560237764992689)

In [13]:
x_gds_not_conjugate = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates)
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_not_conjugate, E, rho, J, y))
# without numba JIT: it takes 1.7 s approx

Un-optimized f(x): 0.001560237764992689
Optimized f(x): 0.001302074475287173


In [15]:
x_gds_conjugate = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"conjugate":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_conjugate, E, rho, J, y))
# without numba JIT: it takes 1.6 s approx

Un-optimized f(x): 0.001560237764992689
Optimized f(x): 0.00155161497719539


In [16]:
x_gds_not_conjugate_jax = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"use_jax":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_not_conjugate_jax, E, rho, J, y))

INFO:2025-01-10 13:18:57,475:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-01-10 13:18:57,475 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-10 13:18:57,479:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/

Using JAX power
Un-optimized f(x): 0.001560237764992689
Optimized f(x): 0.0013020744756441372


In [17]:
x_gds_conjugate_jax = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"conjugate":True, "use_jax":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_conjugate_jax, E, rho, J, y))

Using JAX power
Un-optimized f(x): 0.001560237764992689
Optimized f(x): 0.0015516149792380047


## 1-GST: $d = 3$

In [7]:
# Example configurations for 1 and 3 qubits:
Q1_GST = GSTConfiguration(
    qubit_layouts=[[0]],
    gate_set="1QXYI",
    num_circuits=100,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q1_GST, backend
)

2025-01-10 14:36:50,941 - iqm.benchmarks.logging_config - INFO - Now generating 100 random GST circuits...
2025-01-10 14:36:51,086 - iqm.benchmarks.logging_config - INFO - Will transpile all 100 circuits according to fixed physical layout
2025-01-10 14:36:51,087 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-10 14:36:51,611 - iqm.benchmarks.logging_config - INFO - Submitting batch with 100 circuits corresponding to qubits [0]
2025-01-10 14:36:51,626 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-10 14:36:51,711 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [8]:
fixed_gates = np.array([(f"G%i" % i in []) for i in range(d)])
x_unoptimized = get_x_from_k(K, d, r)
d, fixed_gates, objf(x_unoptimized, E, rho, J, y)

(3, array([False, False, False]), 0.0020824648688890345)

In [9]:
x_gds_not_conjugate = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates)
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_not_conjugate, E, rho, J, y))
# without numba JIT: it takes 0.1 s approx

Un-optimized f(x): 0.0018295245647571557
Optimized f(x): 0.0014634378187419136


In [10]:
x_gds_conjugate = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"conjugate":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_conjugate, E, rho, J, y))
# without numba JIT: it takes 0.1 s approx

Un-optimized f(x): 0.0018295245647571557
Optimized f(x): 0.0017468833116506134


In [11]:
x_gds_not_conjugate_jax = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"use_jax":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_not_conjugate_jax, E, rho, J, y))

INFO:2025-01-10 13:46:54,788:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-01-10 13:46:54,788 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-10 13:46:54,790:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/

Using JAX power
Un-optimized f(x): 0.0018295245647571557
Optimized f(x): 0.0014634378284405938


In [12]:
x_gds_conjugate_jax = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"conjugate":True, "use_jax":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_conjugate_jax, E, rho, J, y))

Using JAX power
Un-optimized f(x): 0.0018295245647571557
Optimized f(x): 0.0017468833441460598


## 3-GST: $d = 8$

In [13]:
Q3_GST = GSTConfiguration(
    qubit_layouts=[[0,1,3]],
    gate_set="3QXYCZ",
    num_circuits=1000,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q3_GST, backend
)

2025-01-10 13:46:58,857 - iqm.benchmarks.logging_config - INFO - Now generating 1000 random GST circuits...
2025-01-10 13:46:59,067 - iqm.benchmarks.logging_config - INFO - Will transpile all 1000 circuits according to fixed physical layout
2025-01-10 13:46:59,068 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-10 13:47:01,284 - iqm.benchmarks.logging_config - INFO - Submitting batch with 1000 circuits corresponding to qubits [0, 1, 3]
2025-01-10 13:47:01,286 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-10 13:47:01,352 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [14]:
fixed_gates = np.array([(f"G%i" % i in []) for i in range(d)])
x_unoptimized = get_x_from_k(K, d, r)
d, fixed_gates, objf(x_unoptimized, E, rho, J, y)

(8,
 array([False, False, False, False, False, False, False, False]),
 0.000798526620967516)

In [26]:
x_gds_not_conjugate = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates)
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_not_conjugate, E, rho, J, y))
# 2m 28.8

Un-optimized f(x): 0.0007295188784255194
Optimized f(x): 0.0006669932844894213


In [26]:
# unjit numba - no compilation time
# x_gds_not_conjugate = compute_new_x()
# print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
# print('Optimized f(x):', objf(x_gds_not_conjugate, E, rho, J, y))

Un-optimized f(x): 0.000757607351976269
Optimized f(x): 0.0006827538052883298


In [27]:
x_gds_conjugate = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"conjugate":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_conjugate, E, rho, J, y))

Un-optimized f(x): 0.0007295188784255194
Optimized f(x): 0.0007293663469541819


In [25]:
# unjit numba - no compilation time
# x_gds_conjugate = compute_new_x({"conjugate":True})
# print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
# print('Optimized f(x):', objf(x_gds_conjugate, E, rho, J, y))

Un-optimized f(x): 0.000757607351976269
Optimized f(x): 0.0007519109514575818


In [28]:
x_gds_not_conjugate_jax = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"use_jax":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_not_conjugate_jax, E, rho, J, y))
# Before: 50 s

Using JAX power
Un-optimized f(x): 0.0007295188784255194
Optimized f(x): 0.0006669932844901381


In [29]:
x_gds_conjugate_jax = compute_new_x(K, E, rho, y, J, d, r, rK, fixed_gates, {"conjugate":True, "use_jax":True})
print('Un-optimized f(x):', objf(x_unoptimized, E, rho, J, y))
print('Optimized f(x):', objf(x_gds_conjugate_jax, E, rho, J, y))
# before 47s

Using JAX power
Un-optimized f(x): 0.0007295188784255194
Optimized f(x): 0.0007293663469620103


In [30]:
J, len(J), J.shape

(array([[-1, -1, -1, ..., -1, -1, -1],
        [-1, -1, -1, ..., -1, -1,  2],
        [-1, -1, -1, ..., -1, -1,  6],
        ...,
        [ 2,  5,  4, ...,  3,  2,  5],
        [ 4,  4,  2, ...,  5,  4,  5],
        [ 0,  5,  3, ...,  5,  5,  0]]),
 1000,
 (1000, 14))

## Profiling

In [30]:
%prun -D program.prof compute_new_x()

 
*** Profile stats marshalled to file 'program.prof'.


         64489 function calls in 167.512 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       68   88.259    1.298   88.259    1.298 low_level_jit.py:107(objf)
        1   60.612   60.612   60.612   60.612 low_level_jit.py:316(dK)
      552   17.969    0.033   17.976    0.033 linalg.py:492(inv)
      552    0.213    0.000    0.250    0.000 linalg.py:1193(eig)
      622    0.126    0.000    0.126    0.000 {built-in method numpy.core._multiarray_umath.c_einsum}
       69    0.120    0.002   18.552    0.269 optimization.py:60(update_K_geodesic)
      552    0.026    0.000    0.110    0.000 linalg.py:789(qr)
      552    0.023    0.000    0.061    0.000 twodim_base.py:366(tri)
     1104    0.021    0.000    0.023    0.000 twodim_base.py:33(_min_int)
      552    0.017    0.000    0.021    0.000 linalg.py:215(_assert_finite)
     1104    0.015    0.000    0.017    0.000 twodim_base.py:158(eye)
      552    0.013    0.000    0.

In [31]:
%prun -D program_jax.prof compute_new_x({"use_jax":True})

Using JAX power
 
*** Profile stats marshalled to file 'program_jax.prof'.


         57597570 function calls (56059077 primitive calls) in 110.362 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       45   72.137    1.603   72.137    1.603 low_level_jit.py:107(objf)
      368   11.132    0.030   11.139    0.030 linalg.py:492(inv)
       29    1.553    0.054    1.557    0.054 compiler.py:276(backend_compile)
   451139    1.324    0.000    2.983    0.000 abstract_arrays.py:46(canonical_concrete_aval)
    81967    1.294    0.000    3.907    0.000 pjit.py:1736(_pjit_call_impl)
145946/137954    1.057    0.000    5.997    0.000 core.py:928(process_primitive)
41109/41004    0.813    0.000    8.285    0.000 pjit.py:2154(_pjit_partial_eval)
   131623    0.726    0.000    0.897    0.000 partial_eval.py:771(newvar)
  1668990    0.549    0.000    0.550    0.000 pjit.py:2181(<genexpr>)
    55987    0.546    0.000    0.676    0.000 dispatch.py:83(apply_primitive)
6150010/6149840    0.530    0.000    0.555    0.

## Jitting jax cost function

### 1GST

In [6]:
# First using the JAX compilation with smaller qubits to see if this helps
Q1_GST = GSTConfiguration(
    qubit_layouts=[[0]],
    gate_set="1QXYI",
    num_circuits=100,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q1_GST, backend
)

2025-01-14 15:12:33,248 - iqm.benchmarks.logging_config - INFO - Now generating 100 random GST circuits...
2025-01-14 15:12:33,524 - iqm.benchmarks.logging_config - INFO - Will transpile all 100 circuits according to fixed physical layout
2025-01-14 15:12:33,525 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 15:12:34,017 - iqm.benchmarks.logging_config - INFO - Submitting batch with 100 circuits corresponding to qubits [0]
2025-01-14 15:12:34,032 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 15:12:34,109 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [7]:
J_processed = [row[row != -1] for row in J]
# cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

In [9]:
# We see this is kind of what we expected. The _mean_squared_error_inner_jit is not compiled again! This is good!
# However, the contraction gets compiled 15 times. Why?
# My theory is that it's because of the change in the elements lengths of the J. Which would make sense!

previous_length = len(J[0])
for i, element in enumerate(J_processed):
    new_length = len(element)
    if new_length != previous_length:
        print(f"differnet length at: {i}-th iteration")
        previous_length = new_length

differnet length at: 0-th iteration
differnet length at: 1-th iteration
differnet length at: 4-th iteration
differnet length at: 10-th iteration
differnet length at: 16-th iteration
differnet length at: 22-th iteration
differnet length at: 29-th iteration
differnet length at: 36-th iteration
differnet length at: 43-th iteration
differnet length at: 50-th iteration
differnet length at: 58-th iteration
differnet length at: 66-th iteration
differnet length at: 74-th iteration
differnet length at: 82-th iteration
differnet length at: 91-th iteration


In [10]:
cost_function_jax(K, d, r, E, rho, J, y)

INFO:2025-01-14 13:59:13,562:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-01-14 13:59:13,562 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-14 13:59:13,564:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/

Array(0.00462135, dtype=float64)

In [17]:
# J_processed[1] = [2]
# cost_function_jax_jit(K, d, r, E, rho, J_processed, y)
# Result: Changing the values of the J does make us have to recompile

Array(0.00250767, dtype=float64)

In [8]:
# not Jitting anything
%timeit cost_function_jax(K, d, r, E, rho, J, y)

INFO:2025-01-14 15:12:39,282:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-01-14 15:12:39,282 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-14 15:12:39,284:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/

63.5 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
# Jitting everything
%time cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

CPU times: user 1.64 s, sys: 57.4 ms, total: 1.69 s
Wall time: 681 ms


Array(0.00203852, dtype=float64)

In [17]:
# Jitting everything
from mGST.low_level_jit import _cost_function_jax_jit, _cost_function_jax_unjit

# This is the time it takes to compile the whole Jittable function
cost_1gst = jax.jit(_cost_function_jax_jit)
X = _cost_function_jax_unjit(K, d, r)
%timeit cost_1gst.lower(X, E, rho, J, y).compile()

375 μs ± 842 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
# Preparing for evaluation
X = _cost_function_jax_unjit(K, d, r)
cost_1gst_compiled =  cost_1gst.lower(X, E, rho, J, y).compile()

In [20]:
%%timeit X = _cost_function_jax_unjit(K, d, r)
cost_1gst_compiled(X, E, rho, J, y)
# This is how long it takes to evaluate once it was compiled once already

983 μs ± 26.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
# less than we had just running it for the first time only once (218 ms)
total_time_first_run = 375e-6 + 983e-6
total_time_first_run*1000 # ms

1.3579999999999999

In [10]:
%timeit cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

6.38 ms ± 194 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%%timeit X = get_x_from_k(k=K, d=d, r=r)
objf(X, E, rho, J, y)
# Jitting using Numba

215 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


**Conclusions from timeit comparison**

* Numba still has an edge over JAX even when Jitting everything. After running both once to not include compilation time, wee see the comparison:
    * 1GST:
        * NUMBA: 224 μs ± 2.24 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
        * JAX (JIT all): 511 μs ± 7.08 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
        * JAX (JIT nothing): 61.2 ms ± 338 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

* Not a big difference. But already for 2GST the compilation time for JAX becomes too high. The problem is that currently we are compilling the whole function. How about instead we compile only the loops part that are the heaviest?

#### Profiling 1GST

We need to see what parts are worth compiling or not

In [16]:
J_processed = [row[row != -1] for row in J]
len(J_processed)

100

In [17]:
cost_function_jax(K, d, r, E, rho, J, y)

Array(0.0021935, dtype=float64)

In [11]:
%prun -D jit_jax_cost.prof cost_function_jax(K, d, r, E, rho, J, y)

 
*** Profile stats marshalled to file 'jit_jax_cost.prof'.


         245752 function calls in 0.126 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2437    0.017    0.000    0.022    0.000 dispatch.py:83(apply_primitive)
      918    0.011    0.000    0.060    0.000 lax_numpy.py:11834(_attempt_rewriting_take_via_slice)
      818    0.006    0.000    0.011    0.000 lax_numpy.py:2910(where)
    53127    0.005    0.000    0.009    0.000 {built-in method builtins.isinstance}
     2718    0.004    0.000    0.007    0.000 dtypes.py:742(dtype)
      818    0.004    0.000    0.004    0.000 array_methods.py:188(_dot)
     1202    0.004    0.000    0.019    0.000 array_methods.py:568(deferring_binary_op)
     2437    0.003    0.000    0.003    0.000 config.py:194(trace_context)
      500    0.003    0.000    0.005    0.000 ufunc_api.py:172(__call__)
      100    0.003    0.000    0.104    0.001 low_level_jit.py:238(contract_jax)
     2437    0.003    0.000    0.035    0.000 core.py:447(b

**Conclusions from Cprofile**

* We see that the most resource consuming components are the `contract_jax` and the `mean_squared_error_inner` functions. So it makes sense to JIT them. 

* The test below show that these do not get recompiled even when some thing changes as long as they have the same shape.

### Timing each JIT component

In [58]:
# Does the contraction get recompiled?
from mGST.low_level_jit import contract_jax_jit, contract_jax, _mean_squared_error_inner_jit, _mean_squared_error_inner

In [59]:
X = get_x_from_k(K, d, r)

In [60]:
J[10], J_processed[10]

(array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1,  0,  1]),
 array([1, 0, 1]))

In [61]:
contract_1gst = jax.jit(contract_jax)

%timeit contract_1gst.lower(X, j_vec=J_processed[10]).compile()

140 μs ± 78.8 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [62]:
contract_1gst_compiled = jax.jit(contract_jax).lower(X, j_vec=J_processed[10]).compile()
%timeit contract_1gst_compiled(X, j_vec=J_processed[10])

7.83 μs ± 60.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [63]:
%timeit contract_jax(X, j_vec=J_processed[10])

125 μs ± 1.11 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [71]:
# Compilation time: 76.6 μs ± 85.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) 
# From above we know this gets compiled 15 times, so:
time_contract = 15 * 76.6e-6 + 100 * 7.83e-6
time_contract/1000

1.9320000000000003e-06

In [16]:
%time contract_jax_jit(X, j_vec=J_processed[10])
# This is to test the compilation time

CPU times: user 62.4 ms, sys: 12 ms, total: 74.4 ms
Wall time: 39.8 ms


Array([[ 0.52678199+0.j        , -0.02652781+0.49607538j,
        -0.02652781-0.49607538j,  0.47453026+0.j        ],
       [ 0.01559817+0.49634251j,  0.52360493+0.01250515j,
         0.46875118-0.03828833j, -0.01538341-0.4962317j ],
       [ 0.01559817-0.49634251j,  0.46875118+0.03828833j,
         0.52360493-0.01250515j, -0.01538341+0.4962317j ],
       [ 0.47321801+0.j        ,  0.02652781-0.49607538j,
         0.02652781+0.49607538j,  0.52546974+0.j        ]],      dtype=complex128)

In [17]:
%timeit contract_jax_jit(X, j_vec=J_processed[10])

7.55 μs ± 80.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [18]:
J_processed[11]

array([0, 0, 1])

In [22]:
%timeit contract_jax_jit(X, j_vec=J_processed[11])
# From the time order of magnitude we see that the function does not get compiled again. It just uses the previous compilation!

7.88 μs ± 248 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [20]:
X.shape

(3, 4, 4)

In [24]:
contract_jax_jit(X, j_vec=np.array([1,2,7,8,11]))

Array([[ 0.51810885+6.94743899e-18j,  0.01951509+4.89910021e-01j,
         0.01951509-4.89910021e-01j,  0.49441546+7.20783474e-18j],
       [-0.02087604+4.90825122e-01j,  0.50439206+1.49427037e-03j,
         0.478247  +4.01933403e-02j,  0.02003762-4.92207672e-01j],
       [-0.02087604-4.90825122e-01j,  0.478247  -4.01933403e-02j,
         0.50439206-1.49427037e-03j,  0.02003762+4.92207672e-01j],
       [ 0.48189115+7.35491429e-19j, -0.01951509-4.89910021e-01j,
        -0.01951509+4.89910021e-01j,  0.50558454-1.28335256e-18j]],      dtype=complex128)

In [25]:
contract_jax_jit._cache_size()
# From this a very interesting conclusion is that the function gets compiled ONLY when there is no previous cached implementation that can be used. This only happens for different lengths of j_vec in this function

17

In [64]:
# Does the same thing happen for the _mean_squared_error_inner_jit function?
i = 10
o = 1
C = contract_jax(X, j_vec=J_processed[i])

In [65]:
mse_1gst = jax.jit(_mean_squared_error_inner)

%timeit mse_1gst.lower(E=E[o], C=C, rho=rho, y=y[o, i]).compile()

94.7 μs ± 1.08 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [66]:
mse_1gst_compiled =  mse_1gst.lower(E=E[o], C=C, rho=rho, y=y[o, i]).compile()

%timeit mse_1gst_compiled(E=E[o], C=C, rho=rho, y=y[o, i])

8.16 μs ± 50 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [74]:
# Do the compilations + running time reflect how long it takes?
time_mse = 1 * 94.7e-6 + 200 * 8.16e-6
total_time = time_mse + time_contract
total_time*1000 # in ms


3.6587

In [28]:
%time _mean_squared_error_inner(E=E[o], C=C, rho=rho, y=y[o, i])

CPU times: user 368 μs, sys: 496 μs, total: 864 μs
Wall time: 535 μs


Array(0.00120979, dtype=float64)

In [29]:
%time _mean_squared_error_inner_jit(E=E[o], C=C, rho=rho, y=y[o, i])

CPU times: user 34.3 ms, sys: 1.84 ms, total: 36.2 ms
Wall time: 24.1 ms


Array(0.00120979, dtype=float64)

In [30]:
%timeit _mean_squared_error_inner_jit(E=E[o], C=C, rho=rho, y=y[o, i])

8.01 μs ± 214 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [33]:
_mean_squared_error_inner_jit._cache_size()

2

We see it actually helps compiling if this process runs a significant amount of times. 

Now, does it get recompiled when changing the inputs?

In [34]:
%timeit _mean_squared_error_inner_jit(E=E[0], C=C, rho=rho, y=y[0, i])

8.19 μs ± 262 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [35]:
_mean_squared_error_inner_jit._cache_size()

2

In [36]:
# new C
i = 20
o = 0
E_test = E[o]
C_test = contract_jax(X, j_vec=J_processed[i])
y_test = y[o, i]

In [37]:
%timeit _mean_squared_error_inner_jit(E=E_test, C=C_test, rho=rho, y=y_test)

7.89 μs ± 181 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [38]:
_mean_squared_error_inner_jit._cache_size()

2

It seems like it doesnt! So it might be a good idea to use the JITTED version!

In [39]:
# Just checking, is doing the processing of J inside a bad idea?
%timeit [row[row != -1] for row in J]
# From this we see it is actually an important part! Ideally we would generate the J without the -1!

76.7 μs ± 331 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### 2GST

In [11]:
# Euclidean Gradient
Q2_GST = GSTConfiguration(
    qubit_layouts=[[0, 1]],
    gate_set="2QXYCZ",
    num_circuits=800,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q2_GST, backend
)

2025-01-14 15:13:44,116 - iqm.benchmarks.logging_config - INFO - Now generating 800 random GST circuits...
2025-01-14 15:13:44,281 - iqm.benchmarks.logging_config - INFO - Will transpile all 800 circuits according to fixed physical layout
2025-01-14 15:13:44,281 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 15:13:45,832 - iqm.benchmarks.logging_config - INFO - Submitting batch with 800 circuits corresponding to qubits [0, 1]
2025-01-14 15:13:45,834 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 15:13:45,903 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [12]:
J_processed = [row[row != -1] for row in J]
len(J_processed)

800

In [36]:
f_test_2 = jax.jit(contract_jax)

%timeit f_test_2.lower(X, j_vec=J_processed[10]).compile()

75.7 μs ± 384 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [37]:
%timeit contract_jax(X, j_vec=J_processed[10])

141 μs ± 1.6 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [38]:
i = 10
o = 1
C = contract_jax(X, j_vec=J_processed[i])

f_test_3 = jax.jit(_mean_squared_error_inner)

%timeit f_test_3.lower(E=E[1], C=C, rho=rho, y=y[1, 10]).compile()

95.1 μs ± 1.33 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [41]:
%timeit _mean_squared_error_inner(E=E[o], C=C, rho=rho, y=y[o, i])

39.6 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
%timeit cost_function_jax(K, d, r, E, rho, J, y)

707 ms ± 20 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
%%timeit X = get_x_from_k(k=K, d=d, r=r)
objf(X, E, rho, J, y)
# Jitting using Numba

8.91 ms ± 88.7 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
# Jitting everything
%time cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

CPU times: user 2.15 s, sys: 120 ms, total: 2.27 s
Wall time: 833 ms


Array(0.00160813, dtype=float64)

In [15]:
%timeit cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

117 ms ± 2.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# Jitting everything
from mGST.low_level_jit import _cost_function_jax_jit, _cost_function_jax_unjit

# This is the time it takes to compile the whole Jittable function
cost_jit = jax.jit(_cost_function_jax_jit)
X = _cost_function_jax_unjit(K, d, r)
%timeit cost_jit.lower(X, E, rho, J_processed, y).compile()

INFO:2025-01-14 14:24:03,641:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-01-14 14:24:03,641 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-14 14:24:03,644:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/

### 3GST

In [16]:
Q3_GST = GSTConfiguration(
    qubit_layouts=[[0,1,3]],
    gate_set="3QXYCZ",
    num_circuits=1000,
    shots=1000,
    rank=4,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q3_GST, backend
)

2025-01-14 15:15:14,038 - iqm.benchmarks.logging_config - INFO - Now generating 1000 random GST circuits...
2025-01-14 15:15:14,438 - iqm.benchmarks.logging_config - INFO - Will transpile all 1000 circuits according to fixed physical layout
2025-01-14 15:15:14,438 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 15:15:16,636 - iqm.benchmarks.logging_config - INFO - Submitting batch with 1000 circuits corresponding to qubits [0, 1, 3]
2025-01-14 15:15:16,638 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 15:15:16,693 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [17]:
# only Jitting the contraction
%timeit cost_function_jax(K, d, r, E, rho, J, y)
# previous run: 1.35 s ± 4.54 ms per loop 
# NOTE: new time is without JITing anything

1.89 s ± 39.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
%%timeit X = get_x_from_k(k=K, d=d, r=r)
objf(X, E, rho, J, y)
# This is Jitting everything with Numba
# Previous run: 786 ms ± 68.6 ms

1.04 s ± 150 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
# This previous run was taking more time somehow: 1.06 s ± 180 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# %%timeit X = get_x_from_k(k=K, d=d, r=r)
# objf(X, E, rho, J, y)

1.06 s ± 180 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
# Process J beforehand
J_processed = [row[row != -1] for row in J]
len(J_processed)

1000

In [49]:
contract_3gst = jax.jit(contract_jax)

%timeit contract_3gst.lower(X, j_vec=J_processed[10]).compile()

76.4 μs ± 210 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [50]:
i = 10
o = 1
C = contract_jax(X, j_vec=J_processed[i])

mse_3gst = jax.jit(_mean_squared_error_inner)

%timeit mse_3gst.lower(E=E[o], C=C, rho=rho, y=y[o, i]).compile()

95.8 μs ± 985 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
%time cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

CPU times: user 4.03 s, sys: 567 ms, total: 4.6 s
Wall time: 2.14 s


Array(0.0008383, dtype=float64)

In [20]:
%timeit cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

991 ms ± 6.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# Let's now time it with JIT
# %timeit cost_function_jax_jit(K, d, r, E, rho, J_processed, y)
# This took over 40 mins. So we need to find a way to avoid recompilation!

2025-01-10 15:01:36.190417: E external/xla/xla/service/slow_operation_alarm.cc:73] 
********************************
[Compiling module jit__cost_function_jax_jit] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


### 4GST

In [21]:
from iqm.qiskit_iqm import IQMCircuit as QuantumCircuit
from qiskit.circuit.library import CZGate, RGate

cz_cz = QuantumCircuit(4)
cz_cz.append(CZGate(), [0,1])
cz_cz.append(CZGate(), [2,3])

gate_list = [
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, 0),
    RGate(0.5 * np.pi, np.pi / 2),
    RGate(0.5 * np.pi, np.pi / 2),
    RGate(0.5 * np.pi, np.pi / 2),
    RGate(0.5 * np.pi, np.pi / 2),
    cz_cz,
]
gates = [QuantumCircuit(4, 0) for _ in range(len(gate_list))]
gate_qubits = [[0], [1], [2], [3], [0], [1], [2], [3], [0, 1, 2, 3]]
for i, gate in enumerate(gate_list):
    if isinstance(gate, QuantumCircuit):
        gates[i].compose(gate, gate_qubits[i], inplace = True)
    else:
        gates[i].append(gate, gate_qubits[i])
gate_labels = ["Rx(pi/2)", "Rx(pi/2)", "Rx(pi/2)", "Rx(pi/2)", 
               "Ry(pi/2)", "Ry(pi/2)", "Ry(pi/2)", "Ry(pi/2)", 
               "CZ-CZ"]

Q4_GST = GSTConfiguration(
    qubit_layouts=[[0,1,3,4]],
    gate_set=gates,
    gate_labels=gate_labels,
    num_circuits=2000,
    shots=1000,
    rank=1,
)

K, X, E, rho, y, J, l, d, pdim, r, n_povm, bsize, meas_samples, n, nt, rK = get_full_mgst_parameters_from_configuration(
    Q4_GST, backend
)

2025-01-14 15:17:13,439 - iqm.benchmarks.logging_config - INFO - Now generating 2000 random GST circuits...
2025-01-14 15:17:13,878 - iqm.benchmarks.logging_config - INFO - Will transpile all 2000 circuits according to fixed physical layout
2025-01-14 15:17:13,878 - iqm.benchmarks.logging_config - INFO - Transpiling for backend IQMFakeApolloBackend with optimization level 0, sabre routing method all circuits
2025-01-14 15:17:17,192 - iqm.benchmarks.logging_config - INFO - Submitting batch with 2000 circuits corresponding to qubits [0, 1, 3, 4]
2025-01-14 15:17:17,194 - iqm.benchmarks.logging_config - INFO - Now executing the corresponding circuit batch
2025-01-14 15:17:17,263 - iqm.benchmarks.logging_config - INFO - Retrieving all counts


In [22]:
d, len(J)

(9, 2000)

In [8]:
%%timeit X = get_x_from_k(k=K, d=d, r=r)
objf(X, E, rho, J, y)

31 s ± 862 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%timeit -n3 -r1 cost_function_jax(K, d, r, E, rho, J, y)

INFO:2025-01-14 14:39:40,980:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-01-14 14:39:40,980 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-14 14:39:40,984:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/Users/emiliano.godinez/.pyenv/versions/3.11.10/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/

31.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 3 loops each)


In [23]:
# Process J beforehand
J_processed = [row[row != -1] for row in J]
len(J_processed)

2000

In [24]:
%time cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

CPU times: user 2min 4s, sys: 7.79 s, total: 2min 12s
Wall time: 26.5 s


Array(0.0005246, dtype=float64)

In [25]:
%timeit -n1 -r3 cost_function_jax_jit(K, d, r, E, rho, J_processed, y)

25.8 s ± 169 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


## Further tests

In [28]:
import jax
import jax.numpy as jnp

@jax.jit
def f(x):
    m = len(x)
    value = 0
    for i in range(m):
        value += 1
    return value

print(f._cache_size())
# 0
print('-----')

print(f(jnp.arange(3)))
print(f._cache_size())
# 1
print('-----')

print(f(jnp.arange(2,5)))  # should not trigger a recompilation
print(f._cache_size())
# 1
print('-----')

print(f(np.arange(3)))  # should not trigger a recompilation
print(f._cache_size())
# 1
print('-----')


print(f(jnp.arange(5)))  # should trigger a recompilation
print(f._cache_size())
print('-----')


0
-----
3
1
-----
3
1
-----
3
2
-----
5
3
-----


In [27]:
jnp.arange(3) == jnp.arange(1,4)

Array([False, False, False], dtype=bool)

# Results

<table>
  <tr>
    <th>Function</th>
    <th>Running time</th>
  </tr>
  <tr>
    <td colspan="2", style="text-align: center;">1 GST</td>
  </tr>
  <tr>
    <td>Numba</td> 
    <td> 230 μs ± 3.63 μs</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> 60.8 ms ± 580 μs</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 6.01 ms ± 67.8 μs/ First compilation: 681 ms</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction)</td>
    <td> 12.5 ms ± 66.8 μs</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + MSE + outter function)</td>
    <td> 586 μs ± 6.07 μs (second run onwards) / 375 μs ± 842 ns (compilation) + 983 μs ± 26.1 μs (run) =  1.3579 ms (first run). Actual reported: 2.95 s </td>
  </tr>
  <tr>
    <td colspan="2", style="text-align: center;">2 GST</td>
  <tr>
    <td>Numba</td> 
    <td> 8.83 ms ± 22 μs</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> 707 ms ± 31.7 ms</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 115 ms ± 2.31 ms / First compilation: 833 ms</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction)</td>
    <td> 194 ms ± 2.65 ms</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + MSE + outter function)</td>
    <td> NA: compilation took minutes to run (stopped after >5mins).</td>
  </tr>
  <tr>
    <td colspan="2", style="text-align: center;">3 GST</td>
  <tr>
    <td>Numba</td> 
    <td> 1.04 s ± 150 ms</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> 1.84 s ± 10.4 ms </td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 991 ms ± 6.09 ms/ First compilation: 2.14 s</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction)</td>
    <td> 1.35 s ± 4.98 ms</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + MSE + outter function)</td>
    <td> NA: compilation took minutes to run (stopped after >5mins).</td>
  </tr>
  <tr>
    <td colspan="2", style="text-align: center;">4 GST</td>
  <tr>
    <td>Numba</td> 
    <td> 31 s ± 862 ms</td>
  </tr>
  <tr>
    <td>JAX </td>
    <td> 31.1 s ± 0 ns </td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + mse)</td>
    <td> 25.8 s ± 169 ms / First compilation: 26.5 s</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction)</td>
    <td> Not attempted</td>
  </tr>
  <tr>
    <td>JAX + JIT (contraction + MSE + outter function)</td>
    <td> Not attempted</td>
  </tr>
</table>
