In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
from jax import random, numpy as jnp
from mondeq.modules import MonLinear, Relu, deq
from mondeq.splittings import build_peacemanrachford_update, solve
from functools import partial

In [3]:
num_samples = 1
d = 3
k = 10

In [4]:
key = random.PRNGKey(0)



In [5]:
key, skey = random.split(key)
lin_module = MonLinear(3, 10, key=skey)
nonlin_module = Relu()

In [6]:
key, *skey = random.split(key, 3)
x = random.normal(skey[0], (num_samples, d))
z0 = jnp.zeros((num_samples, k))

In [7]:
print(lin_module(z0, x).shape)
Winv_trans = lin_module.calW_inv(2., -1.)
print(Winv_trans.shape)
print(lin_module.inverse(Winv_trans, z0).shape)
print(nonlin_module(z0).shape)


(1, 10)
(10, 10)
(1, 10)
(1, 10)


In [8]:
dyn_init, dyn_update = build_peacemanrachford_update(1., lin_module, nonlin_module)
dyn_state = dyn_init(z0, jnp.zeros_like(z0), x)

In [9]:
dyn_state = dyn_update(dyn_state)
print(dyn_state.n_step)
print(dyn_state.objective)
print(jnp.linalg.norm(dyn_state.z))

1
0.9999991
1.0846218


In [10]:
solver_state = solve(dyn_state, dyn_update, max_iter=40, tol=1e-5)
solver_state.min_step.z

DeviceArray([[0.37759763, 0.29325515, 0.19382189, 0.        , 0.        ,
              0.63357264, 0.        , 0.        , 1.1282516 , 0.54578096]],            dtype=float32)

In [11]:
dyn = partial(build_peacemanrachford_update, 1.)
deq_fn = lambda lin_module, x: deq(dyn, 40, 1e-5, nonlin_module, lin_module, x)
deq_fn(lin_module, x)

DeviceArray([[0.37759763, 0.29325515, 0.19382189, 0.        , 0.        ,
              0.63357264, 0.        , 0.        , 1.1282516 , 0.54578096]],            dtype=float32)

In [12]:
z_star, vjp_fun = jax.vjp(deq_fn, lin_module, x)
e0 = jnp.zeros_like(z_star).at[0,0].set(1.)
d_lin_mdl, dx = vjp_fun(e0)
print(z_star)
print(d_lin_mdl.p_A.shape)

[[0.37759763 0.29325515 0.19382189 0.         0.         0.63357264
  0.         0.         1.1282516  0.54578096]]
(10, 10)


## Close form

In [13]:
from utils import commutation_matrix_sp, vec
import numpy as np

In [14]:
K = commutation_matrix_sp(k, k)
P_A = (np.eye(k*k) + K) @ np.kron(np.eye(k), lin_module.p_A)
P_B = (np.eye(k*k) - K)
W = lin_module.calW_trans()

In [15]:
rand_z = np.random.normal(size=(k,))
rand_x = np.random.normal(size=(d,))
auto_Jz = jax.jacrev(lambda z: nonlin_module(lin_module(z.T, rand_x.T)))(rand_z)
auto_Jlin = jax.jacrev(lambda lin_module: nonlin_module(lin_module(rand_z.T, rand_x.T)), allow_int=True)(lin_module)

In [16]:
a = lin_module(rand_z.T, rand_x.T)
Jnl = jnp.diag(nonlin_module.derivative(a))
ana_Jz = Jnl@W.T
jnp.allclose(ana_Jz, auto_Jz, rtol=1e-4)

DeviceArray(True, dtype=bool)

In [17]:
ana_JA = jnp.kron(rand_z.T, Jnl)@P_A
ana_JA = ana_JA.reshape(10, 10, 10, order='C')
jnp.allclose(-auto_Jlin.p_A, ana_JA)

DeviceArray(True, dtype=bool)

In [18]:
ana_JB = jnp.kron(rand_z.T, Jnl)@P_B
ana_JB = ana_JB.reshape(10, 10, 10, order='C')
jnp.allclose(auto_Jlin.p_B, ana_JB)

