# gradients with tensorflow

In [1]:
import numpy as np
import tensorflow as tf
from poenta.jitted import C_mu_Sigma, dC_dmu_dSigma

In [2]:
tf.__version__

'2.3.0'

In [2]:
gamma_ = 0.1+0.1j
gamma = tf.Variable(gamma_)
phi_ = 0.2
phi = tf.Variable(phi_, dtype=tf.float64)
z_ = 0.3+0.3j
z = tf.Variable(z_)
r_ = np.abs(z_)
delta_ = np.angle(z_)

In [3]:
# old
%timeit dC_dmu_dSigma(gamma_, phi_, z_)

249 µs ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
# new
%timeit dC_dmu_dSigma(gamma_, phi_, z_)

13.1 µs ± 7.01 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# old
%timeit C_mu_Sigma(gamma_, phi_, z_)

56.6 µs ± 7.64 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [5]:
# new
%timeit C_mu_Sigma(gamma_, phi_, z_)

2.33 µs ± 137 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [32]:
from numba import njit

@njit
def C_jit(gamma, phi, z):    
    r = np.abs(z)
    delta = np.angle(z)

    C = np.exp(
            -0.5 * np.abs(gamma) ** 2 - 0.5 * np.conj(gamma) ** 2 * np.exp(1j * (2 * phi + delta)) * np.tanh(r)
        ) / np.sqrt(np.cosh(r))
    
    return C 

In [39]:
@tf.function
def C(gamma, phi, z):
    r = tf.cast(tf.abs(z), z.dtype)
    phi = tf.cast(phi, z.dtype)
    delta = tf.cast(tf.math.angle(z), z.dtype)

    C = tf.exp(
        -0.5 * tf.norm(gamma)**2 - 0.5 * tf.math.conj(gamma) ** 2 * tf.exp(1j * (2 * phi + delta)) * tf.math.tanh(r)
    ) / tf.sqrt(tf.math.cosh(r))
    return C

In [40]:
with tf.GradientTape(persistent=True) as tape:
    c = C(gamma, phi, z) 
    ic = c*1j

In [41]:
%%timeit 
c_gamma, c_phi, c_z = tape.gradient(c, [gamma, phi, z])
ic_gamma, ic_phi, ic_z = tape.gradient(ic, [gamma, phi, z])

1.03 ms ± 38.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [37]:
%timeit [x.numpy() for x in [(c_gamma + 1j*ic_gamma)/2, (c_gamma - 1j*ic_gamma)/2, tf.complex(c_phi, ic_phi), (c_z + 1j*ic_z)/2, (c_z - 1j*ic_z)/2]]

279 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [36]:
%timeit dC_dmu_dSigma(0.1+0.1j, 0.2, 0.3+0.3j)[0]

260 µs ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [42]:
@tf.function
def example():
  a = tf.constant(0.)
  b = 2 * a
  return tf.gradients(a + b, [a, b], stop_gradients=[a, b])
example()

[<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]

In [43]:
%timeit example()

136 µs ± 840 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [34]:
C_jit(gamma_, phi_, z_)

(0.9441936749793952+0.0014216305305803543j)

In [107]:
z

<tf.Variable 'Variable:0' shape=() dtype=complex128, numpy=(0.3+0.3j)>

In [69]:
a = 1+1j
b = 1
    
with tf.GradientTape(persistent=True) as tape:
    w = a*z + b*tf.math.conj(z)
    c = (w)*tf.math.conj(w)
    c2 = tf.abs(w)**2

In [77]:
tape.gradient(c, z)/2

<tf.Tensor: shape=(), dtype=complex128, numpy=(0.8999999999999999-0.3j)>

In [79]:
tape.gradient(c2, z)/2

<tf.Tensor: shape=(), dtype=complex128, numpy=(0.8999999999999999-0.3j)>

In [76]:
import numpy as np
(np.abs(a)**2+np.abs(b)**2)*z + 2*(b*tf.math.conj(a))*tf.math.conj(z)

<tf.Tensor: shape=(), dtype=complex128, numpy=(0.9000000000000001-0.2999999999999998j)>

In [62]:
2*z + 2*tf.math.conj(z)

<tf.Tensor: shape=(), dtype=complex128, numpy=(1.2+0j)>

In [3]:
from jax import vjp, jvp, jacobian

In [11]:
import numpy as np
np.set_printoptions(suppress=True)
from jax import numpy as np

In [30]:
def C(gammaR, gammaI, phi, zR, zI):
    r = np.sqrt(zR**2 + zI**2)
    delta = np.arctan2(zI,zR)

    C = np.exp(
        -0.5 * (gammaR**2 + gammaI**2) - 0.5 * np.conj(gammaR-1j*gammaI) ** 2 * np.exp(1j * (2 * phi + delta)) * np.tanh(r)
    ) / np.sqrt(np.cosh(r))
    return C

def C(gamma, phi, z):
    r = np.abs(z)
    delta = np.angle(z)

    C = np.exp(
        -0.5 * r**2 - 0.5 * np.conj(gamma) ** 2 * np.exp(1j * (2 * phi + delta)) * np.tanh(r)
    ) / np.sqrt(np.cosh(r))
    return C

In [32]:
a,b = vjp(C, 1.0+0.0j, 2.0, 3.0+0.0j)

In [41]:
b(1.0-1.0j)

(DeviceArray(0.00614274-0.00296379j, dtype=complex64),
 DeviceArray(-0.00296379, dtype=float32),
 DeviceArray(-0.02196645+0.00049397j, dtype=complex64))

In [2]:
from poenta.jitted import dC_dmu_dSigma

In [9]:
dC, _, _ = dC_dmu_dSigma(0.1+0.2j, 0.3, 0.4*0.5j)

In [12]:
dC

array([-0.04806664+0.09608172j, -0.06861528-0.13317848j,
       -0.00041036-0.00947422j,  0.00200461+0.14158204j,
        0.0000472 -0.09421096j])

In [43]:
jvp(C, (1.0+0.0j, 2.0, 3.0+0.0j), (1.0+0.0j, 1.0j, 0))

TypeError: primal and tangent arguments to jax.jvp must have equal types; type mismatch primal complex64 vs tangent float32

In [44]:
from jax import grad

In [46]:
grad(C, (1,), holomorphic=False)(0.1+0.2j, 0.3, 0.4+0.5j)

TypeError: grad requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex64. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

In [None]:
vjp