2 by 2 XOR problem

In [None]:
import torch
from omegaconf import OmegaConf

In [None]:
# prepare XOR data
X = torch.tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]).float()
y = torch.tensor([[0], [1], [1], [0]])

cfg = """
net:
  _target_: src._eqprop.eqprop_backbone.AnalogEP2
  _partial_: true
  batch_size: 1
  dims: [2,2,1]
  beta: 0.01
  solver:
    _target_: src.core.eqprop.solver.AnalogEqPropSolver
    _partial_: true
    amp_factor: 1.0
    beta: ${net.beta}
    strategy:
      _target_: src.core.eqprop.strategy.NewtonStrategy
      clip_threshold: 0.1
      amp_factor: ${net.solver.amp_factor}
      max_iter: 5
      atol: 1e-5
      activation:
        _target_: src.core.eqprop.eqprop_util.P3OTS
        Is: 1e-6
        Vth: 1
        Vl: 0
        Vr: 0"""

cfg = OmegaConf.create(cfg)

In [None]:
from hydra.utils import instantiate

net = instantiate(cfg.net)
net = net(hyper_params={"bias": True})

In [None]:
# train
import torch.nn.functional as F

from src.core.eqprop.eqprop_util import init_params

optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss(reduction="sum")
net.apply(init_params(min_w=1e-6, max_w_gain=0.3))

for i in range(1000):
    x = X[i % 4].unsqueeze(0)
    optimizer.zero_grad()
    logit = net(x)
    yhat = F.sigmoid(logit)
    y_onehot = F.one_hot(y[i % 4], 2).float()
    loss = loss_fn(yhat, y_onehot)
    loss.backward()
    net.eqprop(x)
    optimizer.step()
    if i % 100 == 0:
        print(net.loss.item())

In [None]:
for buf in net.named_buffers():
    print(buf[0], buf[1].shape)

# XOR2

In [None]:
import torch

from src.core.eqprop.eqprop_util import P3OTS, SymReLU

# xor for newton strategy
from src.core.eqprop.strategy import NewtonStrategy

model = torch.nn.Sequential(torch.nn.Linear(2, 1, bias=True), torch.nn.Linear(1, 1, bias=True))
model[0].weight.data = torch.tensor([[1.0, 1.0]])
model[0].bias.data = torch.tensor([1.0])
model[1].weight.data = torch.tensor([[2.0]])
model[1].bias.data = torch.tensor([0.0])
st = NewtonStrategy(
    activation=P3OTS(Is=1e-6, Vl=-0.6, Vr=0.6),
    clip_threshold=0.1,
    amp_factor=1.0,
    max_iter=5,
    atol=1e-5,
    add_nonlin_last=False,
)
st.set_strategy_params(model)

inputs = torch.tensor([[-1, -1]]).float()

import matplotlib.pyplot as plt

# 2d vector map for the residual using quiver
import numpy as np

resolution = 40
x = np.linspace(-3, 1, resolution)
y = np.linspace(-3, 1, resolution)
X, Y = np.meshgrid(x, y)
U = np.zeros_like(X)
V = np.zeros_like(Y)

for i in range(resolution):
    for j in range(resolution):
        v = torch.tensor([[X[i, j], Y[i, j]]]).float()
        U[i, j], V[i, j] = st.residual(v, inputs, None).squeeze().numpy()


plt.quiver(X, Y, U, V)

In [None]:
# start from [-1,5. -1,5], and follow the -residual for 10 steps, with step size 0.1 add the trajectory to the quiver plot
v_init = torch.tensor([[-2.0, 3.5]]).float()
v_prev = v_init
plt.plot(v_init[0, 0], v_init[0, 1], "ro", alpha=0.5)
for i in range(20):
    v = v_prev - 0.5 * st.residual(v_prev, inputs, None).squeeze()
    # make the color darker and the point smaller for each step
    plt.plot(v[0, 0], v[0, 1], "ro", alpha=0.5 + 0.01 * i)
    # plot arrows for the step
    plt.arrow(
        v_prev[0, 0],
        v_prev[0, 1],
        v[0, 0] - v_prev[0, 0],
        v[0, 1] - v_prev[0, 1],
        head_width=0.01,
        head_length=0.01,
        fc="k",
        ec="k",
    )
    # mark the last point with a cross and show its coordinates
    if i % 5 == 4:
        plt.plot(v[0, 0], v[0, 1], "rx")
        plt.text(v[0, 0], v[0, 1], f"({v[0, 0]:.3f}, {v[0, 1]:.3f})")
    v_prev = v

resolution = 30
x = np.linspace(-2, 1, resolution)
y = np.linspace(-1.5, 1.5, resolution)
X, Y = np.meshgrid(x, y)
U = np.zeros_like(X)
V = np.zeros_like(Y)

for i in range(resolution):
    for j in range(resolution):
        v_ = torch.tensor([[X[i, j], Y[i, j]]]).float()
        U[i, j], V[i, j] = st.residual(v_, inputs, None).squeeze().numpy()


plt.quiver(X, Y, U, V)

In [None]:
v_prev

In [None]:
v_init = torch.tensor([[-2.0, 3.5]]).float()
i_ext = torch.tensor([[1e-3]])
v_prev = v_init
plt.plot(v_init[0, 0], v_init[0, 1], "ro", alpha=0.5)
for i in range(20):
    res = torch._linalg_solve_ex(st.jacobian(v_prev), -st.residual(v_prev, inputs, i_ext))
    v = v_prev + 1 * res.result.squeeze()
    # make the color darker and the point smaller for each step
    plt.plot(v[0, 0], v[0, 1], "ro", alpha=0.5 + 0.02 * i)
    # plot arrows for the step
    plt.arrow(
        v_prev[0, 0],
        v_prev[0, 1],
        v[0, 0] - v_prev[0, 0],
        v[0, 1] - v_prev[0, 1],
        head_width=0.01,
        head_length=0.01,
        fc="k",
        ec="k",
    )
    print(v)
    # mark the last point with a cross and show its coordinates
    if i % 5 == 4:
        plt.plot(v[0, 0], v[0, 1], "rx")
        plt.text(v[0, 0], v[0, 1], f"({v[0, 0]:.2f}, {v[0, 1]:.2f})")
    v_prev = v

resolution = 40
x = np.linspace(v_init[0, 0], v_init[0, 1], resolution)
y = np.linspace(v_init[0, 0], v_init[0, 1], resolution)
X, Y = np.meshgrid(x, y)
U = np.zeros_like(X)
V = np.zeros_like(Y)

for i in range(resolution):
    for j in range(resolution):
        v_ = torch.tensor([[X[i, j], Y[i, j]]]).float()
        res2 = torch._linalg_solve_ex(st.jacobian(v_), -st.residual(v_, inputs, None))
        U[i, j], V[i, j] = res2.result.squeeze().numpy()


plt.quiver(X, Y, U, V)

In [None]:
st.residual(v_prev, inputs, None)

In [None]:
from src.core.eqprop.strategy import SecondOrderStrategy


def int_relu_f(v, vr=0.6):
    w = ((torch.abs(v) - vr) > 0) * (v - vr)
    return 0.5 * torch.sum(w**2)


def energy(st: SecondOrderStrategy, v, x, i_ext):
    L = st.laplacian()
    R = st.rhs(x)
    if i_ext is not None:
        R[:, -st.dims[-1] :] += i_ext * st.amp_factor
    return v @ L @ v.T + v @ (R) + int_relu_f(v)

In [None]:
energy

In [None]:
# use successive over-relaxation to solve the system
v_init = torch.tensor([-1.5, -1.5]).float()
v_prev = v_init
plt.plot(v_init[0], v_init[1], "ro", alpha=0.5)
omega = 0.5
for i in range(10):
    A = st.jacobian(v_prev)
    L = torch.tril(A, -1).squeeze()
    D = torch.diagonal(A).squeeze()
    U = torch.triu(A, 1).squeeze()
    B = -st.residual(v_prev, inputs, None)
    res = torch._linalg_solve_ex(
        D.diag_embed() + omega * L,
        ((omega - 1) * (D * v_prev).T - omega * U @ v_prev.T + omega * B).T,
    )
    v = res.result.squeeze()
    print(v)
    # make the color darker and the point smaller for each step
    plt.plot(v[0], v[1], "ro", alpha=0.5 + 0.05 * i)
    # plot arrows for the step
    plt.arrow(
        v_prev[0],
        v_prev[1],
        v[0] - v_prev[0],
        v[1] - v_prev[1],
        head_width=0.01,
        head_length=0.01,
        fc="k",
        ec="k",
    )
    # mark the last point with a cross and show its coordinates
    if i % 5 == 4:
        plt.plot(v[0], v[1], "rx")
        plt.text(v[0], v[1], f"({v[0]:.2f}, {v[1]:.2f})")
    v_prev = v

In [None]:
A = st.jacobian(v_init).squeeze()
L = torch.tril(A, -1)
D = torch.diagonal(A)
U = torch.triu(A, 1)
B = -st.residual(v_init, inputs, None)

In [None]:
D.squeeze().diag_embed()

In [None]:
D

In [None]:
(omega - 1) * D * v_init

In [None]:
D.diag_embed() + omega * L

In [None]:
B.T

In [None]:
(omega - 1) * D * v_prev.T - omega * U @ v_prev.T + omega * B.T

# XOR-dummy vs direct interleave

In [None]:
import torch

from src.core.eqprop.eqprop_util import init_params, interleave
from src.models.components.simple_dense_net import SimpleDenseNet

interleave.set_num_output(2)

In [None]:
def interleave_forward(x):
    return x[:, ::2] - x[:, 1::2]

In [None]:
net1 = SimpleDenseNet(cfg=[2, 10, 4], bias=False, batch_norm=False)
net1.apply(init_params(min_w=1e-6, max_w_gain=0.08))

## instantiaate XOR model

In [None]:
from hydra import compose, initialize
from hydra_zen import instantiate, store

from configs import register_everything

store._overwrite_ok = True
overrrides = [
    "experiment=ep-xor-dummy",
    "model.positive_w=false",
    "model.bias=false",
    "model.normalize_weights=false",
    "model.clip_weights=false",
    "model.scale_output=2",
    "model.scale_input=1",
]
register_everything()
with initialize(config_path="../../configs", version_base="1.3"):
    cfg = compose(config_name="train", return_hydra_config=True, overrides=overrrides)

net2 = instantiate(cfg.model.net)()
dm = instantiate(cfg.data)
dm.setup()
dl = dm.train_dataloader()
print(net2)

In [None]:
# train the model
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net1.parameters(), lr=0.02, momentum=0.9)
optimizer2 = torch.optim.SGD(net2.parameters(), lr=0.02, momentum=0.9)
for i in range(200):
    for x, y in dl:
        yhat = interleave_forward(net1(x))
        # yhat = net2(x)
        # accuracy = (yhat.argmax(dim=1) == y).float().mean()
        loss = criterion(yhat, y)
        loss.backward()
        # net2.eqprop(x)
        optimizer.step()
        optimizer.zero_grad()
        if i % 50 == 0:
            print(f"y: {y.item()}, yhat: {yhat.argmax(dim=1).item()}")
            print(loss.item())

# d