DeviceArray(True, dtype=bool)

In [19]:
ana_JU = jnp.kron(rand_x.T, Jnl)
ana_JU = ana_JU.reshape(k, d, k, order='C')
jnp.allclose(auto_Jlin.p_U, ana_JU)

DeviceArray(True, dtype=bool)

In [115]:
d_lin_mdl.p_A[0]

DeviceArray([ 0.09708402, -0.15101041, -0.06417657, -0.01218294,
             -0.6068703 , -0.15390113, -0.08026036, -0.9055623 ,
              0.434208  ,  0.6442721 ], dtype=float32)

In [86]:
def deq_forward(dyn, max_iter, tol, nonlin_mdl, lin_mdl, x):
    bs = x.shape[0]
    z0 = jnp.zeros((bs, lin_mdl.out_size))
    u0 = jnp.zeros_like(z0)

    dyn_init, dyn_update = dyn(lin_mdl, nonlin_mdl)
    dyn_state = dyn_init(z0, u0, x)
    solver_state = jax.lax.stop_gradient(solve(dyn_state, dyn_update, max_iter, tol))
    return solver_state.min_step.z, (solver_state.min_step, lin_mdl, x)

def deq_backward(dyn, max_iter, tol, nonlin_mdl, res, g):
    dyn_state = res[0]
    lin_mdl = res[1]
    x = res[2]

    z_star = dyn_state.z
    j = nonlin_mdl.derivative(z_star)
    I = j == 0
    d = (1 - j) / j
    v = j * g

    z0 = jnp.zeros_like(z_star)
    u0 = jnp.zeros_like(z_star)

    def alter_nonlin_mdl_bwd(u):
        zn = (u + dyn_state.alpha*(1 + d)*v) / (1 + dyn_state.alpha*d)
        zn = jax.lax.select(I, v, zn)
        return zn

    dyn_init, dyn_update = dyn(lin_mdl, alter_nonlin_mdl_bwd)
    bwd_dyn_state = dyn_init(z0, u0, x)
    bwd_dyn_state = bwd_dyn_state._replace(
        Winv=bwd_dyn_state.Winv.T,
        bias=jnp.zeros_like(bwd_dyn_state.bias)
    )
    solver_state = solve(bwd_dyn_state, dyn_update, max_iter, tol)

    dg = lin_mdl.W_trans(solver_state.min_step.z)
    dg = g + dg

    # Problem: nonlin_mdl
    _, vjp_lin_fn = jax.vjp(lambda lin_mdl, x: nonlin_mdl(lin_mdl(z_star, x)), lin_mdl, x)

    return vjp_lin_fn(solver_state.min_step.z)
    # return dg, solver_state

In [79]:
z_star, res = deq_forward(dyn, 40, 1e-5, nonlin_module, lin_module, np.expand_dims(rand_x, 0))

In [80]:
res[0].objective

DeviceArray(9.136726e-06, dtype=float32, weak_type=True)

In [82]:
g = e0[0]
Jnl_star = jnp.diag(nonlin_module.derivative(z_star)[0])
uT = g.T @ jnp.linalg.inv(jnp.eye(k) - Jnl_star@W.T)

In [88]:
d_lin, dx = deq_backward(dyn, 40, 1e-5, nonlin_module, res, g)

In [85]:
uT

DeviceArray([ 1.4669073 , -0.5217698 ,  1.334432  ,  1.3785081 ,
              1.2314013 ,  1.2478876 ,  0.07080948,  0.3268739 ,
              0.77704525,  0.85552406], dtype=float32)

In [112]:
d_A = -jnp.kron(z_star, uT.T @ Jnl_star) @ P_A
d_A = d_A.reshape(k, k)
d_B = jnp.kron(z_star, uT.T @ Jnl_star) @ P_B
d_B = d_B.reshape(k, k)
d_U = jnp.kron(rand_x, uT.T @ Jnl_star)
d_U = d_U.reshape(d, k)

In [114]:
print(np.allclose(d_lin.p_A, d_A, atol=1e-4))
print(np.allclose(d_lin.p_B, d_B, atol=1e-4))
print(np.allclose(d_lin.p_U, d_U, atol=1e-4))

