In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
from jax import random, numpy as jnp
from mondeq.modules import MonLinear, Relu, deq
from mondeq.splittings import build_peacemanrachford_update, solve
from functools import partial

In [3]:
num_samples = 1
d = 3
k = 10

In [4]:
key = random.PRNGKey(0)



In [5]:
key, skey = random.split(key)
lin_module = MonLinear(3, 10, key=skey)
nonlin_module = Relu()

In [6]:
key, *skey = random.split(key, 3)
x = random.normal(skey[0], (num_samples, d))
z0 = jnp.zeros((num_samples, k))

In [7]:
print(lin_module(z0, x).shape)
Winv_trans = lin_module.calW_inv(2., -1.)
print(Winv_trans.shape)
print(lin_module.inverse(Winv_trans, z0).shape)
print(nonlin_module(z0).shape)


(1, 10)
(10, 10)
(1, 10)
(1, 10)


In [8]:
dyn_init, dyn_update = build_peacemanrachford_update(1., lin_module, nonlin_module)
dyn_state = dyn_init(z0, jnp.zeros_like(z0), x)

In [9]:
dyn_state = dyn_update(dyn_state)
print(dyn_state.n_step)
print(dyn_state.objective)
print(jnp.linalg.norm(dyn_state.z))

1
0.9999991
1.0846218


In [10]:
solver_state = solve(dyn_state, dyn_update, max_iter=40, tol=1e-5)
solver_state.min_step.z

DeviceArray([[0.37759763, 0.29325515, 0.19382189, 0.        , 0.        ,
              0.63357264, 0.        , 0.        , 1.1282516 , 0.54578096]],            dtype=float32)

In [11]:
dyn = partial(build_peacemanrachford_update, 1.)
deq_fn = lambda lin_module, x: deq(dyn, 40, 1e-5, nonlin_module, lin_module, x)
deq_fn(lin_module, x)

DeviceArray([[0.37759763, 0.29325515, 0.19382189, 0.        , 0.        ,
              0.63357264, 0.        , 0.        , 1.1282516 , 0.54578096]],            dtype=float32)

In [13]:
z_star, vjp_fun = jax.vjp(deq_fn, lin_module, x)
e0 = jnp.zeros_like(z_star).at[0,0].set(1.)
d_lin_mdl, dx = vjp_fun(e0)
print(z_star)
print(d_lin_mdl.p_A.shape)

[[0.37759763 0.29325515 0.19382189 0.         0.         0.63357264
  0.         0.         1.1282516  0.54578096]]
(10, 10)


In [14]:
from utils import commutation_matrix_sp, vec
import numpy as np

In [97]:
K = commutation_matrix_sp(k, k)
P_A = (np.eye(k*k) + K) @ np.kron(np.eye(k), lin_module.p_A)
P_B = (np.eye(k*k) - K)
W = lin_module.calW_trans()

In [89]:
auto_Jz = jax.jacrev(lambda z: nonlin_module(lin_module(z.T, x[0].T)))(z0[0])
a = lin_module(z0[0].T, x[0].T)
Jnl = jnp.diag(nonlin_module.derivative(a))
ana_Jz = Jnl@W.T
jnp.allclose(ana_Jz, auto_Jz, rtol=1e-4)

DeviceArray(True, dtype=bool)

In [85]:
auto_Jlin = jax.jacrev(lambda lin_module: nonlin_module(lin_module(z0[0].T, x[0].T)), allow_int=True)(lin_module)

In [111]:
ana_JA = jnp.kron(z0[0].T, Jnl)@P_A
ana_JA = ana_JA.reshape(10, 10, 10)
jnp.allclose(auto_Jlin.p_A, ana_JA)

DeviceArray(True, dtype=bool)

In [110]:
ana_JB = jnp.kron(z0[0].T, Jnl)@P_B
ana_JB = ana_JB.reshape(10, 10, 10)
jnp.allclose(auto_Jlin.p_B, ana_JB)

DeviceArray(True, dtype=bool)

In [109]:
ana_JU = jnp.kron(x[0].T, Jnl)
ana_JU = ana_JU.reshape(k, d, k)
jnp.allclose(auto_Jlin.p_U, ana_JU)

DeviceArray(True, dtype=bool)

In [104]:
ana_JU.shape

(10, 30)

In [57]:
u = e0[0]

In [112]:
Jnl_star = jnp.diag(nonlin_module.derivative(a))
qT = u.T @ jnp.linalg.inv(jnp.eye(k) - Jnl@W) * jnp.diag(Jnl)

In [113]:
J_A = np.kron(z_star[0], qT).T @ P_A
J_A = J_A.reshape(10, 10, order='F')

In [116]:
J_A

matrix([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [115]:
d_lin_mdl.p_A[0]

DeviceArray([ 0.09708402, -0.15101041, -0.06417657, -0.01218294,
             -0.6068703 , -0.15390113, -0.08026036, -0.9055623 ,
              0.434208  ,  0.6442721 ], dtype=float32)