# Test the `jax.grad` function

## `VJP`

In [1]:
# First, let us start testing the gradient of the non-holomorphic function from figure 3 in Krämer paper.

import jax
import jax.numpy as jnp

key = jax.random.PRNGKey (1)
key1 , key2 , key3 = jax.random.split(key , num=3) 

A = jax.random.normal(key1 , (3, 3), dtype=complex)
x = jax.random.normal(key2 , (3,), dtype=complex)

# Had to take the real part so that the gradient works correctly.
def function(v):
    return jnp.real(jnp.dot(v.conj(), A@v))

# Complex gradient should be of shape (3,). Remember that we should take the conjugate due to the convention in JAX.
grad_x = jax.grad(function)(x)

# this is the \bar{f} in section 6. Should be: df = du + i dv. Satisying equation 24
# Is a vector of the same shape as the output of f (scalar in this case)
df = 1.0
_, vjp_ad = jax.vjp(function, x)
(dx_ad, ) = vjp_ad(1.0)

# Applying the VJP to df is equal to the latent gradient as expected!
print(jnp.allclose(grad_x, dx_ad))

print(grad_x)

True
[-0.42147675-0.17252079j -0.33669758-0.02341664j  0.77379787-0.578864j  ]


In [2]:
# let me try to calculate the gradient again, but now without taking the real part!
def complex_function(v):
    return jnp.dot(v.conj(), A@v)
# Need to make it complex to align with the type of the output of f!
df_complex = jnp.complex64(1.0)

_, vjp_ad_complex = jax.vjp(complex_function, x)
(dx_ad_complex, ) = vjp_ad_complex(df_complex)
print(jnp.allclose(dx_ad_complex, grad_x))
# Just as we expected, they are the same 🥳
print(dx_ad_complex)

True
[-0.42147675-0.17252079j -0.33669758-0.02341664j  0.77379787-0.578864j  ]


In [4]:
# cannot obtain the gradient of a complex valued function by definition from jax
grad_x_complex = jax.grad(complex_function)(x)

TypeError: grad requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex64. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

In [61]:
# Important question, is this the correct result? Let's compare with the analytical result.
# Apparently the derivative is this expression:

def vjp_analytical(df):
    return (A.conj().T @ x * df) + (A @ x * df.conj())

grad_analytical = vjp_analytical(jnp.complex64(df))
print(grad_analytical)

# We see from this that the gradient we get from jax is the conjugate of what we expect analytically!
print(jnp.allclose(grad_analytical, grad_x))
# It is important then that we take the conjugate!
print(jnp.allclose(grad_analytical.conj(), grad_x))

[-0.42147675+0.17252079j -0.33669758+0.02341664j  0.77379787+0.578864j  ]
False
True


## `JVP`

In [71]:
# Just testing the JVP works for complex numbers
dx = jax.random.normal(key3 , (3,) , dtype=complex) #2*x
# First for the complex function
jvp_test_jax = jax.jvp(complex_function,(x,), (dx,))[1]
print(jvp_test_jax)
# Then for teh real function
jvp_test_jax_real = jax.jvp(function,(x,), (dx,))[1]
print(jvp_test_jax_real)
# Interestingly, they are not the same!
print(jnp.allclose(jvp_test_jax_real, jvp_test_jax))

(1.7417455-3.6587248j)
1.7417455
False


In [73]:
# Well there is indeed a difference between the two functions!
function(x), complex_function(x)

(Array(0.30086893, dtype=float32),
 Array(0.30086893-0.32736388j, dtype=complex64))

### NOTE:

Indeed, here it does not make sense to take the real part, because we are not dealing with a Hermitian matrix (observable). Then the `complex_function` is the more correct thing to do.

In [76]:
# We see this is (NOT??) the same result as the (real) inner product between the gradient and the same input vector (2x here)
jvp_test_grad = jnp.dot(grad_x, dx)
print(jnp.allclose(jvp_test_grad, jvp_test_jax))
print(jvp_test_grad)

# conjugate version?
grad_x_conj = dx_ad_complex
jvp_test_grad_conj = jnp.dot(grad_x_conj, dx)
print(jnp.allclose(jvp_test_grad_conj, jvp_test_jax))
print(jvp_test_grad_conj)

False
(1.7417456-0.3855983j)
False
(1.7417456-0.3855983j)


In [77]:
# But are they the correct result? Let's find out:
def jvp_analytical(z, z_tilde):
    # the dot function takes care of the tranpose
    dw1 = jnp.dot(z_tilde.conj () , A @ z )
    dw2 = jnp.dot(z.conj(), A @ z_tilde)
    return dw1 + dw2

# Apparently they are! Why is the grad * dx not then?
jvp_test_analytical = jvp_analytical(x, dx)
print(jnp.allclose(jvp_test_analytical, jvp_test_jax))
jvp_test_analytical

True


Array(1.7417455-3.6587248j, dtype=complex64)