In [49]:
import numpy as np

def newupdate(
    G: np.ndarray,
    *,
    zero_if_nonpositive: bool = False,
    dtype=np.uint8,
) -> np.ndarray:
    """
    Given a 2D array G (shape m x n), return X (m x n) where each column j is
    one-hot at the index of the maximum entry of G[:, j]. If zero_if_nonpositive=True
    and the column's maximum <= 0, the whole column is set to zeros.

    - NaNs in G are treated as -inf (ignored for maxima).
    - Ties are broken by np.argmax's rule (first occurrence).

    Parameters
    ----------
    G : np.ndarray
        Gradient / environment matrix of shape (m, n).
    zero_if_nonpositive : bool, default True
        If True, columns whose maximum <= 0 become all zeros.
        If False, columns are always one-hot at the argmax (even if max <= 0).
    dtype : numpy dtype, default np.uint8
        Dtype of the output (0/1).

    Returns
    -------
    X : np.ndarray
        One-hot selection matrix of shape (m, n).
    """
    if G.ndim != 2:
        raise ValueError("G must be a 2D array (m x n).")

    # Treat NaNs as -inf so they never win the argmax
    G_clean = np.where(np.isnan(G), -np.inf, G)

    # Argmax per column (shape: (n,))
    idx = np.argmax(G_clean, axis=0)

    # Build one-hot result
    m, n = G_clean.shape
    X = np.zeros((m, n), dtype=dtype)
    if n > 0:
        X[idx, np.arange(n)] = 1

    if zero_if_nonpositive:
        col_max = np.max(G_clean, axis=0)  # (n,)
        mask = col_max > 0                 # keep only columns with strictly positive max
        if not np.all(mask):
            # zero out columns that don't pass the positivity test
            X[:, ~mask] = 0

    return X


In [50]:
G = np.array([[0.1, -2.0,  0.0],
              [0.5,  0.3, -1.0],
              [0.2,  0.3,  0.0]])

X = newupdate(G)  # default with nonzeroing columns
print(X)
# column-wise: [argmax -> row 2], [argmax -> row 1], [max == 0 -> zero column]


[[0 0 1]
 [1 1 0]
 [0 0 0]]
