$u_\beta^*= \argmin E(u)+\beta \mathcal L(u)$

$u_\beta^*-u_0^*=f(\beta)$?, $f(0.01) > \epsilon$?

# setup model & data

In [None]:
# instantiate model
from hydra import compose, initialize
from hydra_zen import instantiate, store

from configs import register_everything

store._overwrite_ok = True
overrrides = [
    "experiment=ep-xor-onehot",
    "model.net.bias=false",
]
register_everything()
with initialize(config_path="../../configs", version_base="1.3"):
    cfg = compose(config_name="train", return_hydra_config=True, overrides=overrrides)

net2 = instantiate(cfg.model.net)
dm = instantiate(cfg.data)
dm.setup()
dl = dm.train_dataloader()
print(net2)

In [None]:
# print dictconfig
from omegaconf import OmegaConf

print(OmegaConf.to_yaml(cfg.model))

In [None]:
import torch

# load model checkpoint
ckpt = torch.load("/root/workspace/ml/logs/train/runs/2024-05-07_23-21-47/checkpoints/last.ckpt")
# get weights from lin1 & last layers
# w1 = ckpt["state_dict"]["net.model.0.weight"]
# w2 = ckpt["state_dict"]["net.model.1.weight"]
# get biases from lin1 & 1 layers
# b1 = ckpt["state_dict"]["net.model.0.bias"]
# b2 = ckpt["state_dict"]["net.model.1.bias"]
# get input & output dimensions

In [None]:
# print all parameters
for k, v in ckpt["state_dict"].items():
    print(f"name:{k}, {v}")

In [None]:
# remove the model prefix from the keys
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in ckpt["state_dict"].items():
    name = k.replace("net.", "")
    new_state_dict[name] = v
# overwrite model weights from checkpoint
net2.load_state_dict(new_state_dict, strict=False)
net2.ypred = None

# Train XOR

In [None]:
# free phase
import torch.nn as nn

criterion = nn.MSELoss()
for x, y in dl:
    x = x.view(x.size(0), -1)
    ypred = net2(x)
    # make y onehot
    y = torch.nn.functional.one_hot(y, 2).float()
    print(f"ypred: {ypred.data}, y: {y}")
    loss = criterion(ypred, y)
    loss.backward()
    u_nudge, _ = net2.solver(x, nudge_phase=True, grad=net2.ypred.grad)
    for idx, u_n in enumerate(u_nudge):
        u_f = net2.model[idx].get_buffer("positive_node")
        diff = torch.norm(u_n - u_f)
        # compute cosine similarity
        cos = nn.CosineSimilarity(dim=1)
        cos_sim = cos(u_n, u_f)
        print(f"Layer {idx} diff: {diff}, cos_sim: {cos_sim}")

In [None]:
u_nudge[0].item()

In [None]:
net2.model[1].get_buffer("positive_node")