In [5]:
import numpy as np

def filter_multinomial_rows(X, n):
    """
    Select rows from X that are valid multinomial draws with n trials.

    Parameters:
    - X: 2D numpy array
    - n: number of trials (integer)

    Returns:
    - filtered_rows: 2D array of rows that sum to n and contain only integers
    """
    # Ensure X is a numpy array
    X = np.asarray(X)
    
    # Condition 1: All elements are integers (check using modulo 1 == 0)
    is_integer = np.all(np.mod(X, 1) == 0, axis=1)
    
    # Condition 2: Row sum equals n
    sums_to_n = np.sum(X, axis=1) == n
    
    # Combine both conditions
    valid_rows = is_integer & sums_to_n
    
    return X[valid_rows]


In [6]:
X = np.array([
    [3, 2, 5],
    [2, 2, 6],
    [1.5, 2.5, 6],  # Not integers
    [4, 4, 1],      # Sum != 10
    [2, 3, 5]
])

n = 10
filtered = filter_multinomial_rows(X, n)
print(filtered)


[[3. 2. 5.]
 [2. 2. 6.]
 [2. 3. 5.]]
