# Setup

In [None]:
import sys

sys.path.append("../src")
sys.path.append("../../EP2")

In [None]:
import pyrootutils
from hydra import compose, initialize
from omegaconf import DictConfig, open_dict


def get_cfg():
    with initialize(version_base="1.2", config_path="../../EP2/configs"):
        cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[])
        with open_dict(cfg):
            cfg.paths.root_dir = str(pyrootutils.find_root())
            cfg.trainer.max_epochs = 1
            cfg.trainer.limit_train_batches = 0.01
            cfg.trainer.limit_val_batches = 0.1
            cfg.trainer.limit_test_batches = 0.1
            cfg.trainer.accelerator = "cpu"
            cfg.trainer.devices = 1
            cfg.datamodule.num_workers = 0
            cfg.datamodule.pin_memory = False
            cfg.datamodule.batch_size = 1
            cfg.extras.print_config = False
            cfg.extras.enforce_tags = False
            cfg.logger = None

        return cfg

In [None]:
import json


def print_pretty_json(json_obj):
    print(json.dumps(json_obj, indent=4, sort_keys=True))

# Newton solver 1 vs 2

In [None]:
import torch

from src._eqprop.eqprop_backbone import AnalogEP, AnalogEP2

In [None]:
x, y = torch.rand(1, 784).clamp_min(0.01), torch.randint(0, 10, (1,))

### 2

In [None]:
ep2 = AnalogEP2(1)

In [None]:
from functools import partial

from src.rqprop.eqprop_util import init_params

ep2.model.apply(partial(init_params, min=1e-5, max=1))

In [None]:
%%timeit
ep2.forward(x)

In [None]:
list(ep2.model.named_buffers())

### 1

In [None]:
cfg = get_cfg()
cfg.datamodule.batch_size
ep1 = AnalogEP(cfg.datamodule.batch_size, pos_W=True, L=[1e-5] * 2, U=[1] * 2)

In [None]:
%%timeit
nodes = ep1.minimize(x)

In [None]:
n1, n2 = nodes

# condition number

In [None]:
import matplotlib.pyplot as plt

# visualize weights
import torch

w = torch.randn(28, 28)
plt.imshow(w, cmap="viridis")
# add colorbar
plt.colorbar()

In [None]:
# check matrix spectral density
w = torch.randn(10000, 10000)
# find eigenvalues
eigvals = torch.linalg.eigvals(w)

# plot histogram
plt.hist(eigvals, bins=100)

In [None]:
# plot marchenko-pastur distribution
import numpy as np

sigma = 1
m = 1000
n = 100
ratio = m / n

X = np.random.normal(0, sigma, (m, n))
# singular values
s = np.linalg.svd(X, compute_uv=False) / n


def mu_plus_minus(sigma, ratio, s):
    return sigma * (1 + np.sqrt(ratio)) ** 2, sigma * (1 - np.sqrt(ratio)) ** 2


plt.hist(s, bins=100, density=True)
plt.show()

# VS scipy.optimize 

In [None]:
# setup
import torch
import torch.nn as nn

# load model checkpoint
ckpt = torch.load("../logs/train/runs/2023-07-13_17-53-04/checkpoints/epoch_002.ckpt")
# get weights from lin1 & last layers
w1 = ckpt["state_dict"]["net.model.lin1.weight"]
w2 = ckpt["state_dict"]["net.model.last.weight"]
# get biases from lin1 & last layers
b1 = ckpt["state_dict"]["net.model.lin1.bias"]
b2 = ckpt["state_dict"]["net.model.last.bias"]
# get input & output dimensions

In [None]:
w1.shape, w2.shape, b1.shape, b2.shape

In [None]:
# sample input from MNIST dataset
from src.data.mnist_datamodule import MNISTDataModule

dm = MNISTDataModule(batch_size=1, data_dir="../data")
dm.setup()
x, y = next(iter(dm.train_dataloader()))
from src._eqprop.eqprop_module import EqPropLitModule

x = EqPropLitModule.preprocessing_input(x)

In [None]:
from src.core.eqprop.eqprop_util import OTS, P3OTS
from src.eqprop.E_minimizer import _stepsolve2

