# setup

In [None]:
import numpy as np
import torch
from scipy.optimize import fsolve

from src.eqprop.eqprop_util import OTS, P3OTS

# prepare data

In [None]:
# setup parameters
batch_size = 4
dims = [784 * 2, 128, 10 * 4]
hdims = dims[1:]

ots = OTS(Is=1e-6, Vth=0.2, Vl=0, Vr=0)
i_ext = np.zeros((batch_size, dims[-1]))

In [None]:
# setup random weights and biases
W, B = [], []
for i in range(len(dims) - 1):
    W.append(np.random.rand(dims[i], dims[i + 1]).T)  # random input in [0, 1]
    B.append(np.random.rand(dims[i + 1]))

B = np.concatenate(B, axis=-1)

In [None]:
ckpt_path = "../logs/train/runs/2023-11-19_17-59-54/checkpoints/last.ckpt"
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt["state_dict"]
W, B = [], []
for k, v in ckpt["state_dict"].items():
    if "weight" in k:
        W.append(v.numpy())
    elif "bias" in k:
        B.append(v.numpy())
B_cat = np.concatenate(B, axis=-1)

In [None]:
# dummy example
x = np.random.rand(batch_size, dims[0])

In [None]:
# load mnist data
from src.data.mnist_datamodule import MNISTDataModule

dm = MNISTDataModule(data_dir="../data/", batch_size=batch_size)
dm.setup()
dl = dm.train_dataloader().__iter__()

In [None]:
x, y = dl._next_data()
x = x.view(x.size(0), -1)  # == x.view(-1,x.size(-1)**2)
x = x.repeat_interleave(2, dim=1)
x[:, 1::2] = -x[:, ::2]
x = x.numpy()
y = y.numpy()

# fsolve

In [None]:
def _lap(W: tuple[np.ndarray]):
    if hasattr(_lap, "L"):
        return _lap.L
    global hdims
    size = sum(hdims)

    paddedG = [np.zeros((hdims[0], size))]
    for i, g in enumerate(W[1:]):
        padding = (
            (0, 0),
            (sum(hdims[:i]), sum(hdims[i + 1 :])),
        )  # dim 0 (before, after), dim 1 (before, after)
        paddedG.append(np.pad(-g, padding))
    lower = np.concatenate(paddedG, axis=-2)
    L = lower + lower.T
    D0 = -lower.sum(axis=-2) - lower.sum(axis=-1) + np.pad(W[0].sum(axis=-1), (0, size - hdims[0]))
    L += np.diag(D0)
    _lap.L = L
    return L


def f(v: np.ndarray, x: np.ndarray, W: tuple[np.ndarray], B: np.ndarray, i_ext: np.ndarray):
    L = _lap(W)
    B = B.copy()
    B[: hdims[0]] += x @ W[0].T
    if i_ext is not None:
        B[-hdims[-1] :] += i_ext
    # initial solution
    return L @ v - B + ots.i(torch.from_numpy(v)).numpy()


def jac(v: np.ndarray, W: tuple[np.ndarray], *args):
    L = _lap(W)
    return L + ots.a(torch.from_numpy(v)).numpy()

In [None]:
# expand B_cat to batch size
v0_arr = np.empty((batch_size, sum(dims[1:])), dtype=np.float64)
for i in range(batch_size):
    B_init = B_cat.copy()
    B_init[: hdims[0]] += x[i] @ W[0].T
    v0_arr[i] = np.linalg.solve(_lap(W), B_init)
sol_arr = np.empty((batch_size, sum(dims[1:])), dtype=np.float64)

In [None]:
for batch_idx in range(batch_size):
    sol, info, _, msg = fsolve(
        f,
        v0_arr[batch_idx],
        fprime=jac,
        args=(x[batch_idx], W, B_cat, i_ext[batch_idx]),
        full_output=True,
    )
    sol_arr[batch_idx] = sol
    print(msg)

In [None]:
info

# compare with newton strategy

In [None]:
W_t = [torch.from_numpy(w) for w in W]
B_t = [torch.from_numpy(b) for b in B]

In [None]:
# set pylogger to debug to see the convergence
import logging
from venv import logger

from src.eqprop.strategy import NewtonStrategy

logger.setLevel(logging.DEBUG)

strategy = NewtonStrategy(clip_threshold=1, max_iter=100, atol=1e-6, activation=ots)
strategy.check_and_set_attrs({"dims": hdims})

In [None]:
# n_sol = strategy.solve(torch.from_numpy(x), torch.from_numpy(i_ext), params=(W_t, B_t), dims=hdims)
# n_sol = torch.cat(n_sol, dim=-1).numpy()
n_sol = strategy._densecholsol(torch.from_numpy(x), W_t, B_t).numpy()

# Analysis

In [None]:
# compute cosine distance (1-cosine similarity) between sol and v0
from scipy.spatial.distance import cosine

cosine(sol_arr[-1], v0_arr[-1])

In [None]:
cosine(n_sol[-1], sol_arr[-1])

In [None]:
# visualize solution
import matplotlib.pyplot as plt

plt.plot(sol_arr[-1], label="sol")
plt.plot(v0_arr[-1], label="v0")
plt.plot(n_sol[-1], label="n_sol")
plt.legend()

In [None]:
plt.plot(sol_arr[-1] - v0_arr[-1], label="err")

## linear sol + correction

instead of solving $Lv=b-f(v)$ iteratively, we can solve $Lv=b$ and then correct the solution by adding $-L^\dagger f(v)$

In [None]:
cor = np.linalg.pinv(_lap(W)) @ ots.i(torch.from_numpy(v0_arr[-1])).numpy()
plt.plot(cor)

In [None]:
plt.plot(sol_arr[-1], label="sol", marker="o")
plt.plot(v0_arr[-1] - cor, label="v_cor")
plt.legend()

In [None]:
plt.plot(sol_arr[-1] - v0_arr[-1] + cor, label="err", marker="o")

# end