# Import

In [2]:
import jax
from jax import numpy as jnp, random
import numpy as np
import pandas as pd
from mondeq.modules import MonLinear, Relu, deq, deq_forward, deq_backward, _fp_bwd
from mondeq.splittings import build_peacemanrachford_update
from utils import commutation_matrix_sp, vec
from functools import partial

# Verify derivative

In [3]:
num_samples = 10
d = 3
k = 10
key = random.PRNGKey(0)




In [4]:
rand_z = np.random.normal(size=(k,))
rand_x = np.random.normal(size=(d,))

In [5]:
key, skey = random.split(key)
lin_module = MonLinear(3, 10, key=skey)
nonlin_module = Relu()
K = commutation_matrix_sp(k, k)
P_A = (np.eye(k*k) + K) @ np.kron(np.eye(k), lin_module.p_A)
P_B = (np.eye(k*k) - K)
W = lin_module.calW_trans().T

Autodiff results

## $f$

In [6]:
auto_Jz = jax.jacrev(lambda z: nonlin_module(lin_module(z.T, rand_x.T)))(rand_z)
auto_Jlin = jax.jacrev(lambda lin_module: nonlin_module(lin_module(rand_z.T, rand_x.T)), allow_int=True)(lin_module)

Close form results

$\frac{\partial f}{\partial z} = J W$

In [7]:
a = lin_module(rand_z.T, rand_x.T)
Jnl = jnp.diag(nonlin_module.derivative(a))
ana_Jz = Jnl@W
jnp.allclose(ana_Jz, auto_Jz, rtol=1e-4)

DeviceArray(True, dtype=bool)

$\frac{\partial f}{\partial A} = -(z^T \otimes J)P_A$

In [8]:
ana_JA = jnp.kron(rand_z.T, Jnl)@P_A
ana_JA = ana_JA.reshape(num_samples,k,k, order='C')
jnp.allclose(-auto_Jlin.p_A, ana_JA)

DeviceArray(True, dtype=bool)

$\frac{\partial f}{\partial B} = (z^T \otimes J)P_B$

In [9]:
ana_JB = jnp.kron(rand_z.T, Jnl)@P_B
ana_JB = ana_JB.reshape(num_samples,k,k, order='C')
jnp.allclose(auto_Jlin.p_B, ana_JB)

DeviceArray(True, dtype=bool)

$\frac{\partial f}{\partial U} = (x^T \otimes J)$

In [10]:
ana_JU = jnp.kron(rand_x.T, Jnl)
ana_JU = ana_JU.reshape(k, d, k, order='C')
jnp.allclose(auto_Jlin.p_U, ana_JU)

DeviceArray(True, dtype=bool)

## $\mathbf{z}^*$

In [11]:
max_iter=40
tol = 1e-5
Z0 = jnp.zeros_like(rand_z).T
X = rand_x.T
g = jnp.zeros_like(rand_z)
g = g.at[0].set(1.)

In [12]:
dyn = partial(build_peacemanrachford_update, 1.)

In [13]:
z_star, res = deq_forward(dyn, max_iter, tol, nonlin_module, lin_module, Z0, X)
d_lin, _, dx = deq_backward(dyn, max_iter, tol, nonlin_module, res, g)

In [14]:
Jnl_star = jnp.diag(nonlin_module.derivative(z_star))
uT = g.T @ jnp.linalg.inv(jnp.eye(k) - Jnl_star@W)
d_A = -jnp.kron(z_star, uT.T @ Jnl_star) @ P_A
d_A = d_A.reshape(k, k)
d_B = jnp.kron(z_star, uT.T @ Jnl_star) @ P_B
d_B = d_B.reshape(k, k)
d_U = jnp.kron(rand_x, uT.T @ Jnl_star)
d_U = d_U.reshape(d, k)

In [15]:
print(np.allclose(d_lin.p_A, d_A, atol=1e-4))
print(np.allclose(d_lin.p_B, d_B, atol=1e-4))
print(np.allclose(d_lin.p_U, d_U, atol=1e-4))

True
True
True


# Tangent kernel

In [16]:
from mondeq.modules import fcmon
from scipy.linalg import khatri_rao

In [17]:
X = np.random.normal(size=(num_samples, d))
Z0 = jnp.zeros(shape=(num_samples, k))
u = np.random.normal(size=(k,))

In [18]:
def vprod(M):
    return jnp.einsum("ihk,jhk->ij", M, M)

Autodiff result

In [19]:
Ju = jax.jacrev(lambda u: fcmon(dyn, max_iter, tol, u, nonlin_module, lin_module, Z0, X))(u)
Jlin = jax.jacrev(lambda lin_module: fcmon(dyn, max_iter, tol, u, nonlin_module, lin_module, Z0, X), allow_int=True)(lin_module)
auto_H1 = vprod(Jlin.p_A) + vprod(Jlin.p_B)
auto_H2 = vprod(Jlin.p_U)
auto_H3 = Ju @ Ju.T

Close form

In [20]:
Z_star, res = deq_forward(dyn, max_iter, tol, nonlin_module, lin_module, Z0, X)
S = nonlin_module.derivative(Z_star)
V = _fp_bwd(dyn, max_iter, tol, nonlin_module, res, u)
Q = V * S
Z_Q = khatri_rao(Z_star.T, Q.T).T

ana_H1 = Z_Q @ (P_A@P_A.T + P_B@P_B.T)@Z_Q.T
ana_H2 = (X@X.T) * (Q@Q.T)
ana_H3 = Z_star @ Z_star.T

In [21]:
ana_H1[1]

matrix([[7.08820103e+00, 2.20743472e+02, 1.14907742e+01, 1.80920019e+01,
         8.45144949e+01, 6.20128403e+01, 2.06550030e+01, 2.73771034e+00,
         1.73839350e-01, 4.08653358e+01]])

In [22]:
auto_H1[1]

DeviceArray([7.0882053e+00, 2.2074435e+02, 1.1490833e+01, 1.8092319e+01,
             8.4515778e+01, 6.2012928e+01, 2.0655462e+01, 2.7377620e+00,
             1.7383835e-01, 4.0866005e+01], dtype=float32)

In [24]:
print(jnp.allclose(auto_H1, ana_H1, rtol=1e-3))
print(jnp.allclose(auto_H2, ana_H2, rtol=1e-3))
print(jnp.allclose(auto_H3, ana_H3, rtol=1e-4))

True
True
True