dims = [2 * 28 * 28, 128, 10 * 2]
W = [w1, w2]
B = [b1, b2]
v1 = _stepsolve2(x, W, dims, B, i_ext=0, OTS=OTS(), max_iter=30, atol=1e-6)
v2 = _stepsolve2(x, W, dims, B, i_ext=0, OTS=P3OTS(), max_iter=30, atol=1e-6)

In [None]:
%%timeit
_stepsolve2(x, W, [28 * 28 * 2, 128, 10 * 2], B, i_ext=0, OTS=P3OTS(), max_iter=30, atol=1e-7)

In [None]:
_stepsolve2(x, W, dims, B, i_ext=0, OTS=OTS(), max_iter=30, atol=1e-6)

In [None]:
size = sum(dims[1:])
# construct the laplacian
paddedG = [torch.zeros(dims[1], size).type_as(x)]
for i, g in enumerate(W[1:]):
    paddedG.append(torch.functional.pad(-g, (sum(dims[1 : i + 1]), sum(dims[2 + i :]))))

Ll = torch.cat(paddedG, dim=-2)
L = Ll + Ll.mT

In [None]:
yhat = v1.split(dims[1:], dim=1)[1].squeeze()
(yhat[::2] - yhat[1::2]).argmax()

In [None]:
from src.core.eqprop.eqprop_util import OTS, P3OTS

In [None]:
ots = OTS(Vl=-0, Vr=0, Is=1e-6, Vth=0.026)
p3ots = P3OTS(Vl=-0, Vr=0, Is=1e-6, Vth=0.026)
import matplotlib.pyplot as plt

# plot the OTS function
import torch

x = torch.linspace(-1, 1, 100)
plt.plot(x, ots.i(x))

In [None]:
import numpy as np
import torch

# find the root of the OTS function using Newton's method
from scipy.optimize import fsolve, root


def np_wrapper(x: np.ndarray):
    x = torch.tensor(x)
    return p3ots.i(x).detach().numpy()


x0 = np.random.rand(1000) * 2

In [None]:
%%timeit
res, info, _, __ = fsolve(np_wrapper, x0=x, full_output=True)

In [None]:
import time

residuals_history = {}
duration_history = {}
methods = [
    "krylov",
    "df-sane",
]  # , 'broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing']


def callback(residual, method):
    if method not in residuals_history:
        residuals_history[method] = []
        duration_history[method] = []
        t = time.time()
    new_t = time.time() - t
    duration_history[method].append(new_t)
    residuals_history[method].append(np.linalg.norm(residual, ord=np.inf))
    t = time.time()


def modified_callback(x, residual=None, method=None):
    if residual is None:
        residual = x
    callback(residual, method)


for method in methods:
    residuals_history[method] = []  # Reset the residuals history for each method
    try:
        sol = root(
            np_wrapper,
            np.random.rand(10),
            method=method,
            callback=lambda x, res=None: modified_callback(x, res, method),
            options={"fatol": 1e-7, "disp": True},
        )
    except Exception:
        # Some methods might still not accept the callback or might throw other errors
        print(f"skip {method}")

# Plotting the residuals at each step
plt.figure(figsize=(14, 7))
# colors = plt.cm.get_cmap('tab10').colors
for idx, (method, res) in enumerate(residuals_history.items()):
    if res:  # Only plot methods that have residuals recorded
        # delete outliers
        t = np.array(duration_history[method])
        res = np.array(res)
        res[res > 1e4] = np.nan
        res = res[~np.isnan(res)]
        plt.plot(t, res, label=method, marker="o", markersize=5)


plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Residual")
plt.title("Convergence of Residuals for Different Methods")
plt.legend(loc="best")
plt.grid(True, which="both", ls="--", linewidth=0.5)
plt.show()

In [None]:
# import time

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import root


# Define the function using numpy again
def i_numpy(x):
    x = torch.tensor(x)
    return p3ots.i(x).detach().numpy()


# Define the methods to test
methods_to_test = ["hybr", "lm", "df-sane", "krylov"]


def maxiter_method(method):
    if method in ["krylov", "lm"]:
        return {"maxiter": 7 if method == "krylov" else 10}
    elif method in ["hybr", "df-sane"]:
        return {"maxfev": 1500 if method == "df-sane" else 40}
    else:
        return 0


