In [12]:
import torch

# Dummy fixed point function: F = f(F, x)
def fixed_point_map(F, x):
    # F* = x + sin(F)
    return x + torch.sin(F)

# Fake equilibrium solver (just 1 iteration for simplicity)
def equilibrium(f, guess, params):
    x = params[0].detach()
    F = f(guess, x)
    return F  # simulate that this came from a rootfinder with no tracking

# Inputs
x = torch.tensor([2.0], requires_grad=True)
F_guess = torch.tensor([0.0]).detach()  # doesn't require grad

# Run "SCF" solver
F_star = equilibrium(fixed_point_map, F_guess, params=[x])
print("F_star requires_grad?", F_star.requires_grad)
print("F_star grad_fn:", F_star.grad_fn)

# --- Case 1: E(F_star) directly → fails
E1 = (F_star**2).sum()
try:
    torch.autograd.grad(E1, x)
except RuntimeError as e:
    print("\n[Case 1] No reconnection → autograd fails:", e)

# --- Case 2: reconnect F_star to x
F_connected = F_star + x  # reconnect using a dummy op
E2 = (F_connected**2).sum()
grad2 = torch.autograd.grad(E2, x)[0]
print("\n[Case 2] After reconnecting via F_star * x:")
print("∂E/∂x =", grad2.item())


F_star requires_grad? False
F_star grad_fn: None

[Case 1] No reconnection → autograd fails: element 0 of tensors does not require grad and does not have a grad_fn

[Case 2] After reconnecting via F_star * x:
∂E/∂x = 8.0


In [16]:
import torch

# Define custom autograd Function
class CustomEquilibrium(torch.autograd.Function):
    @staticmethod
    def forward(ctx, F_guess, x):
        # Simulate fixed point iteration: F* = x + sin(F)
        with torch.no_grad():
            F = x + torch.sin(F_guess)
        ctx.save_for_backward(F, x)
        return F

    @staticmethod
    def backward(ctx, grad_output):
        F, x = ctx.saved_tensors

        # f(F, x) = x + sin(F)
        # ∂f/∂F = cos(F)
        # ∂f/∂x = 1

        dfdF = torch.cos(F)
        dfdx = torch.ones_like(x)

        # Apply implicit differentiation:
        # dF*/dx = (I - ∂f/∂F)^(-1) * ∂f/∂x
        # Since everything is scalar here, it's just:
        dFdx = 1.0 / (1.0 - dfdF) * dfdx

        grad_x = grad_output * dFdx # dsmthg/dF * dF/dx
        return None, grad_x  # no grad for F_guess

# Inputs
x = torch.tensor([2.0], requires_grad=True).detach()
F_guess = torch.tensor([0.0]).detach()  # doesn't require grad

# Run equilibrium with custom backward
F_star = CustomEquilibrium.apply(F_guess, x)
print("F_star requires_grad?", F_star.requires_grad)
print("F_star grad_fn:", F_star.grad_fn)

# --- Case: E(F_star)
E = (F_star**2).sum()
grad_x = torch.autograd.grad(E, x)[0]
print("\nGradient ∂E/∂x =", grad_x.item())


F_star requires_grad? False
F_star grad_fn: None


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn