# Setup

In [None]:
import sys

sys.path.append("../src")
sys.path.append("../../EP2")

In [None]:
import pyrootutils
from hydra import compose, initialize
from omegaconf import DictConfig, open_dict


def get_cfg():
    with initialize(version_base="1.2", config_path="../../EP2/configs"):
        cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[])
        with open_dict(cfg):
            cfg.paths.root_dir = str(pyrootutils.find_root())
            cfg.trainer.max_epochs = 1
            cfg.trainer.limit_train_batches = 0.01
            cfg.trainer.limit_val_batches = 0.1
            cfg.trainer.limit_test_batches = 0.1
            cfg.trainer.accelerator = "cpu"
            cfg.trainer.devices = 1
            cfg.datamodule.num_workers = 0
            cfg.datamodule.pin_memory = False
            cfg.datamodule.batch_size = 1
            cfg.extras.print_config = False
            cfg.extras.enforce_tags = False
            cfg.logger = None

        return cfg

In [None]:
import json


def print_pretty_json(json_obj):
    print(json.dumps(json_obj, indent=4, sort_keys=True))

In [None]:
from src.models.components.E_minimizer import _stepsolve2

# Newton solver 1 vs 2

In [None]:
import torch

from src.models.components.eqprop_backbone import AnalogEP, AnalogEP2

In [None]:
x, y = torch.rand(1, 784).clamp_min(0.01), torch.randint(0, 10, (1,))

### 2

In [None]:
ep2 = AnalogEP2(1)

In [None]:
ep2.model[0].__name__

In [None]:
from functools import partial

from src.rqprop.eqprop_util import init_params

ep2.model.apply(partial(init_params, min=1e-5, max=1))

In [None]:
%%timeit
ep2.forward(x)

In [None]:
list(ep2.model.named_buffers())

### 1

In [None]:
cfg = get_cfg()
cfg.datamodule.batch_size
ep1 = AnalogEP(cfg.datamodule.batch_size, pos_W=True, L=[1e-5] * 2, U=[1] * 2)

In [None]:
%%timeit
nodes = ep1.minimize(x)

In [None]:
n1, n2 = nodes

## condition number

In [None]:
import torch
import torch.linalg as la

A = torch.randn(3, 3)
Lap = torch.cat(
    [
        torch.cat([torch.diag(A.sum(dim=1)), -A], dim=1),
        torch.cat([-A.T, torch.diag(A.sum(dim=0))], dim=1),
    ],
    dim=0,
)

In [None]:
print(la.cond(Lap), la.cond(A))

In [None]:
lu, piv = la.lu_factor(Lap)
# check factorization
torch.allclose(Lap, torch.matmul(lu, torch.eye(6)[piv]))

In [None]:
torch.eye(6)[piv]

In [None]:
piv

# OTS-stability

diode model

I-V curve
I = Is*(exp((V)/Vt)-1)

## Piecewise linear approximation

In [None]:
import sys

sys.path.append("../")

In [None]:
from src.utils import eqprop_util

In [None]:
import matplotlib.pyplot as plt
import torch

x = torch.linspace(-0.0, 1.0, 100)
y = eqprop_util.rectifier_a(x)
y2 = eqprop_util.rectifier_p3_a(x)
# y4 = eqprop_util.rectifier_poly_i(x, power=3)
plt.plot(x, y, label="exponential")
plt.plot(x, y2, label="quadratic")
# plt.plot(x, y4, label="cubic")
# add a legend
plt.legend()
plt.show()

In [None]:
y.max()

In [None]:
def rectifier_pseudo_g(V: torch.Tensor):
    return

In [None]:
x = torch.linspace(-2.2, 2.2, 100)
# plt.plot(x, rectifier_pseudo_g(x))

In [None]:
y = x.exp() - (-x).exp()
p2 = (1 + x + x.pow(2) / 2) - (1 - x + x.pow(2) / 2)
p4 = (1 + x + x.pow(2) / 2 + x.pow(3) / 6) - (1 - x + x.pow(2) / 2 - x.pow(3) / 6)
plt.plot(x, y, label="exponential")
plt.plot(x, p2, label="piecewise linear")
plt.plot(x, p4, label="piecewise linear")
plt.legend()

# Block Cholesky

In [None]:
# make a block laplacian matrix
import torch
import torch.linalg as la