# Find roots and record the elapsed time and final residuals again
elapsed_time_results = {}
residuals_results = {}

for method in methods_to_test:
    start_time = time.time()
    sol = root(i_numpy, x0=np.random.rand(100) * 2, method=method, options=maxiter_method(method))
    end_time = time.time()

    elapsed_time_results[method] = end_time - start_time
    residuals_results[method] = np.abs(sol.fun[0])

# Plotting the results again
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))

ax1.bar(
    elapsed_time_results.keys(),
    elapsed_time_results.values(),
    color=["blue", "red", "green", "purple"],
)
ax1.set_ylabel("Elapsed Time (seconds)")
ax1.set_title("Elapsed Time for Different Methods")

ax2.bar(
    residuals_results.keys(), residuals_results.values(), color=["blue", "red", "green", "purple"]
)
ax2.set_yscale("log")
ax2.set_ylabel("Final Residuals")
ax2.set_title("Final Residuals for Different Methods")

fig.tight_layout()
plt.show()

### ScipyStrategy

In [None]:
from src.core.eqprop.strategy import ScipyStrategy

st = ScipyStrategy()

# OTS-stability

diode model

I-V curve
I = Is*(exp((V)/Vt)-1)

## Piecewise linear approximation

In [None]:
import matplotlib.pyplot as plt
import torch

from src.eqprop import eqprop_util

x = torch.linspace(-1.0, 1.0, 100)
ots = eqprop_util.OTS(Vl=-0, Vr=0, Is=1e-6, Vth=0.026)
p3ots = eqprop_util.P3OTS(Vl=-0, Vr=0, Is=1e-6, Vth=0.026)
y = ots.i(x)
y2 = p3ots.i(x)
# y4 = eqprop_util.rectifier_poly_i(x, power=3)
plt.plot(x, y, label="exponential")
plt.plot(x, y2, label="quadratic")
# plt.plot(x, y4, label="cubic")
# add a legend
plt.legend()
plt.show()

In [None]:
symtanh = eqprop_util.Symtanh(Vl=0.3, Vr=0.7, Is=1, Vth=0.2)
y = symtanh.i(x)
b = symtanh.a(x)
plt.plot(x, y, label="symtanh")
plt.plot(x, b, label="symtanh a")

In [None]:
x = torch.linspace(-2.2, 2.2, 100)
# plt.plot(x, rectifier_pseudo_g(x))

In [None]:
y = x.exp() - (-x).exp()
p2 = (1 + x + x.pow(2) / 2) - (1 - x + x.pow(2) / 2)
p4 = (1 + x + x.pow(2) / 2 + x.pow(3) / 6) - (1 - x + x.pow(2) / 2 - x.pow(3) / 6)
plt.plot(x, y, label="exponential")
plt.plot(x, p2, label="piecewise linear")
plt.plot(x, p4, label="piecewise linear")
plt.legend()

## SymOTS

In [None]:
import torch

from src.core.eqprop.eqprop_util import OTS, P3OTS, SymOTS

ots = OTS(Vl=-0.5, Vr=0.5)
symots = SymOTS(Vl=-0.5, Vr=0.5)

x = torch.logspace(-0.01, 0.01, 3000)

inv_a1 = 1 / ots.a(x)
inv_a2 = 1 / symots.a(x)
inv_a3 = symots.inv_a(x)


idiva = ots.i(x) / ots.a(x)
idiva2 = symots.i_div_a(x)

import matplotlib.pyplot as plt

# plt.plot(x, idiva, label="exponential")
plt.plot(x, idiva2 - idiva, label="exponential")

In [None]:
def maxmexp(V):
    xr = (V - 0.5) / 0.026
    xl = (-V - 0.5) / 0.026
    xmax = torch.max(xr, xl)
    return 0.026 * (
        (torch.exp(xr - xmax) - torch.exp(xl - xmax))
        / (torch.exp(xr - xmax) + torch.exp(xl - xmax))
    )

In [None]:
plt.plot(x, maxmexp(x) - idiva2, label="exponential")

# Block Cholesky

In [None]:
# make a block laplacian matrix
import torch
import torch.linalg as la

