In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from snmfem.estimator.nmf import NMF
from snmfem.measures import KLdiv
# from snmfem.toy import create_toy_problem

In [7]:
from snmfem import models

In [10]:
Model = getattr(models, "Toy")

AttributeError: module 'snmfem.models' has no attribute 'Toy'

In [None]:
def create_toy_problem(l = 25, k = 3, p = 100, c = 10, n_poisson=200, force_simplex=True,**kwargs):

    A = np.random.rand(k,p)
    if force_simplex:
        A = A/np.sum(A, axis=0, keepdims=True)
    
    G = np.random.rand(l,c)
    P = np.random.rand(c,k)
    GP = G @ P

    X = GP @ A

    Xdot = 1/n_poisson * np.random.poisson(n_poisson * X)

    return G, P, A, X, Xdot

In [None]:
l = 25
k = 3
p = 100
c = 10
n_poisson=2000
G, P, A, X, Xdot = create_toy_problem(l, k, p, c, n_poisson)
G.shape
# G = np.eye(c)
# A = np.eye(k)
# X = G @ P @ A

In [None]:
from pyunlocbox import functions
from pyunlocbox.solvers import mlfbf, solve
from snmfem.conf import log_shift
# Let us solve the following problem
# 
# \argmin_P  - \sum_ij  X_ij log(GPA)_ij + (GPA)_ij  s.t. P \geq 0

# \argmin_A  - \sum_ij  X_ij log(GPA)_ij + (GPA)_ij + f(A) s.t. A \geq 0


# class smooth(functions.func):
#     def __init__(self, G, A):
#         self.A = A
#         self.G = G
#         self.nabla = np.sum(self.G.T, axis=1, keepdims=True) @ np.sum(self.A.T, axis=0, keepdims=True)

#     def _eval(self, P):
#         return np.sum(self.G @ P @ self.A)
#     def _grad(self, P):
#         return self.nabla

# f_smooth = smooth(G, A)
# assert(f_smooth.grad(P).shape == P.shape)
# np.testing.assert_allclose(f_smooth.grad(P),f_smooth.grad(2*P))


class mxlogvpv(functions.func):
    def __init__(self, X, lambda_=1, log_shift=log_shift, **kwargs):
        super(mxlogvpv, self).__init__(**kwargs)
        self.lambda_ = lambda_
        self.X = X
        self.log_shift = log_shift
        self.offset = np.sum(self.X * np.log(self.X+log_shift)) - np.sum(self.X)

    def _eval(self, P):
        return self.offset * (self.offset - self.lambda_ * np.sum(self.X * np.log(P+log_shift)) + np.sum(P))
    def _prox(self, P, T):
        gamma = self.lambda_ * T
        delta = (P - gamma)**2 + 4 * gamma * self.X
        return (P-gamma + np.sqrt(delta))/2

def solve_P(G, A, X, **kwargs):
    L = lambda P : G @ P @ A
    Lt = lambda X: G.T @ X @ A.T 

    f = functions.proj_positive()
    g = mxlogvpv(X)
    h = functions.dummy()
    beta = 0
    mu = beta + np.linalg.norm(G,2) * np.linalg.norm(A,2)
    step = 1 / mu / 2

    x0 = np.zeros(P.shape)
    D = np.linalg.lstsq(A.T, X.T)[0].T
    x0 = np.abs(np.linalg.lstsq(G, D)[0])
    d0 = np.zeros(X.shape)

    solver = mlfbf(step=step, L=L, Lt=Lt, d0=d0 )
    ret = solve([f, g, h], x0, solver, **kwargs)

    sol = ret["sol"], np.array(ret["objective"]).sum(axis=1)
    return sol

def solve_A(G, P, X, **kwargs):
    GP = G @ P
    L = lambda A : GP @ A
    Lt = lambda X: GP.T @ X
    
    n = GP.shape[1]
    opA = lambda A: np.sum(A, 0, keepdims=True)
    opAt = lambda V: np.ones([n,1]) @ V
    y = np.ones([1, X.shape[1]])

    f = functions.proj_positive()
    g = mxlogvpv(X)
    h = function.norm_l2(A=OpA, At=OpAt, y=y)
    beta = 0
    mu = beta + np.linalg.norm(G,2) * np.linalg.norm(A,2)
    step = 1 / mu / 2

    x0 = np.zeros(P.shape)
    D = np.linalg.lstsq(A.T, X.T)[0].T
    x0 = np.abs(np.linalg.lstsq(G, D)[0])
    d0 = np.zeros(X.shape)

    solver = mlfbf(step=step, L=L, Lt=Lt, d0=d0 )
    ret = solve([f, g, h], x0, solver, **kwargs)

    sol = ret["sol"], np.array(ret["objective"]).sum(axis=1)
    return sol

In [None]:
sol,  objective = solve_P(G, A, X, maxit=1000, rtol=1e-15)

plt.plot(objective)
plt.yscale("log")

np.linalg.norm(sol - P, "fro") / np.linalg.norm(P, "fro"), KLdiv(X, G @ sol, A)

In [None]:
# D = np.linalg.lstsq(A.T, Xdot.T)[0].T
# x0 = np.abs(np.linalg.lstsq(G, D)[0])
# D, np.linalg.lstsq(G, D)[0]