A = torch.randn(3, 3)
Lap = torch.cat(
    [
        torch.cat([torch.diag(A.sum(dim=1)), -A], dim=1),
        torch.cat([-A.T, torch.diag(A.sum(dim=0))], dim=1),
    ],
    dim=0,
)

def add_to_laplacian(Lap:torch.Tensor, A: torch.Tensor):
    m, n = A.shape
    Lap[:-m,:-n] += torch.diag(A.sum(dim=1)
    return torch.cat(
        [
            torch.cat([torch.diag(A.sum(dim=1)), -A], dim=1),
            torch.cat([-A.T, torch.diag(A.sum(dim=0))], dim=1),
        ],
        dim=0,
    )

for _ in range(3):
    A = torch.randn(3, 3)
    Lap = add_to_laplacian(Lap, A)

---

In [None]:
import torch


def block_tri_cholesky(W: list[torch.Tensor]):
    """Blockwise cholesky decomposition for a size varying block tridiagonal matrix.
    see spftrf() in LAPACK

    Args:
        W (List[torch.Tensor]): List of lower triangular blocks.

    Returns:
        L (List[torch.Tensor]): List of lower triangular blocks.
        C (List[torch.Tensor]): List of diagonal blocks. as column vectors.
    """

    n = len(W)
    C = [torch.zeros_like(W[i]) for i in range(n)]
    L = [None] * (n + 1)
    W.append(0)
    L[0] = torch.cholesky(W[0])
    for i in range(n):
        C[i] = torch.triangular_solve(
            W[i], L[i], upper=False
        ).solution  # C[i] = W[i] @ D_prev^-T, trsm()
        D = W[i + 1] - torch.mm(C[i].t(), C[i])  # D = W[i+1] - C[i] @ C[i]^T, syrk()
        L[i + 1] = torch.cholesky(D)
    return L, C


def block_tri_cholesky_solve(L, C, B):
    """Blockwise cholesky solve for a size varing block tridiagonal matrix.

    Args:
        L (List[torch.Tensor]): List of lower triangular blocks.
        C (List[torch.Tensor]): List of diagonal blocks.
        B (torch.Tensor): RHS.

    Returns:
        X (torch.Tensor): Solution.
    """

    n = len(L)
    X = torch.zeros_like(B)
    for i in range(n):
        X[:, i * C[i].size(-1) : (i + 1) * C[i].size(-1)] = torch.cholesky_solve(
            B[:, i * C[i].size(-1) : (i + 1) * C[i].size(-1)],
            L[i + 1] + torch.mm(C[i].t(), C[i]),
        )

    return X

In [None]:
from typing import List

import torch

# Your functions here...


def generate_block_tridiagonal(n: int, block_size: int) -> List[torch.Tensor]:
    """Generate a random block tridiagonal matrix."""
    blocks = [torch.randn(block_size, block_size) for _ in range(n)]
    for block in blocks:
        block += block.t()  # Make it symmetric
        block += block_size * torch.eye(block_size)  # Make it positive definite
    return blocks


# Generate a random block tridiagonal matrix
n = 5
block_size = 3
blocks = generate_block_tridiagonal(n, block_size)

# Perform blockwise Cholesky factorization
L, C = block_tri_cholesky(blocks)

# Generate a random RHS
B = torch.randn(n * block_size)

# Perform blockwise Cholesky solve
X_block = block_tri_cholesky_solve(L, C, B)

# Perform standard Cholesky factorization and solve
A = torch.zeros(n * block_size, n * block_size)
for i in range(n):
    A[i * block_size : (i + 1) * block_size, i * block_size : (i + 1) * block_size] = blocks[i]
    if i < n - 1:
        A[
            i * block_size : (i + 1) * block_size, (i + 1) * block_size : (i + 2) * block_size
        ] = blocks[i]
        A[
            (i + 1) * block_size : (i + 2) * block_size, i * block_size : (i + 1) * block_size
        ] = blocks[i]
L_full = torch.cholesky(A)
X_full = torch.cholesky_solve(B.unsqueeze(1), L_full).squeeze()

# Compare the results
print("Blockwise solution:", X_block)
print("Full solution:", X_full)
print("Difference:", torch.norm(X_block - X_full))

# Laplacian-Tree

## Low level

In [None]:
import torch

torch.backends.cuda.preferred_linalg_library()

In [None]:
w = torch.randn(3, 4).clamp_min(0.01)

In [None]:
Ll = torch.concat([torch.diag(w.sum(dim=1)), w.T], dim=0)
Lr = torch.concat((w, torch.diag(w.sum(dim=0))), dim=0)
L = torch.concat((Ll, Lr), dim=1)

In [None]:
Lp = L + torch.eye(7) * 1e-5

In [None]:
c_2 = torch.linalg.cholesky(Lp)

In [None]:
c_1, info1 = torch.linalg.cholesky_ex(L)

In [None]:
c_3, info2 = torch.linalg.cholesky_ex(Lp)

In [None]:
info1.item() == 7

In [None]:
c_3

In [None]:
torch.allclose(c_2, c_3)

In [None]:
abs(c_3 - c_2).max()

In [None]:
cond = torch.linalg.cond(Lp)

In [None]:
cond

### LAPACK

https://netlib.org/lapack/explore-html/da/dba/group__double_o_t_h_e_rcomputational_gae5d8ecd7fbd852fe3c3f71e08ec8332c.html


In [None]:
import numpy as np
import numpy.linalg as nla
import scipy.linalg as sla

# generate a random positive semi-definite matrix
n = 3
A = np.random.randn(n, n)
B = A @ A.T
l, v = sla.eigh(B)
C = B - l[0] * v[:, 0:1] @ [v[:, 0]]

In [None]:
# factorize the matrix with Cholesky decomposition
U = la.cholesky(C + np.eye(n) * 1e-7)

# compare with lapack wrapper
U2, piv, rank, info = la.lapack.dpstf2(C + np.eye(n) * 1e-7)
print(info)
U3 = nla.cholesky(C + np.eye(n) * 1e-7)

In [None]:
np.round(U - U3, 4)

In [None]:
import torch
import torch.linalg as tla

tB = torch.from_numpy(C)
L, piv = tla.cholesky_ex(tB)

In [None]:
tla.cond(tB)

In [None]:
torch.allclose(tB, L @ L.T)

In [None]:
%%timeit
tla.cholesky_ex(tB)

In [None]:
%%timeit
la.cholesky(C)

In [None]:
%%timeit
nla.cholesky(C)

In [None]:
%%timeit
la.lapack.spstf2(B)

In [None]:
P = np.eye(U2.shape[0])[piv - 1]

In [None]:
np.round(P @ U2.T, 2)

In [None]:
C

In [None]:
la.eigvalsh(C)

In [None]:
C = B - l[0] * v[:, 0:1] @ [v[:, 0]]

In [None]:
la.expm_cond(C)

In [None]:
L @ L.T

## Sparsifier

## Laplacian.jl

In [None]:
import julia
from julia import Base

julia.install()
j = julia.Julia()
j.using("LinearAlgebra")

# CNN

## conv

In [None]:
import torch
import torch.nn as nn

x = torch.rand(1, 1, 8, 8)  # batch size, channels, height, width
convlayer = nn.Conv2d(1, 3, 3, 1, bias=False)  # in_channels, out_channels, kernel_size, stride


def conv2d(x, w):
    return torch.einsum("bchw, oihw -> bco", x, w)


# check if the output is the same
torch.allclose(conv2d(x, convlayer.weight), convlayer(x))

In [None]:
convlayer(x)

## maxpool

use gumbel trick

In [None]:
# change to avgpool

# Ridge Regression

In [None]:
import torch

ckpt_path = "./logs/train/runs/2023-07-06_21-52-51/checkpoints/last.ckpt"
# load weight from checkpoint
model = torch.load(ckpt_path)

In [None]:
model.keys()

In [None]:
model.get("state_dict").keys()

In [None]:
w1 = model.get("state_dict").get("net.model.lin1.weight")
w2 = model.get("state_dict").get("net.model.last.weight")

In [None]:
wt = w1[:, ::2] - w1[:, 1::2]

In [None]:
wt = wt.sum(dim=0).reshape(28, 28)

In [None]:
# plot the weight
import matplotlib.pyplot as plt

plt.imshow(wt.numpy(), interpolation="nearest", cmap="seismic")
plt.colorbar()
plt.show()

In [None]:
plt.imshow(w2.numpy(), interpolation="nearest")