In [1]:
import numpy as np

In [2]:
def projection_onto_simplex(z: np.ndarray, d: float = 1.0) -> np.ndarray:
    """
    Projection onto the simplex.
    Parameters:
        z: np.ndarray
            Input vector.
        d: float
            Value of the projection.
    Returns:
        p: np.ndarray
            Projection of z onto the simplex.
    References:
        Michelot, C. (1986). A finite algorithm for finding the projection of a point onto the canonical simplex of ∝n. 
        Journal of Optimization Theory and Applications, 50, 195-200.
    """
    p = z + 1/len(z) * (d - np.sum(z))
    while p.min() < 0:
        inc_ind = np.where(p < 0)
        dec_ind = np.where(p > 0)
        p[inc_ind] = 0
        p[dec_ind] = p[dec_ind] + 1/np.size(dec_ind) * (d - np.sum(p[dec_ind]))
    return p

In [6]:
z = np.random.rand(5000000)
p = projection_onto_simplex(z, 5)
print(sum(p))

4.999999999999994
