In [None]:
import torch
import torch.autograd as autograd

def hessian_vector_product(h, w, v):
    grad_h = autograd.grad(h, w, create_graph=True, retain_graph=True)[0]
    grad_h_v = torch.dot(grad_h, v)
    hvp = autograd.grad(grad_h_v, w, retain_graph=True)[0]
    return hvp

def conjugate_gradient(A_hvp_func, b, n_iter=10, epsilon=1e-8):
    x = torch.zeros_like(b)
    r = b.clone()
    p = r.clone()
    rs_old = torch.dot(r, r)
    for _ in range(n_iter):
        Ap = A_hvp_func(p)
        alpha = rs_old / torch.dot(p, Ap)
        x += alpha * p
        r -= alpha * Ap
        rs_new = torch.dot(r, r)
        if torch.sqrt(rs_new) < epsilon:
            break
        p = r + (rs_new / rs_old) * p
        rs_old = rs_new
    return x

#hoag implementation
def hoag_optimize(
    inner_loss_func,      
    outer_loss_func,      
    w,                    
    lambda_,              
    train_data,           
    val_data,             
    n_outer_steps=20,     
    inner_lr=0.01,        
    outer_lr=0.1,
    max_inner_steps=50,
    epsilon=1e-5
):
    history = []

    print("--- Starting HOAG Optimization ---")
    for k in range(n_outer_steps):
        
        #Step (i): Solve Inner Problem
        for _ in range(max_inner_steps):
            inner_loss = inner_loss_func(w, lambda_, train_data)
            grad_w_inner = autograd.grad(inner_loss, w, create_graph=True)[0]
            if torch.norm(grad_w_inner) < epsilon:
                break
            w = w - inner_lr * grad_w_inner

        #Step (ii): Solve Linear System
        outer_loss = outer_loss_func(w, val_data)
        b = autograd.grad(outer_loss, w, retain_graph=True)[0]
        
        w_cg = w.detach().requires_grad_()
        lambda_cg = lambda_.detach()
        def hvp_func(v):
            h = inner_loss_func(w_cg, lambda_cg, train_data)
            return hessian_vector_product(h, w_cg, v)
        q_k = conjugate_gradient(hvp_func, b, epsilon=epsilon)

        #Step (iii): Compute Approximate Gradient
        w_clean = w.detach().requires_grad_()
        lambda_clean = lambda_.detach().requires_grad_()

        inner_loss_grads = inner_loss_func(w_clean, lambda_clean, train_data)
        outer_loss_grads = outer_loss_func(w_clean, val_data)

        # First term: d g/dλ
        grad_g_lambda = autograd.grad(outer_loss_grads, lambda_clean, allow_unused=True)[0]
        if grad_g_lambda is None: grad_g_lambda = 0.0
            
        # Second term: (d^2 h / (dλ dw))^T * q_k
        grad_h_lambda = autograd.grad(inner_loss_grads, lambda_clean, create_graph=True)[0]
        cross_gradient_term = autograd.grad(grad_h_lambda, w_clean, retain_graph=True)[0]
        
        p_k = grad_g_lambda - torch.dot(cross_gradient_term, q_k)

        #Step (iv): Update Hyperparameter
        lambda_ = lambda_ - outer_lr * p_k
        lambda_ = torch.clamp(lambda_, min=0.0)

        w = w.detach().requires_grad_()
        lambda_ = lambda_.detach().requires_grad_()
        
        current_val_loss = outer_loss_func(w, val_data).item()
        history.append({'lambda': lambda_.item(), 'val_loss': current_val_loss})
        print(f"Iteration {k+1:02d}: Hyperparameter (lambda) = {lambda_.item():.4f}, Validation Loss = {current_val_loss:.4f}")
        
    print(f"\n--- Final Result ---\nOptimal hyperparameter found: lambda = {lambda_.item():.4f}")
    
    return lambda_, w, history

# if __name__ == '__main__':
#     def ridge_inner_loss(w, lambda_, data):
#         X, y = data
#         return torch.mean((X @ w - y)**2) + lambda_ * torch.sum(w**2)

#     def ridge_outer_loss(w, data):
#         X, y = data
#         return torch.mean((X @ w - y)**2)

#     n_features, n_samples = 10, 100
#     X_train = torch.randn(n_samples, n_features)
#     w_true = torch.randn(n_features)
#     y_train = X_train @ w_true + torch.randn(n_samples) * 0.1
#     X_val = torch.randn(n_samples, n_features)
#     y_val = X_val @ w_true + torch.randn(n_samples) * 0.1

#     w_init = torch.randn(n_features, requires_grad=True)
#     lambda_init = torch.tensor(10.0, requires_grad=True)

#     final_lambda, final_w, history = hoag_optimize(
#         inner_loss_func=ridge_inner_loss,
#         outer_loss_func=ridge_outer_loss,
#         w=w_init,
#         lambda_=lambda_init,
#         train_data=(X_train, y_train),
#         val_data=(X_val, y_val),
#         max_inner_steps=50,
#         epsilon=1e-4
#     )

--- Starting HOAG Optimization ---
Iteration 01: Hyperparameter (lambda) = 9.9909, Validation Loss = 6.0137
Iteration 02: Hyperparameter (lambda) = 9.9818, Validation Loss = 6.0129
Iteration 03: Hyperparameter (lambda) = 9.9727, Validation Loss = 6.0121
Iteration 04: Hyperparameter (lambda) = 9.9635, Validation Loss = 6.0113
Iteration 05: Hyperparameter (lambda) = 9.9544, Validation Loss = 6.0104
Iteration 06: Hyperparameter (lambda) = 9.9452, Validation Loss = 6.0096
Iteration 07: Hyperparameter (lambda) = 9.9360, Validation Loss = 6.0087
Iteration 08: Hyperparameter (lambda) = 9.9268, Validation Loss = 6.0079
Iteration 09: Hyperparameter (lambda) = 9.9176, Validation Loss = 6.0071
Iteration 10: Hyperparameter (lambda) = 9.9084, Validation Loss = 6.0062
Iteration 11: Hyperparameter (lambda) = 9.8991, Validation Loss = 6.0054
Iteration 12: Hyperparameter (lambda) = 9.8899, Validation Loss = 6.0045
Iteration 13: Hyperparameter (lambda) = 9.8806, Validation Loss = 6.0036
Iteration 14: Hy