In [1]:
# You will need to install iqm-benchmarks from the github repo to access all the mGST functions: https://github.com/iqm-finland/iqm-benchmarks
from mGST.low_level_jit import dK, objf
import shelve
import numpy as np

%load_ext autoreload
%autoreload 2

In [2]:
# Loading parameters
filename = "./Minimal_gradient/shelve_out"
my_shelf = shelve.open(filename)
for key in ["K", "E_new", "rho", "y", "J", "d", "r", "rK", "fixed_elements"]:
    print(key)
    globals()[key]=my_shelf[key]
my_shelf.close()

K
E_new
rho
y
J
d
r
rK
fixed_elements


In [3]:
print(d), print(r), print(fixed_elements)

3
4
[]


(None, None, None)

In [4]:
# Setting some additional_parameters
E = E_new
pdim = int(np.sqrt(r))
n = rK * pdim
Delta = np.zeros((d, n, pdim)).astype(np.complex128)
X = np.einsum("ijkl,ijnm -> iknlm", K, K.conj()).reshape((d, r, r))

In [5]:
print(objf(X, E, rho, J, y))

0.0019762572998321807


In [7]:
dK_ = dK(X, K, E, rho, J, y, d, r, rK)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1mUse of unsupported NumPy function 'numpy.einsum' or unsupported use of the function.
[1m
File "../src/mGST/low_level_jit.py", line 349:[0m
[1mdef dK(X, K, E, rho, J, y, d, r, rK):
    <source elided>
    """
[1m    X = np.einsum("ijkl,ijnm -> iknlm", K, K.conj()).reshape((d, r, r))
[0m    [1m^[0m[0m
[0m
[0m[1mDuring: typing of get attribute at /Users/emiliano.godinez/Documents/phd/iqm-benchmarks/src/mGST/low_level_jit.py (349)[0m
[1m
File "../src/mGST/low_level_jit.py", line 349:[0m
[1mdef dK(X, K, E, rho, J, y, d, r, rK):
    <source elided>
    """
[1m    X = np.einsum("ijkl,ijnm -> iknlm", K, K.conj()).reshape((d, r, r))
[0m    [1m^[0m[0m


In [6]:
dK_.shape, dK_.dtype, type(dK_)

NameError: name 'dK_' is not defined

# Try obtain the derivative but with JAX

In [7]:
import jax
import jax.numpy as jnp

In [6]:
a = jnp.array([1,1])
jnp.complex128(a)

Array([1.+0.j, 1.+0.j], dtype=complex128)

In [8]:
from mGST.low_level_jit import cost_function_jax
d = 3
r = 4

In [11]:
cost_function_jax(K, d, r, E, rho, J, y)

Array(0.00197626, dtype=float64)

**Result with single precision**: Array(0.00197626, dtype=float32)

In [12]:
grad_jax = jax.grad(fun=cost_function_jax, argnums=0)(K, d, r, E, rho, J, y)

In [13]:
grad_jax.shape, grad_jax.dtype

((3, 4, 2, 2), dtype('complex128'))

In [14]:
jnp.allclose(grad_jax/2, dK_)

Array(True, dtype=bool)

In [15]:
jnp.linalg.norm(grad_jax - dK_)

Array(0.06980057, dtype=float64)

In [17]:
K.dtype, E.dtype, rho.dtype, J.dtype, y.dtype

(dtype('complex128'),
 dtype('complex128'),
 dtype('complex128'),
 dtype('int32'),
 dtype('float64'))

In [18]:
grad_jax[0,0,:,:]/2

Array([[ 6.89530895e-06+1.41335293e-05j, -6.95757287e-07+2.63181640e-05j],
       [ 1.79551344e-05-1.89523399e-05j,  1.34852364e-05+4.34467736e-06j]],      dtype=complex128)

In [19]:
dK_[0,0,:,:]

array([[ 6.89530895e-06+1.41335293e-05j, -6.95757287e-07+2.63181640e-05j],
       [ 1.79551344e-05-1.89523399e-05j,  1.34852364e-05+4.34467736e-06j]])

## Conclusion up to now:

The gradients seem to match up to a factor of 2 and the conjugation. I will now check the results both of these return inside the algorithm.

### TODO:
* Once I verify they both give reasonable results, I can time both executions

# Gradient Descent

In [20]:
from mGST.algorithm import gd

In [25]:
fixed_gates = np.array([(f"G%i" % i in fixed_elements) for i in range(d)])
fixed_gates

def get_x_from_k(k):
    return np.einsum("ijkl,ijnm -> iknlm", k, k.conj()).reshape((d, r, r))

K_gds = gd(K, E_new, rho, y, J, d, r, rK, fixed_gates=fixed_gates, ls="COBYLA")

In [26]:
x_gds = get_x_from_k(k=K_gds)

print('Old f(x):', objf(X, E, rho, J, y))
print('New f(x):', objf(x_gds, E, rho, J, y))

Old f(x): 0.0019762572998321807
New f(x): 0.0015298804063626297


In [32]:
K_gds_jax = gd(K, E_new, rho, y, J, d, r, rK, fixed_gates=fixed_gates, ls="COBYLA", use_jax=True)

Using JAX power


In [33]:
x_gds_jax = get_x_from_k(k=K_gds_jax)

print('Old f(x):', objf(X, E, rho, J, y))
print('New f(x):', objf(x_gds_jax, E, rho, J, y))

Old f(x): 0.0019762572998321807
New f(x): 0.0015298804016647602


In [34]:
# NOTE: we are able to obtain the same result up to numerical precision.
# However, this is without taking the conj() of the gradient, which is weird!
# Could the definition be taking already the conjugate into accoiunt? 
# (so that the conjugate of the gradient is actually the direction we want to follow)

# Timing: (after restarting to avoid cache)

In [5]:
# defining the numba function including the contraction of X:
def grad_numba(K, E, rho, J, y, d, r, rK):
    X = np.einsum("ijkl,ijnm -> iknlm", K, K.conj()).reshape((d, r, r))
    grad = dK(X, K, E, rho, J, y, d, r, rK)
    return grad

In [7]:
# lets time how long it takes to compute the gradients

# Without JIT
# First let's time the TN contraction:
%timeit dK(X, K, E, rho, J, y, d, r, rK)

43.8 ms ± 219 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
# Without JIT
# Now let's time JAX:
%timeit jax.grad(fun=cost_function_jax, argnums=0)(K, d, r, E, rho, J, y)

534 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
# With Numba cache and JIT
%timeit grad_numba(K, E, rho, J, y, d, r, rK)

5.25 ms ± 88.4 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
# With JIT
%timeit jax.grad(fun=cost_function_jax, argnums=0)(K, d, r, E, rho, J, y)

275 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