A = torch.randn(3, 3)
Lap = torch.cat(
    [
        torch.cat([torch.diag(A.sum(dim=1)), -A], dim=1),
        torch.cat([-A.T, torch.diag(A.sum(dim=0))], dim=1),
    ],
    dim=0,
)


def add_to_laplacian(Lap: torch.Tensor, A: torch.Tensor):
    m, n = A.shape
    Lap[:-m, :-n] += torch.diag(A.sum(dim=1))
    return torch.cat(
        [
            torch.cat([torch.diag(A.sum(dim=1)), -A], dim=1),
            torch.cat([-A.T, torch.diag(A.sum(dim=0))], dim=1),
        ],
        dim=0,
    )


for _ in range(3):
    A = torch.randn(3, 3)
    Lap = add_to_laplacian(Lap, A)

---

In [None]:
import torch


def block_tri_cholesky(W: list[torch.Tensor]):
    """Blockwise cholesky decomposition for a size varying block tridiagonal matrix.
    see spftrf() in LAPACK

    Args:
        W (List[torch.Tensor]): List of lower triangular blocks.

    Returns:
        L (List[torch.Tensor]): List of lower triangular blocks.
        C (List[torch.Tensor]): List of diagonal blocks. as column vectors.
    """

    n = len(W)
    C = [torch.zeros_like(W[i]) for i in range(n)]
    L = [None] * (n + 1)
    W.append(0)
    L[0] = torch.cholesky(W[0])
    for i in range(n):
        C[i] = torch.triangular_solve(
            W[i], L[i], upper=False
        ).solution  # C[i] = W[i] @ D_prev^-T, trsm()
        D = W[i + 1] - torch.mm(C[i].t(), C[i])  # D = W[i+1] - C[i] @ C[i]^T, syrk()
        L[i + 1] = torch.cholesky(D)
    return L, C


def block_tri_cholesky_solve(L, C, B):
    """Blockwise cholesky solve for a size varing block tridiagonal matrix.

    Args:
        L (List[torch.Tensor]): List of lower triangular blocks.
        C (List[torch.Tensor]): List of diagonal blocks.
        B (torch.Tensor): RHS.

    Returns:
        X (torch.Tensor): Solution.
    """

    n = len(L)
    X = torch.zeros_like(B)
    for i in range(n):
        X[:, i * C[i].size(-1) : (i + 1) * C[i].size(-1)] = torch.cholesky_solve(
            B[:, i * C[i].size(-1) : (i + 1) * C[i].size(-1)],
            L[i + 1] + torch.mm(C[i].t(), C[i]),
        )

    return X

In [None]:
import torch

# Your functions here...


def generate_block_tridiagonal(n: int, block_size: int) -> list[torch.Tensor]:
    """Generate a random block tridiagonal matrix."""
    blocks = [torch.randn(block_size, block_size) for _ in range(n)]
    for block in blocks:
        block += block.t()  # Make it symmetric
        block += block_size * torch.eye(block_size)  # Make it positive definite
    return blocks


# Generate a random block tridiagonal matrix
n = 5
block_size = 3
blocks = generate_block_tridiagonal(n, block_size)

# Perform blockwise Cholesky factorization
L, C = block_tri_cholesky(blocks)

# Generate a random RHS
B = torch.randn(n * block_size)

# Perform blockwise Cholesky solve
X_block = block_tri_cholesky_solve(L, C, B)

# Perform standard Cholesky factorization and solve
A = torch.zeros(n * block_size, n * block_size)
for i in range(n):
    A[i * block_size : (i + 1) * block_size, i * block_size : (i + 1) * block_size] = blocks[i]
    if i < n - 1:
        A[i * block_size : (i + 1) * block_size, (i + 1) * block_size : (i + 2) * block_size] = (
            blocks[i]
        )
        A[(i + 1) * block_size : (i + 2) * block_size, i * block_size : (i + 1) * block_size] = (
            blocks[i]
        )
L_full = torch.cholesky(A)
X_full = torch.cholesky_solve(B.unsqueeze(1), L_full).squeeze()

# Compare the results
print("Blockwise solution:", X_block)
print("Full solution:", X_full)
print("Difference:", torch.norm(X_block - X_full))

# Laplacian-Tree

## Low level

In [None]:
import torch

