We here consider the problem of logistic loss.

## Model Specification

$$
\mathbf{Y} = g\left( \mathbf{X}\mathbf{B}^\top + \mathbf{Z}\mathbf{\Gamma}^\top \right),
$$
where $g$ is the logistic loss function.

## Joint Maximum Likelihood Estimation

The negative log likelihood is,

$$
l = \sum_{i,j}\left[\log\left(1+e^{P_{ij}}\right) - Y_{ij}P_{ij}\right],
$$
where
$$
\begin{aligned}
P_{ij} &= \mathbf{X}_{i,:}\mathbf{B}_{j,:}^\top + \mathbf{Z}_{i,:}\mathbf{\Gamma}_{j,:}^\top\\
&= \sum_{k=1}^p X_{ik}B_{jk} + \sum_{l=1}^KZ_{il}\Gamma_{jl}.
\end{aligned}
$$

### Alternating Minimization

For the first step, consider $\mathbf{Z}$ is known, then this is a Multivariate logistic regression. Since the model assumes conditional independence, we can do it column-wise using sklearn package.

For the second step, consider $\mathbf{B}$ and $\mathbf{\Gamma}$ is known, then, 
$$
\frac{\partial l}{\partial Z_{il}} = \sum_{j=1}^q\frac{\partial l}{\partial P_{ij}}\frac{\partial P_{ij}}{Z_{il}} = \sum_{j=1}^p\frac{e^{P_{ij}}}{1+e^{P_{ij}}}\Gamma_{jl} = 0.
$$

There is no explicit solution but is a convex optimization problem. Even though Newton-step is preferred, I will use CVXPY to solve this convex optimization problem. Or scipy.optimize using minimize function and default setting should be good enough.

In [1]:
import numpy as np
import cvxpy as cp
import scipy
from scipy.optimize import minimize
from scipy.special import softplus

In [2]:
def JML(Y, X, K, esp = 1e-3):
    q = Y.shape[1]
    p = X.shape[1]
    n = X.shape[0]

    # Initialize Z
    Z0 = np.random.randn(n, K)
    Gamma0 = np.random.randn(q, K)
    B0 = np.random.randn(q, p)

    def iter_z_fix(Z, B_old, Gamma_old):
        def obj_fn(BGamma, shape):
            BGamma = BGamma.reshape(shape)
            B = BGamma[:,:p]
            Gamma = BGamma[:,p:]
            P = X @ B.T + Z @ Gamma.T
            return np.sum(softplus(P)-Y * P)

        shape = (q, p + K)
        BGamma0 = np.hstack([B_old, Gamma_old])
        BGamma0 = BGamma0.flatten()
        opt_res = minimize(obj_fn, BGamma0, args=(shape,), method='Newton-CG')
        
        BGamma = opt_res.x.reshape(shape)
        B = BGamma[:,:p]
        Gamma = BGamma[:,p:]
        return B, Gamma
    
    def iter_BG_fix(B, Gamma, Z_old):
        def obj_fn(Z, shape):
            Z = Z.reshape(shape)
            P = X @ B.T + Z @ Gamma.T
            return np.sum(softplus(P)-Y * P)
        
        shape = Z_old.shape
        Z_old = Z_old.flatten()
        # opt_res = minimize(obj_fn, Z_old, args=(shape,), options={'maxiter':int(1e7)})
        opt_res = minimize(obj_fn, Z_old, args=(shape,), method='Newton-CG')
        return opt_res.x.reshape(shape)
        # if opt_res.success:
        #     return opt_res.x.reshape(shape)
        # else:
        #     raise AssertionError("Not Reach Minimum")
    
    ## Stop Here. There should be a while loop to connect everything together
    # give me while loop of iter_z_fix and iter_BG_fix
    Z = Z0
    B, Gamma = iter_z_fix(Z0, B0, Gamma0)
    Z_new = iter_BG_fix(B, Gamma, Z)
    err = scipy.linalg.norm(Z_new - Z)
    iter = 0
    while err > esp:
        Z = Z_new
        B_new, Gamma_new = iter_z_fix(Z, B, Gamma)
        Z_new = iter_BG_fix(B_new, Gamma_new, Z)
        err = np.max([scipy.linalg.norm(Z_new - Z), scipy.linalg.norm(B_new - B), scipy.linalg.norm(Gamma_new - Gamma)])
        B = B_new
        Gamma = Gamma_new
        iter += 1

    return B, Z_new, Gamma, iter

