# Fold-opt Layer with PGD
1. Define PGD update step with analytical projection
2. Define solver `solve_closed_form`

In [106]:
import torch
from torch import nn
import math
import cvxpy as cp

In [107]:
import sys

sys.path.insert(0, 'E:\\User\\Stevens\\MyRepo\\FDFL\\helper')
# sys.path.insert(0, 'E:\\User\\Stevens\\Code\\Fold-opt\\fold_opt')
from myutil import *
from GMRES import *
from fold_opt import *

In [108]:
# ---------------- little utilities ---------------------------------
def to_col(x):              # (…,n)  ->  (…,n,1)
    return x.unsqueeze(-1) if x.dim() == 1 else x
def from_col(x):            # (…,n,1) -> (…,n)
    return x.squeeze(-1)

def from_numpy_torch(x):  # (n,) -> (1,n) -> (1,n,1)
    return torch.from_numpy(x).float().unsqueeze(0).unsqueeze(-1)
def to_torch_numpy(x):    # (1,n,1) -> (1,n) -> (n,)
    return x.squeeze(0).squeeze(-1).numpy()

# PGD

 -------------------------------------------------------------
 1.  Projection onto ${d \geq 0 , c^ {\top} d \leq Q}$
 
     Formula:  $\Pi(d) = [d - λ^* \cdot c]_+$  with
                $λ^* \geq 0$ chosen s.t. $c^{\top}\Pi(d)=Q$   (if budget violated)
 -------------------------------------------------------------


In [109]:
# import torch
# import numpy as np
# import time
# import cvxpy as cp

# # --- Your analytical projection (as given) ---
# def proj_budget(x, cost, Q, max_iter=60):
#     batched = x.dim() == 2
#     if not batched:
#         x = x.unsqueeze(0)
#     B, n = x.shape
#     cost = cost.to(x)
#     Q = torch.as_tensor(Q, dtype=x.dtype, device=x.device).reshape(-1, 1)

#     d = x.clamp(min=0.)
#     viol = (d @ cost) > Q.squeeze(1)

#     if viol.any():
#         dv, Qv = d[viol], Q[viol]
#         lam_lo = torch.zeros_like(Qv.squeeze(1))
#         lam_hi = (dv / cost).max(1).values

#         for _ in range(max_iter):
#             lam_mid = 0.5 * (lam_lo + lam_hi)
#             trial = (dv - lam_mid[:, None] * cost).clamp(min=0.)
#             too_big = (trial @ cost) > Qv.squeeze(1)
#             lam_lo[too_big]  = lam_mid[too_big]
#             lam_hi[~too_big] = lam_mid[~too_big]

#         d[viol] = (dv - lam_hi[:, None] * cost).clamp(min=0.)

#     return d if batched else d.squeeze(0)

# # --- Solver-based projection via CVXPY ---
# def cvx_proj(x_np, cost_np, Q):
#     n = x_np.shape[0]
#     d = cp.Variable(n)
#     obj = cp.Minimize( cp.sum_squares(d - x_np) )
#     cons = [d >= 0, cost_np.T @ d <= Q]
#     prob = cp.Problem(obj, cons)
#     _ = prob.solve(solver=cp.OSQP, warm_start=True)
#     return d.value

# # --- Timing routine ---
# def benchmark(sizes, repetitions=5):
#     results = []
#     for n in sizes:
#         # generate random problem
#         x   = torch.randn(n)
#         cost= torch.rand(n) + 0.1
#         Q   = float( n * cost.mean() * 0.5 )

#         # warm‐up
#         _ = proj_budget(x, cost, Q)
#         _ = cvx_proj(x.numpy(), cost.numpy(), Q)

#         # time analytical
#         t0 = time.perf_counter()
#         for _ in range(repetitions):
#             _ = proj_budget(x, cost, Q)
#         t_custom = (time.perf_counter() - t0) / repetitions

#         # time solver
#         t0 = time.perf_counter()
#         for _ in range(repetitions):
#             _ = cvx_proj(x.numpy(), cost.numpy(), Q)
#         t_solver = (time.perf_counter() - t0) / repetitions

#         results.append((n, t_custom, t_solver))

#     # print results
#     print(f"{'n':>8} | {'custom (s)':>12} | {'solver (s)':>12}")
#     print("-" * 38)
#     for n, tc, ts in results:
#         print(f"{n:8d} | {tc:12.6f} | {ts:12.6f}")

# if __name__ == "__main__":
#     sizes = [100, 1_000, 5_000, 10_000]
#     benchmark(sizes)


