In [1]:
import sys
from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

In [2]:
import torch
from fs_mol.utils.cauchy_hypergradient import cauchy_hypergradient
from fs_mol.utils.cauchy_hypergradient_jvp import cauchy_hypergradient_jvp

In [3]:
def f_outer(params_outer, params_inner):
    """ sum of squares of parameters """
    s = 0.0
    for tup in (params_outer, params_inner):
        for p in tup:
            s += (p**2).sum()
    return s

In [4]:
a = torch.randn(3, 4)
b = torch.randn(5)
a.requires_grad_(True)
b.requires_grad_(True)

a.grad, b.grad

(None, None)

In [5]:
cauchy_hypergradient(
    f_outer,
    f_outer, # both functions are the same here,
    params_outer=(a,),
    params_inner=(b,),
    device=torch.device("cpu")
)

tensor(19.8681, grad_fn=<AddBackward0>)

In [6]:
a.grad, b.grad

(tensor([[-0.8843, -3.3354,  1.8506, -0.0197],
         [ 0.8465, -1.0572,  0.3310, -1.2549],
         [-3.7115, -2.1355, -2.2661,  2.6281]]),
 tensor([-3.6669, -1.0069,  1.1539,  1.7742, -3.3623]))

In [7]:
a.grad, b.grad = None, None

In [8]:
cauchy_hypergradient_jvp(
    f_outer,
    f_outer, # both functions are the same here,
    params_outer=(a,),
    params_inner=(b,),
    device=torch.device("cpu")
)

tensor(19.8681, grad_fn=<AddBackward0>)

In [9]:
a.grad, b.grad

(tensor([[-0.8843, -3.3354,  1.8506, -0.0197],
         [ 0.8465, -1.0572,  0.3310, -1.2549],
         [-3.7115, -2.1355, -2.2661,  2.6281]]),
 tensor([-3.6669, -1.0069,  1.1539,  1.7742, -3.3623]))

In [10]:
def f_outer(params_outer, params_inner):
    """ quadratic function """
    a, b, c = params_outer
    x,  = params_inner
    
    return a*(x**2) + b*x + c

f_inner = f_outer

In [11]:
a = torch.exp(torch.randn(1))[0]  # ensure positive
b = torch.randn(1)[0]
c = torch.randn(1)[0]
x = torch.randn(1)[0]

params_outer = (a, b, c)
params_inner = (x, )
for tup in (params_outer, params_inner):
    for tensor in tup:
        tensor.requires_grad_(True)

In [12]:
# minimize w.r.t. x, in this case done analytically
with torch.no_grad():
    x.fill_(- b / 2 / a)

# Confirm that the x gradient is 0 here
loss = f_inner(params_outer, params_inner)
loss.backward()
x.grad

tensor(0.)

In [13]:
# try my hypergradient
cauchy_hypergradient(
    f_outer,
    f_inner,
    params_outer=params_outer,
    params_inner=params_inner,
    device=torch.device("cpu")
)

tensor(0.5018, grad_fn=<AddBackward0>)

In [14]:
tuple(t.grad for t in params_outer)

(tensor(0.0505), tensor(0.2246), tensor(1.))

In [15]:
with torch.no_grad():
    print(
        (
            b**2/4/a**2,
            -b/2/a,
            1.0
        )
    )

(tensor(0.0505), tensor(0.2246), 1.0)


In [16]:
a.grad, b.grad, c.grad, x.grad = None, None, None, None

In [17]:
# try my hypergradient
cauchy_hypergradient_jvp(
    f_outer,
    f_inner,
    params_outer=params_outer,
    params_inner=params_inner,
    device=torch.device("cpu")
)

tensor(0.5018, grad_fn=<AddBackward0>)

In [18]:
tuple(t.grad for t in params_outer)

(tensor(0.0505), tensor(0.2246), tensor(1.))

In [19]:
# check that "residual" hypergradient is 0
cauchy_hypergradient(
    f_outer,
    f_inner,
    params_outer=params_outer,
    params_inner=params_inner,
    ignore_direct_grad=True,
    device=torch.device("cpu")
)

tensor(0.5018, grad_fn=<AddBackward0>)

In [20]:
tuple(t.grad for t in params_outer)

(tensor(0.), tensor(0.), tensor(0.))

In [21]:
# check that "residual" hypergradient is 0
cauchy_hypergradient_jvp(
    f_outer,
    f_inner,
    params_outer=params_outer,
    params_inner=params_inner,
    ignore_direct_grad=True,
    device=torch.device("cpu")
)

tensor(0.5018, grad_fn=<AddBackward0>)

In [22]:
tuple(t.grad for t in params_outer)

(tensor(0.), tensor(0.), tensor(0.))

In [23]:
import numpy as np

In [24]:
def f_inner(params_outer, params_inner):
    """ quadratic function """
    a, b, c = params_outer
    x1, x2 = params_inner
    
    return (
        torch.sum(a*(x1**2) + b*x1 + c) +
        torch.sum(a*(x2**2) + b*x2 + c)
    ) / 2

def f_outer(params_outer, params_inner):
    a, b, c = params_outer
    x1, x2 = params_inner
    
    return torch.sum(a * (x1+x2)) / 2

In [25]:
# test in a loop:
for _ in range(100):
    
    D = 3
    a = torch.exp(torch.randn(D))  # ensure positive
    b = torch.randn(D)
    c = torch.randn(D)
    x1 = torch.randn(D)
    x2 = torch.randn(D)

    params_outer = (a, b, c)
    params_inner = (x1, x2)
    for tup in (params_outer, params_inner):
        for tensor in tup:
            tensor.requires_grad_(True)

    # minimize w.r.t. x, in this case done analytically
    with torch.no_grad():
        x1.data[:] = (-b/2/a).data.clone()
        x2.data[:] = (-b/2/a).data.clone()

    # Confirm that the x gradient is 0 here
    loss = f_inner(params_outer, params_inner)
    loss.backward()
    x.grad

    cauchy_hypergradient(
        f_outer=f_outer,
        f_inner=f_inner,
        params_outer=params_outer,
        params_inner=params_inner,
        device=torch.device("cpu")
    )

    assert np.allclose(a.grad.numpy(), 0.0)
    assert np.allclose(b.grad.numpy(),-0.5)
    assert np.allclose(c.grad.numpy(), 0.0)

In [26]:
# test in a loop:
for _ in range(100):
    
    D = 3
    a = torch.exp(torch.randn(D))  # ensure positive
    b = torch.randn(D)
    c = torch.randn(D)
    x1 = torch.randn(D)
    x2 = torch.randn(D)

    params_outer = (a, b, c)
    params_inner = (x1, x2)
    for tup in (params_outer, params_inner):
        for tensor in tup:
            tensor.requires_grad_(True)

    # minimize w.r.t. x, in this case done analytically
    with torch.no_grad():
        x1.data[:] = (-b/2/a).data.clone()
        x2.data[:] = (-b/2/a).data.clone()

    # Confirm that the x gradient is 0 here
    loss = f_inner(params_outer, params_inner)
    loss.backward()
    x.grad

    cauchy_hypergradient_jvp(
        f_outer=f_outer,
        f_inner=f_inner,
        params_outer=params_outer,
        params_inner=params_inner,
        device=torch.device("cpu")
    )


    assert np.allclose(a.grad.numpy(), 0.0)
    assert np.allclose(b.grad.numpy(),-0.5)
    assert np.allclose(c.grad.numpy(), 0.0)