def rotate_sparse(X, B, Z, Gamma):
    n = Z.shape[0]
    q = B.shape[0]
    p = B.shape[1]
    K = Gamma.shape[1]
    X_cov = X[:,1:]

    def obj_fn(Ac):
        return cp.abs(B[1:,:] - Ac.T @ Gamma)
    
    Ac0 = cp.Variable((q-1, K))
    prob = cp.Problem(cp.Minimize(obj_fn(Ac0)))
    prob.solve()
    Ac = Ac0.value
    a0 = -np.mean(X_cov @ Ac.T + Z, axis=0)
    A = np.concat([a0.reshape(1, K), Ac], axis=0)

    B_new = B - Gamma @ A
    Z_mid = Z + X @ A.T

    GG_half = scipy.linalg.fractional_matrix_power(Gamma.T @ Gamma, 0.5)
    G_full = GG_half @ Z_mid.T @ Z_mid @ GG_half /n/q
    d, V = np.linalg.eigh(G_full)
    D = np.diag(d**(-1/4))
    G = GG_half @ V @ D / np.sqrt(q)

    Gamma_new = Gamma @ scipy.linalg.inv(G_full.T)
    Z_new = Z_mid @ G

    return B_new, Z_new, Gamma_new

def rotate_ortho(X, B, Z, Gamma):
    n = Z.shape[0]
    q = B.shape[0]
    p = B.shape[1]
    K = Gamma.shape[1]

    A = - Z.T @ X @ scipy.linalg.inv(X.T @ X)

    B_new = B - Gamma @ A
    Z_mid = Z + X @ A.T

    GG_half = scipy.linalg.fractional_matrix_power(Gamma.T @ Gamma, 0.5)
    G_full = GG_half @ Z_mid.T @ Z_mid @ GG_half /n/q
    d, V = np.linalg.eigh(G_full)
    D = np.diag(d**(-1/4))
    G = GG_half @ V @ D / np.sqrt(q)

    Gamma_new = Gamma @ scipy.linalg.inv(G_full.T)
    Z_new = Z_mid @ G

    return B_new, Z_new, Gamma_new

In [3]:
def denoise(Y, X, lam):
    n = Y.shape[0]
    q = Y.shape[1]
    p = X.shape[1]

    def obj_fn(B, L):
        P = X @ B.T + L
        return cp.sum(cp.logistic(P)-cp.multiply(Y, P)) + lam * cp.norm(L, "nuc")

    def constr(X, L):
        return X.T @ L == 0

    B0 = cp.Variable((q, p))
    L0 = cp.Variable((n, q))
    obj = cp.Minimize(obj_fn(B0, L0))
    constraints = [constr(X, L0)]
    prob = cp.Problem(obj, constraints)
    prob.solve()
    
    return prob.status, B0.value, L0.value

def SVD_est(L, K):
    n = L.shape[0]
    q = L.shape[1]

    U, S, Vt = scipy.sparse.linalg.svds(L, K)
    Pi = (n/q) ** (1/4) * U @ np.diag(S ** (1/2))
    Gamma = (q/n) ** (1/4) * Vt.T @ np.diag(S ** (1/2))
    
    return Pi, Gamma

In [1]:
import numpy as np
i, j = np.ogrid[:10, :10]
diff = np.abs(i - j)
matrix = 0.2 ** diff

In [6]:
from scipy.special import expit

## make L larger through making Gamma larger rather than simply rand n
## give some sturcture to it

n = 100
p = 5
q = 50
K = 2
np.random.seed(1)

X = np.random.randn(n, p)
# B = np.random.randn(q, p)
B_true = np.random.uniform(low=0.3, high=0.7, size=(q, p))
X_ortho = scipy.linalg.null_space(X.T)
Pi0 = X_ortho[:,:K] * n
Gamma0 = np.random.uniform(low=0.5, high=1.5, size=(q, K))
L = Pi0 @ Gamma0.T
U, S, Vt = scipy.sparse.linalg.svds(L, K)
Pi_true = (n/q) ** (1/4) * U @ np.diag(S ** (1/2))
Gamma_true = (q/n) ** (1/4) * Vt.T @ np.diag(S ** (1/2))
Y0 = X @ B_true.T + L
Y = np.random.binomial(1, expit(Y0))
# Y = Y0 + np.random.randn(n, q)

In [65]:
def obj_fn(B, L, lam):
    P = X @ B.T + L
    return cp.sum((Y - P)**2) + lam * cp.norm(L, "nuc")

def constr(X, L):
    return X.T @ L == 0

lam = np.sqrt((n+q)*np.log(n))
B0 = cp.Variable((q, p))
L0 = cp.Variable((n, q))
obj = cp.Minimize(obj_fn(B0, L0, lam))
constraints = [constr(X, L0)]
prob = cp.Problem(obj, constraints)
prob.solve()

