# Setup

In [None]:
import torch

input_size = 7
hidden_size = 5
output_size = 3
batch_size = 2

x = torch.rand(batch_size, input_size, requires_grad=True)
y = torch.rand(batch_size, output_size)
y2 = torch.rand(batch_size, hidden_size)

# PoC

In [None]:
import torch.nn as nn

from src.core.eqprop import nn as enn

model = nn.Sequential(
    enn.EqPropLinear(input_size, hidden_size), enn.EqPropLinear(hidden_size, output_size)
)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
import torch.nn.functional as F

print(f"weight: {model[1].weight}")
logit = model(x)
y_hat = F.softmax(logit, dim=1)
loss = F.mse_loss(y_hat, y)
optimizer.zero_grad()
loss.backward()
print(model[1].weight.grad)
print(f"input grad: {x.grad}")
optimizer.step()
print(f"new weight: {model[1].weight}")

# Compare to AnalogEP2

In [None]:
from copy import deepcopy

from src._eqprop.eqprop_backbone import AnalogEP2

model1 = enn.EqPropLinear(input_size, hidden_size, bias=True)
solver = deepcopy(model1.solver)
solver.strategy.W = []
solver.strategy.B = []
solver.strategy.dims = []
model2 = AnalogEP2(
    batch_size=batch_size,
    solver=solver,
    bias=True,
    cfg=[input_size, hidden_size],
    scale_input=1,
    scale_output=1,
)

model1.weight.data = model2.model[0].weight.data.clone().detach()
model1.bias.data = model2.model[0].bias.data.clone().detach()

In [None]:
logit_1 = model1(x)
logit_1.retain_grad()
y_hat1 = F.softmax(logit_1, dim=1)
loss_1 = F.mse_loss(y_hat1, y2)
loss_1.backward()
logit_2 = model2(x)
y_hat2 = F.softmax(logit_2, dim=1)
loss_2 = F.mse_loss(y_hat2, y2)
loss_2.backward()
model2.eqprop(x)
print(torch.allclose(y_hat1, y_hat2))
torch.allclose(model1.weight.grad, model2.model[0].weight.grad)

In [None]:
pnode = model2.get_buffer("model.0.positive_node")
nnode = model2.get_buffer("model.0.negative_node")
model1.zero_grad()
model1.calc_n_set_param_grad_(x, (pnode, nnode))

In [None]:
from src.utils.eqprop_utils import deltaV

(deltaV(x, nnode).pow(2) - deltaV(x, pnode).pow(2)).mean(0) / model1.solver.beta

# Bias validation

In [None]:
from src.core.eqprop import nn as enn
from src.utils.eqprop_utils import positive_param_init

model1 = enn.EqPropLinear(input_size, hidden_size)
model2 = enn.EqPropLinear(input_size + 1, hidden_size, bias=False)
model1.apply(positive_param_init)

# concat model1's weight and grad and paste it into model 2
w = model1.weight.data.clone().detach()
b = model1.bias.data.clone().detach().unsqueeze(-1)
w_tilde = torch.cat((w, b), dim=1)
x_tilde = torch.cat((x, torch.ones((x.size(0), 1))), dim=1)

model2.weight.data = w_tilde

In [None]:
import torch.nn.functional as F

y2 = torch.rand(batch_size, hidden_size)
logit_1 = model1(x)
logit_1.retain_grad()
y_hat1 = F.softmax(logit_1, dim=1)
loss_1 = F.mse_loss(y_hat1, y2)
loss_1.backward()
logit_2 = model2(x_tilde)
y_hat2 = F.softmax(logit_2, dim=1)
loss_2 = F.mse_loss(y_hat2, y2)
loss_2.backward()
print(torch.allclose(y_hat1, y_hat2))
# split model2's weight into w and b
w_grad, b_grad = torch.split(model2.weight.grad.clone(), input_size, dim=1)
print(torch.allclose(model1.weight.grad, w_grad, atol=1e-4))
torch.allclose(model1.bias.grad, b_grad, atol=1e-4)

In [None]:
torch.allclose(model1.bias.grad, b_grad, atol=1e-6)