In [110]:
def proj_budget(x, cost, Q, max_iter=60):
    """
    x : (B,n)   or (n,)   –– internally promoted to (B,n)
    cost : (n,) positive
    Q : scalar or length‑B tensor
    """
    batched = x.dim() == 2
    if not batched:                       # (n,)  →  (1,n)
        x = x.unsqueeze(0)

    B, n = x.shape
    cost = cost.to(x)
    Q    = torch.as_tensor(Q, dtype=x.dtype, device=x.device).reshape(-1, 1)  # (B,1)

    d    = x.clamp(min=0.)                # enforce non‑neg
    viol = (d @ cost) > Q.squeeze(1)      # which rows violate the budget?

    if viol.any():
        dv, Qv = d[viol], Q[viol]
        lam_lo = torch.zeros_like(Qv.squeeze(1))
        lam_hi = (dv / cost).max(1).values   # upper bound for λ⋆

        for _ in range(max_iter):
            lam_mid = 0.5 * (lam_lo + lam_hi)
            trial   = (dv - lam_mid[:, None] * cost).clamp(min=0.)
            too_big = (trial @ cost) > Qv.squeeze(1)
            lam_lo[too_big] = lam_mid[too_big]
            lam_hi[~too_big]= lam_mid[~too_big]

        d[viol] = (dv - lam_hi[:, None] * cost).clamp(min=0.)

    return d if batched else d.squeeze(0)   # restore original rank

 -------------------------------------------------------------
###  One PGD update that *is differentiable* wrt both r & d
 -------------------------------------------------------------

In [111]:
def alpha_fair(u, alpha):
    if alpha == 1:
        return torch.log(u).sum(-1)
    elif alpha == 0:
        return u.sum(-1)
    elif alpha == 'inf':
        return u.min(-1).values
    return (u.pow(1-alpha)/(1-alpha)).sum(-1)

def pgd_step(r, d, g, cost, Q, alpha, lr):
    d = d.clone().requires_grad_(True)
    obj     = alpha_fair(d * r * g, alpha).sum()
    grad_d, = torch.autograd.grad(obj, d, create_graph=True)
    return proj_budget(d + lr * grad_d, cost, Q)


In [112]:
def closed_form_solver_torch(r, g, cost, alpha, Q):
    if r.dim() == 1:                      # (n,) → (1,n) before looping
        r = r.unsqueeze(0)
    out = []
    for r_i in r:
        d_np, _ = solve_closed_form(g.cpu().numpy(),
                                    r_i.detach().cpu().numpy(),
                                    cost.cpu().numpy(),
                                    alpha, Q)
        out.append(torch.as_tensor(d_np, dtype=r.dtype, device=r.device))
    return torch.stack(out)               # (B,n) even if B=1

 ------------------------------------------------------------
 ###  Fold‑Opt layer
 ------------------------------------------------------------

In [113]:
def make_foldopt_layer(g, cost, alpha, Q,
                       lr=1e-2, n_fixedpt=200, rule='GMRES'):
    g    = g.detach()
    cost = cost.detach()

    # -------- solver: no gradients flow ----------------------------
    def solver_fn(r):
        return closed_form_solver_torch(r, g, cost, alpha, Q)

    # -------- one differentiable PGD step --------------------------
    def update_fn(r, x_star, *_):
        # promote to (B,n) if needed
        if r.dim() == 1:       r = r.unsqueeze(0)
        if x_star.dim() == 1:  x_star = x_star.unsqueeze(0)

        g_b = g.expand_as(r) if g.dim() == 1 else g
        return pgd_step(r, x_star, g_b, cost, Q, alpha, lr)  # (B,n)

    return FoldOptLayer(solver_fn, update_fn,
                        n_iter=n_fixedpt, backprop_rule=rule)

In [114]:
r     = torch.tensor([1.,2.,3.,4.,5.], requires_grad=True)
print(r.shape)
r = r.unsqueeze(0)
r.shape

torch.Size([5])


torch.Size([1, 5])