np.float64(354196.89320100413)

In [5]:
lam = np.sqrt((n+q)*np.log(np.max([n,q])))
is_denoise, B_hat, L_hat = denoise(Y, X, lam)

In [7]:
Pi_hat, Gamma_hat = SVD_est(L_hat, K)

In [8]:
print(scipy.linalg.norm(B_hat - B) / scipy.linalg.norm(B))
print(scipy.linalg.norm(L_hat - L) / scipy.linalg.norm(L))
print(scipy.linalg.norm(Pi_hat - Pi) / scipy.linalg.norm(Pi))
print(scipy.linalg.norm(Gamma_hat - Gamma) / scipy.linalg.norm(Gamma))

1.702986003545672
1.0000000014478054
0.9999885349290456
1.0001011901482266


In [31]:
B_err = B_hat - B
L_err = L_hat - L
Pi_err = Pi_hat - Pi
Gamma_err = Gamma_hat - Gamma

In [19]:
print(cp.installed_solvers())

['CLARABEL', 'ECOS', 'ECOS_BB', 'OSQP', 'SCIPY', 'SCS']


In [None]:
X = np.random.randn(200, 100)
B = np.random.randn(100, 100)
Z = np.random.randn(200, 3)
Gamma = np.random.randn(100, 3)
Y = np.exp(X @ B.T + Z @ Gamma.T) / (1 + np.exp(X @ B.T + Z @ Gamma.T))
Y = np.random.binomial(1, Y)



In [9]:
B_hat, Z_hat, Gamma_hat, niter = JML(Y, X, K)

KeyboardInterrupt: 

In [21]:
q = Y.shape[1]
p = X.shape[1]
n = X.shape[0]

# Initialize Z
Z0 = np.random.randn(n, K)
Gamma0 = np.random.randn(q, K)
B0 = np.random.randn(q, p)

like_lst = []
err_b = []

def iter_z_fix(Z, B_old, Gamma_old):
    # def obj_fn(BGamma, shape):
    #     BGamma = BGamma.reshape(shape)
    #     B = BGamma[:,:p]
    #     Gamma = BGamma[:,p:]
    #     P = X @ B.T + Z @ Gamma.T
    #     return np.sum(softplus(P)-Y * P)
    def obj_fn(BGamma, shape):
        BGamma = BGamma.reshape(shape)
        B = BGamma[:,:p]
        Gamma = BGamma[:,p:]
        P = X @ B.T + Z @ Gamma.T
        loss = np.sum(softplus(P)-Y * P)
        B_grad = X.T @ (expit(P) - Y)
        Gamma_grad = Z.T @ (expit(P) - Y)
        grad = np.vstack([B_grad, Gamma_grad]).T
        return (loss, grad.flatten())

    shape = (q, p + K)
    BGamma0 = np.hstack([B_old, Gamma_old])
    BGamma0 = BGamma0.flatten()
    opt_res = minimize(obj_fn, BGamma0, args=(shape,), method='L-BFGS-B', jac=True)

    BGamma = opt_res.x.reshape(shape)
    B = BGamma[:,:p]
    Gamma = BGamma[:,p:]
    return B, Gamma
    
def iter_BG_fix(B, Gamma, Z_old):
    def obj_fn(Z, shape):
        Z = Z.reshape(shape)
        P = X @ B.T + Z @ Gamma.T
        loss = np.sum(softplus(P)-Y * P)
        # grad = (Gamma.T @ (expit(P) - Y).T).flatten()
        grad = (Gamma.T @ (expit(P) - Y).T).T
        return (loss, grad.flatten())
        # return np.sum(softplus(P)-Y * P)
    
    shape = Z_old.shape
    Z_old = Z_old.flatten()
    # opt_res = minimize(obj_fn, Z_old, args=(shape,), options={'maxiter':int(1e7)})
    opt_res = minimize(obj_fn, Z_old, args=(shape,), method='L-BFGS-B', jac=True)
    return opt_res.x.reshape(shape)
        # if opt_res.success:
        #     return opt_res.x.reshape(shape)
        # else:
        #     raise AssertionError("Not Reach Minimum")
    
    ## Stop Here. There should be a while loop to connect everything together
    # give me while loop of iter_z_fix and iter_BG_fix
