In [None]:
import numpy as np
def awaystep_fw(
    objective_fun, 
    gradient_fun, 
    LMO, 
    x0, 
    hyperparams={"max_iterations": 20, "tolerance": 1e-6}
):
    """
    Inputs:
        - objective_fun (callable): Function f(x) to minimize (or maximize if sign=-1).
        - gradient_fun (callable): Gradient of f(x).
        - projection_operator (callable): Projection onto the feasible set (e.g., Lâˆž-ball).
        - x0 (np.ndarray): Initial feasible point.
        - hyperparams (dict): Dictionary with hyperparameters:
            - "max_iterations" (int): Maximum number of iterations.
            - "tolerance" (float): Tolerance on gradient norm.
    Outputs:
        - x_t (np.ndarray): Final solution.
        - t (int): Number of iterations performed.
        - history (dict): Contains 'objective' and 'gradient_norm'.
    """
    # Defining the parameters
    x_t = x0.copy().astype(np.float64)
    max_iterations = hyperparams["max_iterations"]
    tolerance = hyperparams["tolerance"]

    # Defining the active set S and weights alpha_t
    active_set = [x0.copy()]
    weights = np.array([1.0])

    # History trackers
    history = {'objective': [], 'gap': [], 'gradient': []}

    # Starting the Away-step Frank-Wolfe iterations
    for t in range(1, max_iterations + 1):

        # Compute the loss and gradient at the current point
        objective_t = objective_fun(x_t)
        grad_t = gradient_fun(x_t)

        # Compute the FW candidate direction
        s_t = LMO(grad_t)
        d_t_FW = s_t - x_t

        # Compute the FW duality gap
        gap_t_FW = - grad_t @ d_t_FW

        # Compute the away vertex
        v_scores = np.array([grad_t @ v for v in active_set])
        v_t_idx = np.argmax(v_scores)
        v_t = active_set[v_t_idx]
        
        # Compute the Away candidate direction
        d_t_Away = x_t - v_t 

        # Compute the Away duality gap
        gap_t_Away = - grad_t @ d_t_Away 

        # Store history
        history['gradient'].append(grad_t)
        history['objective'].append(objective_t)
        history['gap'].append(max(gap_t_FW, gap_t_Away))

        # Check for convergence
        if gap_t_FW <= tolerance:
            print(f"Duality gap below tolerance at iteration {t}: {gap_t_FW}")
            break

        # Choose between FW step and Away step based on the duality gap
        if gap_t_FW >= gap_t_Away:
            selected_direction = 'FW'
            selected_d_t = d_t_FW
            gamma_max = 1.0
        else:
            selected_direction = 'Away'
            selected_d_t = d_t_Away
            gamma_max = weights[v_t_idx] / (1.0 - weights[v_t_idx] + 1e-10)

        # Determine step size with diminishing step size rule
        gamma_t = min(2 / (t + 2), gamma_max)

        # Compute the next step
        x_t += gamma_t * selected_d_t

        # Update active set and weights based when choose FW
        if selected_direction == 'FW':
            if gamma_t == gamma_max:
                # If we take a full step, we can reset the active set
                active_set = [s_t]
                weights = np.array([1.0])
            else:
                is_in_active_set = any(np.all(s_t == v) for v in active_set)
                # If we take a partial step, we need to update the active set
                if not is_in_active_set:
                    active_set.append(s_t)
                    weights = np.append(weights, gamma_t)
                else:
                    # Update the weight of s_t in the active set
                    s_t_idx = active_set.index(s_t)
                    weights *= (1.0 - gamma_t)
                    weights[s_t_idx] += gamma_t

        # Update active set and weights based when choose Away Step
        if selected_direction == "Away":
            if gamma_t == gamma_max:
                # Remove v_t from active set
                active_set.pop(v_t_idx)
                weights = np.delete(weights, v_t_idx) 
            else:
                 # Update weight of v_t
                weights *= (1.0 + gamma_t)
                weights[v_t_idx] -= gamma_t

    return x_t, t, history