In [115]:
if __name__ == "__main__":
    torch.manual_seed(0)

    # ---------- single input vector --------------------------------
    r     = torch.tensor([1.,2.,3.,4.,5.], requires_grad=True)
    g     = torch.ones_like(r)
    cost  = torch.ones_like(r)
    alpha = 2
    Q     = 2.0

    r_b   = r.unsqueeze(0)
    print(r_b.shape)       # (1,5)

    layer = make_foldopt_layer(g, cost, alpha, Q, lr=1e-2)

    d = layer(r_b)                # d is (1,5)
    d.sum().backward()   # OK: shapes now match inside GMRES

    print("d (no‑batch) :", d)
    print("∂d/∂r, shape  :", r.grad, r.grad.shape)   # -> (5,)

    J = compute_gradient_closed_form(g, r.detach(), cost, alpha, Q)   # now J is 5×5 = ∂d/∂r
    v = torch.ones(5)                                                 # same upstream as d.sum()
    vjp_closed = (torch.tensor(J.T) @ v).detach()                                   # shape (5,)



    r.grad.zero_()
    d = layer(r.unsqueeze(0))
    d.sum().backward()                 # uses the same “ones” upstream
    vjp_foldopt = r.grad               # shape (5,)
    print("closed VJP:", vjp_closed)
    print("fold‑opt VJP:", vjp_foldopt)
    print("inf norm diff:", (vjp_closed - vjp_foldopt).abs().max())



torch.Size([1, 5])
d (no‑batch) : tensor([[0.6189, 0.4376, 0.3573, 0.3094, 0.2768]],
       grad_fn=<FixedPtDiffGMRESBackward>)
∂d/∂r, shape  : tensor([ 0.0239, -0.1094, -0.0596, -0.0387, -0.0277]) torch.Size([5])
closed VJP: tensor([-7.4506e-09, -5.5879e-09, -7.4506e-09,  1.8626e-09, -1.8626e-09])
fold‑opt VJP: tensor([ 0.0239, -0.1094, -0.0596, -0.0387, -0.0277])
inf norm diff: tensor(0.1094)


In [116]:
alpha = 1.5

In [117]:
# 1) closed‑form VJP
J        = torch.tensor(compute_gradient_closed_form(g, r.detach(), cost, alpha, Q))  # (5×5)
v        = torch.randn(5)   # your test upstream gradient
vjp_cl   = J.T @ v          # (5,)

# 2) fold‑opt VJP
r.grad = None
d = layer(r.unsqueeze(0)).squeeze(0)
d.backward(v)               # use v instead of ones
vjp_fo = r.grad
print("closed VJP:", vjp_cl)
print("fold‑opt VJP:", vjp_fo)

# 3) cosine
cos = torch.dot(vjp_cl, vjp_fo) / (vjp_cl.norm()*vjp_fo.norm()+1e-12)
print("cosine similarity:", cos.item())


closed VJP: tensor([-0.3008,  0.0117,  0.0853, -0.0198,  0.0201])
fold‑opt VJP: tensor([-0.4898,  0.0321,  0.1297, -0.0220,  0.0300])
cosine similarity: 0.9993108510971069


In [118]:
# def my_solver(c):
#     return torch.clamp(c, min=0.0)

# def my_update_step(c, x):
#     alpha = 0.1
#     grad  = x - c
#     x_new = torch.clamp(x - alpha*grad, min=0.0)
#     return x_new

# fold_layer = FoldOptLayer(
#                 solver      = my_solver,
#                 update_step = my_update_step,
#                 n_iter      = 20,
#                 backprop_rule='Jacobianx')

# c      = torch.tensor([[1.0], [ 2.0]], requires_grad=True)   # (B=2, n=1)
# target = torch.tensor([[ 1], [ 1]])

# x_star = fold_layer(c)
# print("x* from FoldOptLayer:", x_star.squeeze().tolist())     # → [0.0, 2.0]

# loss = 0.5 * torch.sum((x_star - target) ** 2)
# loss.backward()

# print("Grad wrt c:", c.grad.squeeze().tolist())

In [119]:
r.grad

tensor([-0.4898,  0.0321,  0.1297, -0.0220,  0.0300])

In [120]:
# column sum of cl_grad
grad_l_d , cl_grad.sum(0) * 1e7

NameError: name 'grad_l_d' is not defined

In [None]:
J = torch.tensor(compute_gradient_closed_form(r, g, cost, alpha, Q))  # shape (5,5)
vjp_closed = J.sum(dim=0)    # shape (5,)
print("closed‑form VJP:", vjp_closed)
print("fold‑opt  VJP:", r.grad)  # from your backward()
print("difference ∞‑norm:", (vjp_closed - r.grad).abs().max())


closed‑form VJP: tensor([-7.4506e-09, -2.2352e-08, -6.7055e-08,  2.9802e-08, -8.1956e-08])
fold‑opt  VJP: tensor([ 0.0239, -0.1094, -0.0596, -0.0387, -0.0277])
difference ∞‑norm: tensor(0.1094)


In [None]:
v1 = vjp_closed
v2 = r.grad  # your fold‑opt VJP
cos_sim = torch.dot(v1, v2) / (v1.norm() * v2.norm() + 1e-12)
print("cosine sim =", cos_sim.item())

cosine sim = 0.48399463295936584