In [None]:
sol,  objective = solve_P(G, A, Xdot, maxit=100, rtol=1e-15)

plt.plot(objective)
# plt.yscale("log")

np.linalg.norm(sol - P, "fro") / np.linalg.norm(P, "fro"), KLdiv(X, G @ sol, A)

In [None]:
sol

In [None]:
from pyunlocbox import functions


In [None]:
est = NMF(G=G, n_components=k, debug=True, max_iter=200,  force_simplex=True)

In [None]:
Ps, As = est.fit_transform(X, eval_print=20)

In [None]:
# est = NMF(G=G, n_components=k, debug=True, max_iter=2000, force_simplex=False)
# Pss, Ass = est.fit_transform(Xdot, eval_print=200)
# Ainit = Ass/np.sum(Ass, axis=0, keepdims=True)
        # - 'random': non-negative random matrices, scaled with:
        #     sqrt(X.mean() / n_components)
        # - 'nndsvd': Nonnegative Double Singular Value Decomposition (NNDSVD)
        #     initialization (better for sparseness)
        # - 'nndsvda': NNDSVD with zeros filled with the average of X
        #     (better when sparsity is not desired)
        # - 'nndsvdar': NNDSVD with zeros filled with small random values
        #     (generally faster, less accurate alternative to NNDSVDa
        #     for when sparsity is not desired)
est = NMF(G=G, n_components=k, debug=True, mu=0, max_iter=2000,tol=0, force_simplex=False, init="nndsvd")
Pss, Ass = est.fit_transform(Xdot, eval_print=200)
# np.sum(Ass, axis=0)

In [None]:
# est = NMF(G=G, n_components=k, debug=True, max_iter=10, force_simplex=True)
# Pss, Ass = est.fit_transform(Xdot, A=A, eval_print=1)

In [None]:
vmin = 0
vmax = 1

plt.subplot(2,1,1)
plt.imshow(A, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.subplot(2,1,2)
plt.imshow(Ass, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.show()

In [None]:
np.min(Ass)

In [None]:
k = 15
num = np.random.rand(k,23)/20
denum = np.random.rand(k,23)

def f(nu):
    sol = 0
    for n,d in zip(num, denum):
        sol = sol + n/(d+nu)
    return sol - 1

In [None]:
r = np.sum(num/denum, axis=0)
# t = np.sum(num)
# if r<1:
#     bmax = 0
# else:
ind_min = np.argmax(num/denum, axis=0)
ind_min2 = np.argmin(denum, axis=0)
ind = np.arange(len(ind_min))

bmin1 = num[ind_min, ind]-denum[ind_min, ind]
bmin2 = num[ind_min2, ind]-denum[ind_min2, ind]
bmin = np.maximum(bmin1, bmin2)
bmax = r

# nu = np.arange(bmin, bmax, 1e-3)
# plt.plot(nu, f(nu))
# plt.show()

In [None]:
bmin, bmax

In [None]:
bmin

In [None]:
ind_max = np.argmin(num/denum)
num[ind_max]/denum[ind_max]

In [None]:
"""
Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection.
For details and references, see the following paper:
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
ICPR 2014.
http://www.mblondel.org/publications/mblondel-icpr2014.pdf
"""

import numpy as np


def projection_simplex_sort(v, z=1):
    n_features = v.shape[0]
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u) - z
    ind = np.arange(n_features) + 1
    cond = u - cssv / ind > 0
    rho = ind[cond][-1]
    theta = cssv[cond][-1] / float(rho)
    w = np.maximum(v - theta, 0)
    return w


def projection_simplex_pivot(v, z=1, random_state=None):
    rs = np.random.RandomState(random_state)
    n_features = len(v)
    U = np.arange(n_features)
    s = 0
    rho = 0
    while len(U) > 0:
        G = []
        L = []
        k = U[rs.randint(0, len(U))]
        ds = v[k]
        for j in U:
            if v[j] >= v[k]:
                if j != k:
                    ds += v[j]
                    G.append(j)
            elif v[j] < v[k]:
                L.append(j)
        drho = len(G) + 1
        if s + ds - (rho + drho) * v[k] < z:
            s += ds
            rho += drho
            U = L
        else:
            U = G
    theta = (s - z) / float(rho)
    return np.maximum(v - theta, 0)


def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000):
    lower = 0
    upper = np.max(v)
    current = np.inf

    for it in range(max_iter):
        if np.abs(current) / z < tau and current < 0:
            break

        theta = (upper + lower) / 2.0
        w = np.maximum(v - theta, 0)
        current = np.sum(w) - z
        if current <= 0:
            upper = theta
        else:
            lower = theta
    return w

n = 10
rs = np.random.RandomState(0)
v = rs.rand(n)
z = 1
print(z)

w1 = projection_simplex_sort(v, z)
print(np.sum(w1))

w2 = projection_simplex_pivot(v, z)
print(np.sum(w2))

w3 = projection_simplex_bisection(v, z)
print(np.sum(w3))

In [None]:
Aop = lambda x : np.sum(x)
pinvA = np.linalg.pinv(np.ones([1,n]))
pinvAop = lambda x: 1/n * x

def mysolution(x):
    return x - pinvAop(Aop(x) - 1)
mysolution(v)

In [None]:
# some plots