# setup

In [None]:
import numpy as np
import torch
from scipy.optimize import fsolve

from src.eqprop.eqprop_util import OTS, P3OTS

# prepare data

In [None]:
# dummy example
batch_size = 32
dims = [784, 100, 20]
hdims = dims[1:]
x = np.random.rand(batch_size, dims[0])
W, B = [], []
for i in range(len(dims) - 1):
    W.append(np.random.rand(dims[i], dims[i + 1]).T)
    B.append(np.random.rand(dims[i + 1]))
# i_ext = np.random.rand(batch_size, dims[-1])
i_ext = np.zeros((batch_size, dims[-1]))
ots = P3OTS(Is=1e-6, Vth=1, Vl=0, Vr=0)
B = np.concatenate(B, axis=-1)

In [None]:
# load from model ckpt
ckpt_path = ".ckpt"
ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu")

# fsolve

In [None]:
def _lap(W: tuple[np.ndarray]):
    if hasattr(_lap, "L"):
        return _lap.L
    global hdims
    size = sum(hdims)

    paddedG = [np.zeros((hdims[0], size))]
    for i, g in enumerate(W[1:]):
        padding = (
            (0, 0),
            (sum(hdims[:i]), sum(hdims[i + 1 :])),
        )  # dim 0 (before, after), dim 1 (before, after)
        paddedG.append(np.pad(-g, padding))
    lower = np.concatenate(paddedG, axis=-2)
    L = lower + lower.T
    D0 = -lower.sum(axis=-2) - lower.sum(axis=-1) + np.pad(W[0].sum(axis=-1), (0, size - hdims[0]))
    L += np.diag(D0)
    _lap.L = L
    return L


def f(v: np.ndarray, x: np.ndarray, W: tuple[np.ndarray], B: np.ndarray, i_ext: np.ndarray):
    L = _lap(W)
    B = B.copy()
    B[: hdims[0]] += x @ W[0].T
    if i_ext is not None:
        B[-hdims[-1] :] += i_ext
    # initial solution
    return L @ v - B + ots.i(torch.from_numpy(v)).numpy()


def jac(v: np.ndarray, W: tuple[np.ndarray], *args):
    L = _lap(W)
    return L + ots.a(torch.from_numpy(v)).numpy()

In [None]:
v0 = np.linalg.solve(_lap(W), B)

In [None]:
for batch_idx in range(batch_size):
    sol, info, _, msg = fsolve(
        f, v0, fprime=jac, args=(x[batch_idx], W, B, i_ext[batch_idx]), full_output=True
    )
    # print(msg)

# compare with newton solver

In [None]:
# load model

# end