In [13]:
# for compatibility with CLRS, we advise using Python 3.9 or later
# install required packages (only need to run once)
#!pip install -e .

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from ltfs.common.logic import less_than, if_else

def compute_hessian(param, closure_loss):
    hessian_entries = []
    param_with_grad = [p for p in param if p.requires_grad]
    grads = torch.autograd.grad(closure_loss, param_with_grad, create_graph=True)
    for grad in grads:
        for grad_i in grad.view(-1):
            hessian_row_i = torch.autograd.grad(grad_i, param_with_grad, retain_graph=True)
            hessian_row_i = [item.item() for sublist in hessian_row_i for item in sublist.view(-1)]
            hessian_entries.append(hessian_row_i)
    return torch.tensor(hessian_entries)

class NewtonsMethod(optim.Optimizer):
    def __init__(self, params, lr=1e-3, damping=1e-2):
        super(NewtonsMethod, self).__init__(params, {'lr': lr, 'damping': damping})

    def step(self, closure=None):
        if closure is not None:
            loss = closure()  # Just call the closure without any arguments

        for group in self.param_groups:
            lr = group['lr']
            damping = group['damping']
            param_with_grad = [p for p in group['params'] if p.requires_grad]

            flat_grad = torch.hstack([p.grad.flatten() for p in param_with_grad])
            hessian = compute_hessian(group['params'], loss)
            hessian += torch.eye(hessian.size(0), device=hessian.device) * damping
            hessian_inv = torch.linalg.solve(hessian, flat_grad)

            i = 0
            for p in param_with_grad:
                delta_w = hessian_inv[i:i+p.numel()].view_as(p)
                p.data -= lr * delta_w
                i += p.numel()
                
        return loss

def distance2ref(layer, layer_ref):
    return sum([torch.norm(p - p_ref) for p, p_ref in zip(layer.parameters(), layer_ref.parameters())])

d = 12 # input dimension
# define some placeholder structure because the layer definition requires it (the data will be random though)
index = {"X": slice(0, 1), "Y": slice(1, 2), "Z": slice(2, 3), "W": slice(3, 4), "B_global": slice(4, 5), "B_local": slice(6, 7), "S": slice(6, 12)}
X = torch.randn(20000, d)

### Case 1: Using any first-order optimization method

In [16]:
distance = 1e-8
omega = 5e1
layer = if_else(d, index, ["X"], ["Y"], "Z", ["W"], ifINF=omega)[0]
layer_ref = if_else(d, index, ["X"], ["Y"], "Z", ["W"], ifINF=omega)[0]
for p in layer.parameters():
    p.data += torch.randn_like(p.data) * distance
    p.requires_grad = True

optimizer = optim.Adam(layer.parameters(), lr=1e-8, weight_decay=1e-3)

for iter in range(0, 1000):
    optimizer.zero_grad()
    loss = F.mse_loss(layer(X), layer_ref(X))
    if iter != 0:
        loss.backward()
        optimizer.step()
    else:
        loss.detach()

    if iter % 100 == 0:
        print(f"iter {iter}, loss {loss.item()}, distance {distance2ref(layer, layer_ref)}")
        if distance2ref(layer, layer_ref) < 1e-8:
            break

iter 0, loss 1.0344297701261573e-13, distance 4.771546425590996e-07
iter 100, loss 4.433348204268193e-13, distance 1.4358682932081695e-05
iter 200, loss 1.314595006619079e-12, distance 3.0306186728547305e-05
iter 300, loss 2.533690824605958e-12, distance 4.679554079239725e-05
iter 400, loss 4.067522577782731e-12, distance 6.363722684239774e-05
iter 500, loss 5.900021529996412e-12, distance 8.073097445332383e-05
iter 600, loss 8.022824302043332e-12, distance 9.801216503501996e-05
iter 700, loss 1.0430615599977466e-11, distance 0.00011545012658027968
iter 800, loss 1.3119908758432733e-11, distance 0.00013301940722362075
iter 900, loss 1.609022122627473e-11, distance 0.00015068731425805197


### Case 2: Using a second-order method (Newton's method)

In [19]:
distance = 1e-8
omega = 5e1
layer = if_else(d, index, ["X"], ["Y"], "Z", ["W"], ifINF=omega)[0]
layer_ref = if_else(d, index, ["X"], ["Y"], "Z", ["W"], ifINF=omega)[0]
y = layer_ref(X)

for p in layer.parameters():
    p.data += torch.randn_like(p.data) * distance
    p.requires_grad = True

def closure(params=None, i=None):
    original_params = [p.clone() for p in layer.parameters()]
    
    if params is not None and i is not None:
        list(layer.parameters())[i].data = params.data.clone()

    optimizer.zero_grad()
    output = layer(X)
    loss = F.mse_loss(output, y)
    loss.backward(retain_graph=True)

    for p, orig_p in zip(layer.parameters(), original_params):
        p.data = orig_p.data.clone()

    return loss

optimizer = NewtonsMethod(layer.parameters(), lr=1e-2, damping=1e-5)

for iter in range(0, 1000):
    optimizer.zero_grad()
    loss = optimizer.step(closure)
    if iter % 100 == 0:
        print(f"iter {iter}, loss {loss.item()}, distance {distance2ref(layer, layer_ref)}")
        if distance2ref(layer, layer_ref) < 1e-8:
            break

iter 0, loss 1.919412982515624e-13, distance 4.7118882865253995e-07
iter 100, loss 2.5716272498403237e-14, distance 3.4379886766804754e-07
iter 200, loss 3.445463288926409e-15, distance 3.166777272868899e-07
iter 300, loss 4.616228258858022e-16, distance 3.1211491451722303e-07
iter 400, loss 6.184823963151172e-17, distance 3.113663895159845e-07
iter 500, loss 8.286444885022362e-18, distance 3.1120533682987424e-07
iter 600, loss 1.1102318896171125e-18, distance 3.111434044693802e-07
iter 700, loss 1.4875860031202636e-19, distance 3.111140009591099e-07
iter 800, loss 1.9933300266410284e-20, distance 3.111031116586465e-07
iter 900, loss 2.673735190333543e-21, distance 3.110958579090009e-07
