In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
# from snmfem.experiments import load_samples, print_results, load_data, run_experiment
# from snmfem.measures import KL


In [5]:
def create_toy_problem(l = 25, k = 3, p = 100, c = 10, n_poisson=200, force_simplex=True):

    A = np.random.rand(k,p)
    if force_simplex:
        A = A/np.sum(A, axis=0, keepdims=True)
    
    G = np.random.rand(l,c)
    P = np.random.rand(c,k)
    GP = G @ P

    X = GP @ A

    Xdot = 1/n_poisson * np.random.poisson(n_poisson * X)

    return G, P, A, X, Xdot

In [6]:
shape_2d = [6, 10]
k = 5
n_poisson = 30
G, P, true_maps, Xtrue, Xflat = create_toy_problem(p = shape_2d[0]*shape_2d[1], k=k, n_poisson=n_poisson)
true_spectra = (G @ P).T



In [8]:
maxit = 10
tol = 1e-7

In [None]:
from pyunlocbox import functions
from pyunlocbox.solvers import mlfbf, solve
from snmfem.conf import log_shift
# Let us solve the following problem
# 
# \argmin_P  - \sum_ij  X_ij log(GPA)_ij + (GPA)_ij  s.t. P \geq 0

# \argmin_A  - \sum_ij  X_ij log(GPA)_ij + (GPA)_ij + f(A) s.t. A \geq 0


# class smooth(functions.func):
#     def __init__(self, G, A):
#         self.A = A
#         self.G = G
#         self.nabla = np.sum(self.G.T, axis=1, keepdims=True) @ np.sum(self.A.T, axis=0, keepdims=True)

#     def _eval(self, P):
#         return np.sum(self.G @ P @ self.A)
#     def _grad(self, P):
#         return self.nabla

# f_smooth = smooth(G, A)
# assert(f_smooth.grad(P).shape == P.shape)
# np.testing.assert_allclose(f_smooth.grad(P),f_smooth.grad(2*P))


class mxlogvpv(functions.func):
    def __init__(self, X, lambda_=1, log_shift=log_shift, **kwargs):
        super(mxlogvpv, self).__init__(**kwargs)
        self.lambda_ = lambda_
        self.X = X
        self.log_shift = log_shift
        self.offset = np.sum(self.X * np.log(self.X+log_shift)) - np.sum(self.X)

    def _eval(self, P):
        return self.offset * (self.offset - self.lambda_ * np.sum(self.X * np.log(P+log_shift)) + np.sum(P))
    def _prox(self, P, T):
        gamma = self.lambda_ * T
        delta = (P - gamma)**2 + 4 * gamma * self.X
        return (P-gamma + np.sqrt(delta))/2

def solve_P(G, A, X, **kwargs):
    L = lambda P : G @ P @ A
    Lt = lambda X: G.T @ X @ A.T 

    f = functions.proj_positive()
    g = mxlogvpv(X)
    h = functions.dummy()
    beta = 0
    mu = beta + np.linalg.norm(G,2) * np.linalg.norm(A,2)
    step = 1 / mu / 2

    x0 = np.zeros(P.shape)
    D = np.linalg.lstsq(A.T, X.T)[0].T
    x0 = np.abs(np.linalg.lstsq(G, D)[0])
    d0 = np.zeros(X.shape)

    solver = mlfbf(step=step, L=L, Lt=Lt, d0=d0 )
    ret = solve([f, g, h], x0, solver, **kwargs)

    sol = ret["sol"], np.array(ret["objective"]).sum(axis=1)
    return sol

def solve_A(G, P, X, **kwargs):
    GP = G @ P
    L = lambda A : GP @ A
    Lt = lambda X: GP.T @ X
    
    n = GP.shape[1]
    opA = lambda A: np.sum(A, 0, keepdims=True)
    opAt = lambda V: np.ones([n,1]) @ V
    y = np.ones([1, X.shape[1]])

    f = functions.proj_positive()
    g = mxlogvpv(X)
    h = function.norm_l2(A=OpA, At=OpAt, y=y)
    beta = 0
    mu = beta + np.linalg.norm(G,2) * np.linalg.norm(A,2)
    step = 1 / mu / 2

    x0 = np.zeros(P.shape)
    D = np.linalg.lstsq(A.T, X.T)[0].T
    x0 = np.abs(np.linalg.lstsq(G, D)[0])
    d0 = np.zeros(X.shape)

    solver = mlfbf(step=step, L=L, Lt=Lt, d0=d0 )
    ret = solve([f, g, h], x0, solver, **kwargs)

    sol = ret["sol"], np.array(ret["objective"]).sum(axis=1)
    return sol

In [None]:
sol,  objective = solve_P(G, A, X, maxit=1000, rtol=1e-15)

plt.plot(objective)
plt.yscale("log")

np.linalg.norm(sol - P, "fro") / np.linalg.norm(P, "fro"), KLdiv(X, G @ sol, A)

In [None]:
# D = np.linalg.lstsq(A.T, Xdot.T)[0].T
# x0 = np.abs(np.linalg.lstsq(G, D)[0])
# D, np.linalg.lstsq(G, D)[0]

In [None]:
sol,  objective = solve_P(G, A, Xdot, maxit=100, rtol=1e-15)

plt.plot(objective)
# plt.yscale("log")

np.linalg.norm(sol - P, "fro") / np.linalg.norm(P, "fro"), KLdiv(X, G @ sol, A)