Z = Z0
B, Gamma = iter_z_fix(Z0, B0, Gamma0)
Z_new = iter_BG_fix(B, Gamma, Z)
err = np.max([scipy.linalg.norm(Z_new - Z), scipy.linalg.norm(B-B0), scipy.linalg.norm(Gamma-Gamma0)])
iter = 0
while (err > 1e-3) & (iter < 100):
    Z = Z_new
    B_new, Gamma_new = iter_z_fix(Z, B, Gamma)
    Z_new = iter_BG_fix(B_new, Gamma_new, Z)
    err = np.max([scipy.linalg.norm(Z_new - Z), scipy.linalg.norm(B_new - B), scipy.linalg.norm(Gamma_new - Gamma)])
    err_b.append(np.linalg.norm(B_new - B))
    B = B_new
    Gamma = Gamma_new
    iter += 1
    like_lst.append(np.sum(softplus(X @ B.T + Z @ Gamma.T) - Y * (X @ B.T + Z @ Gamma.T)))

In [15]:
scipy.linalg.norm(Gamma_new - Gamma)

np.float64(0.0)

In [22]:
B_orth, Z_orth, Gamma_orth = rotate_ortho(X, B, Z, Gamma)

In [23]:
import scipy.linalg


print(scipy.linalg.norm(B_orth - B_true))
print(scipy.linalg.norm(Z_orth - Pi_true))
print(scipy.linalg.norm(Gamma_orth - Gamma_true))


12593.182378450827
833.9439072754107
29.060436595295226


In [40]:
print(scipy.linalg.norm(B_hat - B_true))
print(scipy.linalg.norm(Pi_hat - Pi_true))
print(scipy.linalg.norm(Gamma_hat - Gamma_true))

13.831209257218706
41.097182231067826
29.0633700522935


In [6]:
n = 200
p = 10
K = 3

X = np.random.randn(n, K)
Beta = np.random.randn(K, p)
Y = X @ Beta + np.random.randn(n,p)
lam = np.sqrt((n+p) * np.log(np.max((n,p))))

Beta0 = cp.Variable((K, p))
Gamma0 = cp.Variable((n, p))


In [15]:
obj = cp.Minimize(obj_fn(Y,X,Gamma0,Beta0,lam))
constraint = constr(X,Gamma0)
prob = cp.Problem(obj, constraint)
prob.solve()

np.float64(1926.0045456886132)

In [7]:
def rotate_sparse(X, B, Z, Gamma):
    n = Z.shape[0]
    q = B.shape[0]
    p = B.shape[1]
    K = Gamma.shape[1]
    X_cov = X[:,1:]

    def obj_fn(Ac):
        return cp.sum(cp.abs(B[:,1:] - Gamma @ Ac))
    
    # Ac0 = cp.Variable((p-1, K))
    Ac0 = cp.Variable((K, p-1))
    prob = cp.Problem(cp.Minimize(obj_fn(Ac0)))
    prob.solve()
    Ac = Ac0.value
    a0 = -np.mean(X_cov @ Ac.T + Z, axis=0)
    A = np.concat([a0.reshape(K, 1), Ac], axis=1)

    B_new = B - Gamma @ A
    Z_mid = Z + X @ A.T

    GG_half = scipy.linalg.fractional_matrix_power(Gamma.T @ Gamma, 0.5)
    G_full = GG_half @ Z_mid.T @ Z_mid @ GG_half /n/q
    d, V = np.linalg.eigh(G_full)
    D = np.diag(d**(-1/4))
    G = GG_half @ V @ D / np.sqrt(q)

    Gamma_new = Gamma @ scipy.linalg.inv(G_full.T)
    Z_new = Z_mid @ G

    return B_new, Z_new, Gamma_new

In [8]:
B, Z, Gamma = rotate_sparse(X, B_true, Pi_true, Gamma_true)

In [None]:
from utils import gen_data_sparse

n = 200
p = 10
K = 2
q = 100
Y, X, B, Pi, Gamma = gen_data_sparse(n, q, p, K, tau=0.5, rho=0.5, seed=1, sparse=True)

In [8]:
import numpy as np

rv = np.random.uniform(low=0.3, high=0.7, q/K)
print(rv)

SyntaxError: positional argument follows keyword argument (2356776484.py, line 3)

In [4]:
from multiprocessing import Pool
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor

def process_item(x):
    return x * x  # Example function

data = list(range(100))
# with Pool(processes=4) as pool:
#     results = list(tqdm(pool.imap_unordered(process_item, data), total=len(data)))
with ProcessPoolExecutor(max_workers=4) as executor:
    results = list(tqdm(executor.map(process_item, data), total=len(data)))

  0%|          | 0/100 [00:00<?, ?it/s]Process SpawnProcess-214:
Process SpawnProcess-215:
Process SpawnProcess-216:
Process SpawnProcess-213:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchbase/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchbase/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchbase/lib/python3.11/concurrent/futures/process.py", line 244, in _process_worker
    call_item = call_queue.get(block=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchbase/lib/python3.11/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.