True
True
True


## Tangent kernel

In [115]:
dyn = partial(build_peacemanrachford_update, 1.)
def f(u, nonlin_mdl, lin_mdl, X):
    Z_star = deq(dyn, 40, 1e-5, nonlin_mdl, lin_mdl, X)
    return Z_star @ u.T

In [118]:
num_samples = 20
d = 3
k = 10

In [119]:
X = np.random.uniform(size=(num_samples, d))
u = np.random.uniform(size=(k,))

In [122]:
f(u, nonlin_module, lin_module, X).shape

(20,)

In [126]:
f(u, nonlin_module, lin_module, X)

DeviceArray([0.9323009 , 0.94532776, 0.66870034, 0.72705317, 1.1534706 ,
             0.7390838 , 0.94819844, 0.45236948, 0.8993375 , 0.5273541 ,
             0.5459217 , 0.2698567 , 0.4707398 , 0.99770194, 0.26313868,
             0.6860222 , 0.58688456, 0.6808951 , 0.8512972 , 0.45734864],            dtype=float32)

In [132]:
Ju = jax.jacrev(lambda u: f(u, nonlin_module, lin_module, X))(u)
Jlin = jax.jacrev(lambda lin_module: f(u, nonlin_module, lin_module, X), allow_int=True)(lin_module)
JvA = Jlin.p_A.reshape(num_samples, k*k)
JvB = Jlin.p_B.reshape(num_samples, k*k)
JvU = Jlin.p_U.reshape(num_samples, d*k)

In [138]:
H = Ju @ Ju.T + JvA@JvA.T + JvB@JvB.T + JvU@JvU.T

In [146]:
from scipy.linalg import khatri_rao

In [143]:
def deq_backward(dyn, max_iter, tol, nonlin_mdl, res, g):
    dyn_state = res[0]
    lin_mdl = res[1]
    x = res[2]

    z_star = dyn_state.z
    j = nonlin_mdl.derivative(z_star)
    I = j == 0
    d = (1 - j) / j
    v = j * g

    z0 = jnp.zeros_like(z_star)
    u0 = jnp.zeros_like(z_star)

    def alter_nonlin_mdl_bwd(u):
        zn = (u + dyn_state.alpha*(1 + d)*v) / (1 + dyn_state.alpha*d)
        zn = jax.lax.select(I, v, zn)
        return zn

    dyn_init, dyn_update = dyn(lin_mdl, alter_nonlin_mdl_bwd)
    bwd_dyn_state = dyn_init(z0, u0, x)
    bwd_dyn_state = bwd_dyn_state._replace(
        Winv=bwd_dyn_state.Winv.T,
        bias=jnp.zeros_like(bwd_dyn_state.bias)
    )
    solver_state = solve(bwd_dyn_state, dyn_update, max_iter, tol)

    dg = lin_mdl.W_trans(solver_state.min_step.z)
    dg = g + dg
    return dg

In [159]:
Z_star, res = deq_forward(dyn, 40, 1e-5, nonlin_module, lin_module, X)

In [160]:
S = nonlin_module.derivative(Z_star)
Q = deq_backward(dyn, 40, 1e-5, nonlin_module, res, u) * S

Z_Q = khatri_rao(Z_star.T, Q.T).T

In [161]:
Z_Q.shape

(20, 100)

In [167]:
H1 = Z_Q @ (P_B@P_B.T + P_A@P_A.T) @ Z_Q.T
H2 = (X@X.T) * (Q@Q.T)
H3 = Z_star @ Z_star.T
ana_H = H1 + H2 + H3

In [173]:
JvU@JvU.T

