In [None]:
import torch
from primal_dual_hybrid_gradient_step import adaptive_one_step_pdhg, fixed_one_step_pdhg
from helpers import spectral_norm_estimate_torch, KKT_error, compute_residuals_and_duality_gap, check_termination
from enhancements import primal_weight_update

def pdlp_algorithm(K, m_ineq, c, q, l, u, device, max_iter=100_000, tol=1e-4, verbose=True, restart_period=40, precondition=False, primal_update=False, adaptive=False, data_precond=None):
    '''
    Main PDLP algorithm implementation with integrated infeasibility detection.
    Args:
        K (torch.Tensor): Constraint matrix.
        m_ineq (int): Number of inequality constraints.
        c (torch.Tensor): Coefficients for the primal objective.
        q (torch.Tensor): Right-hand side vector for the constraints.
        l (torch.Tensor): Lower bounds for the primal variable.
        u (torch.Tensor): Upper bounds for the primal variable.
        device (torch.device): Device to run the algorithm on (CPU or GPU).
        max_iter (int): Maximum number of iterations.
        tol (float): Tolerance for convergence.
        verbose (bool): Whether to print detailed output.
        restart_period (int): Number of iterations between restarts.
        precondition (bool): Whether to use preconditioning.
        primal_update (bool): Whether to perform primal weight updates.
        adaptive (bool): Whether to use adaptive stepsize.
        data_precond (tuple): Preconditioned data (D_col, D_row, K_unscaled, c_unscaled, q_unscaled, l_unscaled, u_unscaled) if precondition is True.

    Returns:
        x (torch.Tensor): Optimal primal variable.
        prim_obj (float): Optimal primal objective value.
        k (int): Total number of iterations.
        n (int): Number of restart loops.
        j (int): KKT pass counter.
    '''
    # Recover G, A, h, b for infeasibility detection
    m_eq = K.shape[0] - m_ineq
    if m_ineq > 0:
        G_mat = K[:m_ineq]
        h = q[:m_ineq]
    else:
        G_mat = torch.zeros((0, K.shape[1]), device=device)
        h = torch.zeros((0, 1), device=device)
    if m_eq > 0:
        A_mat = K[m_ineq:]
        b = q[m_ineq:]
    else:
        A_mat = torch.zeros((0, K.shape[1]), device=device)
        b = torch.zeros((0, 1), device=device)

    # Dual bound masks
    is_neg_inf = torch.isinf(l) & (l < 0)
    is_pos_inf = torch.isinf(u) & (u > 0)

    l_dual = l.clone()
    u_dual = u.clone()
    l_dual[is_neg_inf] = 0
    u_dual[is_pos_inf] = 0

    q_norm = torch.linalg.norm(q, 2)
    c_norm = torch.linalg.norm(c, 2)

    # Initial step-size
    eta = 0.9 / spectral_norm_estimate_torch(K, num_iters=100)
    omega = c_norm / q_norm if q_norm > 1e-6 and c_norm > 1e-6 else torch.tensor(1.0)

    theta = 1.0

    # Restart Parameters [Sufficient, Necessary, Artificial]
    beta = [0.2, 0.8, 0.36]

    # Initialize primal and dual
    x = torch.zeros((c.shape[0], 1), device=device)
    y = torch.zeros((K.shape[0], 1), device=device)

    # Infeasibility detection parameters
    tol_eq = 1e-2
    tol_feas = 1e-2
    tol_obj = 1e-2
    tol_dual_eq = 1e-2
    tol_cert = 1e-2

    # Trackers for infeasibility detection
    x_prev_iter = x.clone()
    y_prev_iter = y.clone()
    lam_prev = torch.zeros_like(x)
    certificate_flag = None
    certificate_iter = None

    # Counters
    n = 0 # Outer Loop Counter
    k = 0 # Total Iteration Counter
    j = 0 # Kkt pass Counter

    # Initialize Previous KKT Error
    KKT_first = 0 # The actual KKT error of the very first point doesn't matter since the artificial criteria will always hit anyway

    # -------------- Outer Loop --------------
    while k < max_iter:
        t = 0 # Initialize inner iteration counter

        # Initialize/Reset sums for averaging
        x_eta_total = torch.zeros_like(x)
        y_eta_total = torch.zeros_like(y)
        eta_total = 0

        # Initialize/Reset Previous restart point for primal weighting
        x_last_restart = x.clone()
        y_last_restart = y.clone()

        # --------- Inner Loop ---------
        while k < max_iter:
            k += 1
            x_previous = x.clone() # For checking necessary criteria
            y_previous = y.clone()

            if adaptive:
                # Adaptive step of pdhg
                x, y, eta, eta_hat, j = adaptive_one_step_pdhg(x, y, c, q, K, l, u, m_ineq, eta, omega, theta, k, j)
            else:
                # Fixed step of pdhg
                x, y, eta, eta_hat = fixed_one_step_pdhg(x, y, c, q, K, l, u, m_ineq, eta, omega, theta)
                j += 1

            # Increase iteration counters
            t += 1

            # Update totals
            x_eta_total += eta * x
            y_eta_total += eta * y
            eta_total += eta

            # Update eta
            eta = eta_hat

            # Infeasibility detection
            lam = c - K.T @ y
            if k > 1:  # Need at least two points to compute differences
                dx = x - x_prev_iter
                dy = y - y_prev_iter
                dlam = lam - lam_prev

                # Dual infeasibility (primal unbounded) detection
                dlam_plus = (-dlam).clamp(min=0)
                dlam_minus = dlam.clamp(min=0)

                # Check dual infeasibility (primal unbounded)
                if m_eq == 0 or (A_mat @ dx).norm() < tol_eq:
                    Gdx = G_mat @ dx if m_ineq > 0 else torch.zeros((1, 1), device=device)
                    if (m_ineq == 0 or torch.all(Gdx >= -tol_feas)) and (c.T @ dx < tol_obj):
                        bounds_ok = True
                        for i in range(x.shape[0]):
                            dx_i = dx[i].item()
                            c_i = c[i].item()
                            l_i = l[i].item()
                            u_i = u[i].item()
                            if not (
                                (not torch.isinf(l[i]) and not torch.isinf(u[i]) and abs(dx_i) <= tol_feas) or
                                (u_i == float('inf') and c_i >= 0 and dx_i >= -tol_feas) or
                                (l_i == -float('inf') and c_i <= 0 and dx_i <= tol_feas)
                            ):
                                bounds_ok = False
                                break
                        if bounds_ok:
                            certificate_flag = "DUAL_INFEASIBLE"
                            certificate_iter = k
                            if verbose:
                                print(f"[PDLP] Dual infeasibility detected at iter {k}")
                            break

                # Primal infeasibility (dual unbounded) detection
                dy_in = dy[:m_ineq] if m_ineq > 0 else torch.zeros((0, 1), device=device)
                dy_eq = dy[m_ineq:] if m_eq > 0 else torch.zeros((0, 1), device=device)
                dual_res = torch.zeros_like(x)
                if m_ineq > 0: dual_res += G_mat.T @ dy_in
                if m_eq > 0: dual_res += A_mat.T @ dy_eq
                dual_res -= dlam

                if dual_res.norm() < tol_dual_eq and (m_ineq == 0 or torch.all(dy_in >= -tol_feas)):
                    dual_combo = 0.0
                    if m_ineq > 0: dual_combo += (h.T @ dy_in).item()
                    if m_eq > 0: dual_combo += (b.T @ dy_eq).item()

                    finite_l = (~torch.isinf(l).view(-1)) & (l.view(-1) != 0)
                    finite_u = (~torch.isinf(u).view(-1)) & (u.view(-1) != 0)

                    if finite_l.any():
                        dual_combo -= (l[finite_l].view(1, -1) @ dlam_minus[finite_l].view(-1, 1)).item()
                    if finite_u.any():
                        dual_combo -= (u[finite_u].view(1, -1) @ dlam_plus[finite_u].view(-1, 1)).item()

                    if dual_combo > -tol_cert:
                        certificate_flag = "PRIMAL_INFEASIBLE"
                        certificate_iter = k
                        if verbose:
                            print(f"[PDLP] Primal infeasibility detected at iter {k}")
                        break

            # Update previous iterates for next infeasibility check
            x_prev_iter.copy_(x)
            y_prev_iter.copy_(y)
            lam_prev.copy_(lam)

            # Check Restart Criteria Every restart_period iterations
            if t % restart_period == 0:
                # Compute averages
                x_avg = x_eta_total / eta_total
                y_avg = y_eta_total / eta_total

                # Compute KKT errors
                KKT_current = KKT_error(x, y, c, q, K, m_ineq, omega, is_neg_inf, is_pos_inf, l_dual, u_dual, device)
                KKT_average = KKT_error(x_avg, y_avg, c, q, K, m_ineq, omega, is_neg_inf, is_pos_inf, l_dual, u_dual, device)
                KKT_min = min(KKT_current, KKT_average)
                KKT_previous = KKT_error(x_previous, y_previous, c, q, K, m_ineq, omega, is_neg_inf, is_pos_inf, l_dual, u_dual, device)

                # Add three kkt passes
                j += 3

                # Check Restart Criteria and update with Restart Candidate
                if KKT_min <= beta[0] * KKT_first: # Sufficient Criteria
                    if verbose:
                        print(f"Sufficient restart at iteration {t} using the",
                              "Average iterate." if KKT_current >= KKT_average else "Current iterate.")
                    (x, y) = (x_avg, y_avg) if KKT_current >= KKT_average else (x, y)
                    break
                elif KKT_min <= beta[1] * KKT_first and KKT_min > KKT_previous: # Necessary Criteria
                    if verbose:
                        print(f"Necessary restart at iteration {t} using the",
                              "Average iterate." if KKT_current >= KKT_average else "Current iterate.")
                    (x, y) = (x_avg, y_avg) if KKT_current >= KKT_average else (x, y)
                    break
                elif t >= beta[2] * k: # Artificial Criteria
                    if verbose:
                        print(f"Artificial restart at iteration {t} using the",
                              "Average iterate." if KKT_current >= KKT_average else "Current iterate.")
                    (x, y) = (x_avg, y_avg) if KKT_current >= KKT_average else (x, y)
                    break

        # ------------- End Inner Loop ------------

        # Handle infeasibility certificate
        if certificate_flag:
            prim_obj = c.T @ x
            if verbose:
                print(f"Terminating due to {certificate_flag} at iteration {k}")
            return x, prim_obj.cpu().item(), k, n, j

        n += 1 # Increase restart loop counter

        if primal_update: # Primal weight update
            omega = primal_weight_update(x_last_restart, x, y_last_restart, y, omega, 0.5)

        KKT_first = KKT_error(x, y, c, q, K, m_ineq, omega, is_neg_inf, is_pos_inf, l_dual, u_dual, device)
        j += 1 # Add one kkt pass

        # Compute primal and dual residuals, and duality gap
        if precondition:
            D_col, D_row, K_unscaled, c_unscaled, q_unscaled, l_unscaled, u_unscaled = data_precond
            l_unscaled[is_neg_inf] = 0
            u_unscaled[is_pos_inf] = 0
            primal_residual, dual_residual, duality_gap, prim_obj, adjusted_dual = compute_residuals_and_duality_gap(
                D_col * x, D_row * y, c_unscaled, q_unscaled, K_unscaled, m_ineq,
                is_neg_inf, is_pos_inf, l_unscaled, u_unscaled
            )
        else:
            primal_residual, dual_residual, duality_gap, prim_obj, adjusted_dual = compute_residuals_and_duality_gap(
                x, y, c, q, K, m_ineq, is_neg_inf, is_pos_inf, l_dual, u_dual
            )
        j += 1 # Add one kkt pass

        if verbose:
            rel_gap = duality_gap.item() / (1 + abs(prim_obj.item()) + abs(adjusted_dual.item()))
            rel_primal = primal_residual.item() / (1 + q_norm)
            rel_dual = dual_residual.item() / (1 + c_norm)

            print(f"[{k}] Primal Obj: {prim_obj.item():.4f}, Adjusted Dual Obj: {adjusted_dual.item():.4f}, "
                  f"Gap: {rel_gap:.2e}, Prim Res: {rel_primal:.2e}, Dual Res: {rel_dual:.2e}")
            print("")

        # Termination conditions
        if check_termination(primal_residual, dual_residual, duality_gap, prim_obj, adjusted_dual, q_norm, c_norm, tol):
            if verbose:
                print(f"Converged at iteration {k} after {n} restart loops")
            break

    # ------------------- End Outer Loop ------------------------

    return x, prim_obj.cpu().item(), k, n, j