In [198]:
%load_ext autoreload
%autoreload 2

$$\operatorname{minimize}_{C, Q} t^{T} \operatorname{diag}(C) \\
\operatorname{subject to} \begin{array}{l} {\boldsymbol{C} \boldsymbol{X}-\boldsymbol{X}+\boldsymbol{Q}=0} \\ {\|\boldsymbol{Q}\|_{\infty, 1} \leq \tau} \\ {\boldsymbol{C} \geq 0} \\ {\operatorname{diag}(\boldsymbol{C}) \leq \mathbf{1}} \\ {\operatorname{Tr}(\boldsymbol{C})=r} \\ {C_{i j} \leq C_{j j} \text { for all } i, j}\end{array}$$


In [142]:
import numpy as np
from scipy.optimize import linprog

In [159]:
f = 2
r = 2
n = 5

In [160]:
F = np.random.random((f, r))
W = np.random.random((r, n))
X = F @ W

In [161]:
def build_a_eq(X=X, f=f, n=n):
    A = np.zeros((f*n + 1, f**2))
    for i in range(f):
        # Handle CX = X
        A[i*n:(i+1)*n, i*f:(i+1)*f] = X.T
        # Handle Tr(C) = r
        A[-1, i*(f+1)] = 1
    return A

In [162]:
def build_b_eq(X=X, r=r):
    b = np.zeros((f*n+1,))
    # Handle CX = X
    b[:-1] = X.flatten()
    # Handle Tr(C) = r
    b[-1] = r
    return b

In [163]:
def check_ab_eq(A, b, c, X=X, r=r):
    res_1 = (A @ c - b).sum()
    res_2 = (c.reshape(f, f) @ X - X).sum() + np.trace(c.reshape(f, f)) - r
    return  res_1 - res_2 < 1e-5

In [164]:
def build_a_ub(X=X, f=f, n=n):
    A = np.zeros((f**2, f**2))
    index_A = 0
    # Handle diag(C) <= 1
    for i in range(f):
        A[i, i*(f+1)] = 1
        index_A += 1
        
    # Handle C_ij <= C_jj  
    for j in range(f):
        for i in range(f):
            if i == j:
                continue
            C = np.zeros((f, f))
            C[j, j] = -1
            C[i, j] = 1
            A[index_A, :] = C.flatten()
            index_A += 1
    return A

In [165]:
def build_b_ub(f=f):
    b = np.zeros(f**2)
    b[:f] += 1
    return b

In [166]:
A_eq = build_a_eq()
b_eq = build_b_eq()

In [167]:
A_ub = build_a_ub()
b_ub = build_b_ub()

In [168]:
c = np.random.random((f, f)).flatten()
check_ab_eq(A_eq, b_eq, c)

True

In [169]:
t = (np.random.random((f, f)) * np.eye(f)).flatten()

In [170]:
A_eq

array([[0.44449541, 0.26962565, 0.        , 0.        ],
       [0.68069341, 0.36725849, 0.        , 0.        ],
       [0.68494016, 0.39777066, 0.        , 0.        ],
       [0.25811041, 0.14317446, 0.        , 0.        ],
       [0.29119356, 0.15106983, 0.        , 0.        ],
       [0.        , 0.        , 0.44449541, 0.26962565],
       [0.        , 0.        , 0.68069341, 0.36725849],
       [0.        , 0.        , 0.68494016, 0.39777066],
       [0.        , 0.        , 0.25811041, 0.14317446],
       [0.        , 0.        , 0.29119356, 0.15106983],
       [1.        , 0.        , 0.        , 1.        ]])

In [171]:
opt = linprog(t, method='simplex', A_eq=A_eq, b_eq=b_eq, A_ub=A_ub, b_ub=b_ub)



In [172]:
opt

     con: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
     fun: 0.6474986583754387
 message: 'Optimization terminated successfully.'
     nit: 10
   slack: array([0., 0., 1., 1.])
  status: 0
 success: True
       x: array([1., 0., 0., 1.])

In [56]:
res = opt.x.reshape(f, f)

In [57]:
res

array([[ 1.00000000e+00,  4.84509898e-16, -4.09143168e-16,
         0.00000000e+00,  0.00000000e+00],
       [ 8.64570028e-01,  4.84509898e-16,  0.00000000e+00,
         1.65257961e-01,  6.16553661e-17],
       [ 8.07262657e-01,  4.84509898e-16,  0.00000000e+00,
         3.30523745e-01,  1.97758476e-16],
       [ 2.33815068e-16,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  1.94548645e-16],
       [ 7.41614846e-01,  0.00000000e+00,  0.00000000e+00,
         2.97397148e-01,  8.67361738e-17]])

In [58]:
res @ X

array([[0.87737854, 0.44599603, 0.72962892, 0.3062781 , 0.49816112,
        0.71967298, 0.08012752, 0.49408122, 1.05989639, 0.74910025,
        0.40473732, 0.43714062, 0.543676  , 0.89584383, 0.43382241,
        0.40420323, 0.70491957, 0.31638458, 0.61741974, 0.54584087,
        0.5420589 , 0.35992293, 0.95682839, 0.45067002, 0.7596635 ,
        0.32483646, 0.84830287, 0.88350543, 0.83686203, 0.29201092,
        0.54140802, 0.63870575, 0.26665534, 0.67532342, 0.85490648,
        0.71029236, 0.28713456, 0.78889127, 0.51223608, 0.16354085,
        0.58806081, 0.54185641, 0.54508025, 0.63035457, 0.72508343,
        0.31525106, 0.25975999, 0.6508814 , 0.45932447, 0.30651536,
        0.93039235, 0.62229188, 0.87956095, 1.09754001, 0.29930625,
        0.05511511, 0.44069781, 0.5417554 , 0.70874167, 0.51312525,
        0.44879523, 0.28220901, 1.06620143, 0.57673799, 0.56341261,
        0.25616988, 0.73827132, 0.78051997, 0.50244872, 0.75680368,
        1.06609712, 0.93968568, 0.31346303, 0.14

In [222]:
f = 10
r = 2
n = 50

In [232]:
from simplex_nmf import SimplexNMF

In [233]:
simplex = SimplexNMF(f=f, r=r, n=n)

In [234]:
simplex.run()

||CX - X|| = 0.012386713505034563
diag(C) = [9.74217246e-16 5.04859512e-16 7.08682576e-01 0.00000000e+00
 9.19808408e-01 0.00000000e+00 0.00000000e+00 6.07784789e-02
 0.00000000e+00 3.10730537e-01]