DeviceArray([[69.11857 , 61.039223, 61.367313, 66.95027 , 71.91965 ,
              64.20949 , 66.30809 , 60.77679 , 65.97245 , 65.30183 ,
              59.71698 , 62.92219 , 62.581367, 66.26043 , 61.23346 ,
              66.65157 , 67.031555, 59.565083, 61.316883, 60.502037],
             [61.039223, 56.876633, 56.342506, 59.604107, 63.563034,
              58.10205 , 59.92113 , 55.715977, 59.340637, 57.984154,
              55.185734, 56.71503 , 56.68129 , 59.73848 , 55.838337,
              58.767673, 59.54737 , 55.279743, 56.90714 , 55.76693 ],
             [61.367313, 56.342506, 56.450863, 59.85481 , 63.563385,
              58.510517, 59.72461 , 55.946297, 59.466473, 58.24316 ,
              54.748837, 56.793995, 56.59811 , 59.75547 , 56.013367,
              59.104687, 59.53029 , 54.699684, 56.444675, 55.57779 ],
             [66.95027 , 59.604107, 59.85481 , 65.01376 , 69.54812 ,
              62.625984, 64.37154 , 59.399136, 64.08898 , 63.325333,
              58.31639 , 61.272

In [171]:
H2

DeviceArray([[4.0062456 , 1.0844413 , 1.3001921 , 3.1380396 , 4.982688  ,
              1.5227088 , 3.15249   , 0.6211865 , 3.1686196 , 2.3417203 ,
              0.6062444 , 1.3066876 , 1.6496931 , 3.3656104 , 0.5075688 ,
              2.8543026 , 3.0134745 , 0.63231766, 1.2250704 , 0.61497104],
             [1.0844413 , 2.0795681 , 1.4331061 , 0.94945747, 1.7836207 ,
              0.57296604, 1.9232032 , 0.7180931 , 1.6944771 , 0.1816796 ,
              1.2327049 , 0.25719094, 0.90729797, 2.0013313 , 0.27014473,
              0.12801675, 0.6868834 , 1.5046914 , 1.9730439 , 1.0375789 ],
             [1.3001921 , 1.4331061 , 1.4291403 , 1.0878239 , 1.6716452 ,
              0.86909664, 1.6143494 , 0.83607763, 1.7079837 , 0.3283546 ,
              0.6834861 , 0.223831  , 0.7117868 , 1.9059931 , 0.33283687,
              0.35269225, 0.55747366, 0.81230235, 1.3982557 , 0.7361063 ],
             [3.1380396 , 0.94945747, 1.0878239 , 2.5016487 , 3.911267  ,
              1.2393447 , 2.5160673

In [168]:
ana_H

DeviceArray([[10.375988  ,  2.9927504 ,  3.495832  ,  8.104519  ,
              12.009341  ,  4.7157164 ,  7.4550004 ,  1.9976528 ,
               8.172879  ,  6.346464  ,  1.5200535 ,  3.0265582 ,
               3.6540005 ,  8.495546  ,  1.5779381 ,  8.226741  ,
               6.5376105 ,  1.7173073 ,  2.9699621 ,  1.5633265 ],
             [ 2.9927504 ,  6.1198854 ,  3.3045912 ,  2.2963073 ,
               4.69819   ,  1.7088364 ,  4.8904924 ,  1.7216413 ,
               3.5344203 ,  1.2010007 ,  3.8572502 ,  0.925689  ,
               2.7129712 ,  4.147904  ,  0.70102894,  1.4274702 ,
               2.2037253 ,  4.7118397 ,  5.807604  ,  3.0358844 ],
             [ 3.495832  ,  3.3045912 ,  3.8254926 ,  2.9228191 ,
               4.2168164 ,  2.6939468 ,  3.723384  ,  2.4803567 ,
               4.433433  ,  1.0409086 ,  1.4700868 ,  0.55998504,
               1.6004056 ,  4.739892  ,  1.019198  ,  1.2672175 ,
               1.2890918 ,  1.7583156 ,  3.081447  ,  1.6619382 ],
       

In [156]:
H[0]

DeviceArray([214.31097, 193.99396, 191.27907, 206.83755, 220.67776,
             199.69795, 206.90233, 188.57025, 205.31558, 204.30191,
             188.8966 , 193.83742, 194.5462 , 206.71457, 189.1728 ,
             210.99185, 204.8679 , 189.8243 , 193.64226, 189.02676],            dtype=float32)