torch.backends.cuda.preferred_linalg_library()

In [None]:
w = torch.randn(3, 4).clamp_min(0.01)

In [None]:
Ll = torch.concat([torch.diag(w.sum(dim=1)), w.T], dim=0)
Lr = torch.concat((w, torch.diag(w.sum(dim=0))), dim=0)
L = torch.concat((Ll, Lr), dim=1)

In [None]:
Lp = L + torch.eye(7) * 1e-5

In [None]:
c_2 = torch.linalg.cholesky(Lp)

In [None]:
c_1, info1 = torch.linalg.cholesky_ex(L)

In [None]:
c_3, info2 = torch.linalg.cholesky_ex(Lp)

In [None]:
c_3

In [None]:
torch.allclose(c_2, c_3)

In [None]:
abs(c_3 - c_2).max()

In [None]:
cond = torch.linalg.cond(Lp)

In [None]:
cond

### LAPACK

https://netlib.org/lapack/explore-html/da/dba/group__double_o_t_h_e_rcomputational_gae5d8ecd7fbd852fe3c3f71e08ec8332c.html


In [None]:
import numpy as np
import numpy.linalg as nla
import scipy.linalg as sla

# generate a random positive semi-definite matrix
n = 3
A = np.random.randn(n, n)
B = A @ A.T
left, v = sla.eigh(B)
C = B - left[0] * v[:, 0:1] @ [v[:, 0]]

In [None]:
# factorize the matrix with Cholesky decomposition
U = nla.cholesky(C + np.eye(n) * 1e-7)

# compare with lapack wrapper
U2, piv, rank, info = sla.lapack.dpstf2(C + np.eye(n) * 1e-7)
print(info)
U3 = nla.cholesky(C + np.eye(n) * 1e-7)

In [None]:
np.round(U - U3, 4)

In [None]:
import torch
import torch.linalg as tla

tB = torch.from_numpy(C)
L, piv = tla.cholesky_ex(tB)

In [None]:
tla.cond(tB)

In [None]:
torch.allclose(tB, L @ L.T)

In [None]:
%%timeit
tla.cholesky_ex(tB)

In [None]:
%%timeit
tla.cholesky(C)

In [None]:
%%timeit
nla.cholesky(C)

In [None]:
%%timeit
sla.lapack.spstf2(B)

In [None]:
P = np.eye(U2.shape[0])[piv - 1]

In [None]:
np.round(P @ U2.T, 2)

In [None]:
tla.eigvalsh(C)

In [None]:
C = B - left[0] * v[:, 0:1] @ [v[:, 0]]

In [None]:
tla.expm_cond(C)

In [None]:
L @ L.T

## Sparsifier

## Laplacian.jl

In [None]:
# use julia bindings
import julia

julia.install()
from julia import Base

Base.sind(90)

# CNN

## conv

In [None]:
import torch

x = torch.rand(1, 1, 8, 8)  # batch size, channels, height, width
convlayer = nn.Conv2d(1, 3, 3, 1, bias=False)  # in_channels, out_channels, kernel_size, stride


def conv2d(x, w):
    return torch.einsum("bchw, oihw -> bco", x, w)


# check if the output is the same
torch.allclose(conv2d(x, convlayer.weight), convlayer(x))

In [None]:
convlayer(x)

## maxpool

use gumbel trick

In [None]:
# change to avgpool

# Ridge Regression

In [None]:
import torch

ckpt_path = "./logs/train/runs/2023-07-06_21-52-51/checkpoints/last.ckpt"
# load weight from checkpoint
model = torch.load(ckpt_path)

In [None]:
model.keys()

In [None]:
model.get("state_dict").keys()

In [None]:
w1 = model.get("state_dict").get("net.model.lin1.weight")
w2 = model.get("state_dict").get("net.model.last.weight")

In [None]:
wt = w1[:, ::2] - w1[:, 1::2]

In [None]:
wt = wt.sum(dim=0).reshape(28, 28)

In [None]:
# plot the weight
import matplotlib.pyplot as plt

plt.imshow(wt.numpy(), interpolation="nearest", cmap="seismic")
plt.colorbar()
plt.show()

In [None]:
plt.imshow(w2.numpy(), interpolation="nearest")