In [1]:
import numpy as np
import cvxpy as cp
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def grad_x(X, S):

    X_inv = np.linalg.inv(X)
    return S - X_inv

def arg_prox(X, t, grad_x, S):

    return X - t * grad_x(X, S)

def h_func_cp(X, gamma):
    offdiag_mask = ~np.eye(X.shape[0], dtype=bool)
    return gamma * cp.norm1(cp.multiply(offdiag_mask, X))

def h_func(X, gamma):
    offdiag_mask = ~np.eye(X.shape[0], dtype=bool)
    return gamma * np.sum(np.abs(np.multiply(offdiag_mask, X)))

def prox_h(X, h_func_cp, t, grad_x, arg_prox, S, gamma):
    """
    Compute prox_{h}(x) = argmin_y h(y) + 0.5 * ||y - x||_2^2

    Parameters:
    - x (np.ndarray): The point at which to evaluate the proximal operator
    - h_func_cp (callable): A function that accepts a cvxpy Variable y and returns h(y)

    Returns:
    - np.ndarray: Result of the proximal operator
    """
    x = arg_prox(X, t, grad_x, S)

    y = cp.Variable(x.shape)
    objective = h_func_cp(y, gamma) + 0.5 * cp.sum_squares(y - x)
    problem = cp.Problem(cp.Minimize(objective))
    problem.solve()
    return y.value



def compute_U(X, S, gamma):
    """
    Compute U where:
    U_ij = max(-gamma, min(gamma, [X_inv - S]_ij)) for i ≠ j
           0 for i == j
    
    Parameters:
    - X (np.ndarray): Square positive definite matrix
    - S (np.ndarray): Symmetric matrix of same shape as X
    - gamma (float): Threshold parameter
    
    Returns:
    - np.ndarray: Matrix U
    """
    X_inv = np.linalg.inv(X)
    diff = X_inv - S

    # Apply soft thresholding only to off-diagonal elements
    U = np.zeros_like(diff)
    for i in range(diff.shape[0]):
        for j in range(diff.shape[1]):
            if i != j:
                U[i, j] = np.clip(diff[i, j], -gamma, gamma)
            else:
                U[i, j] = 0  # optional, since we initialized with zeros

    return U

def g_func(X, S):
    return np.trace(S @ X) + np.log(np.linalg.det(X))

def compute_stopping_criterion(X, S, g_func, h_func, gamma, compute_U):

    n = X.shape[0]

    delta = g_func(X, S) + h_func(X, gamma) - np.log(np.linalg.det(S + compute_U(X, S, gamma))) - n
    
    return delta

def proximal_gradient_descend(X, h_func_cp, h_func, t, grad_x, arg_prox, S, gamma, g_func, compute_U, epsilon=1e-2): # TODO : i'm not fucking sure this work, can you guys check please 

    while True:
        X_new = prox_h(X, h_func_cp, t, grad_x, arg_prox, S, gamma)
        delta = compute_stopping_criterion(X_new, S, g_func, h_func, gamma, compute_U)

        X = X_new

        if delta <= epsilon:
            break 
    
    return X

def backtracking_line_search(phi, phi_derivative_at_0, t_init, alpha1=0.1, beta=0.7):
    """
    Perform backtracking line search to find step size t.

    Parameters:
        phi (function): A continuously differentiable function φ: R → R.
        phi_derivative_at_0 (float): The derivative φ'(0).
        t_init (float): Initial step size (t >= 0).
        alpha1 (float): Parameter in (0, 0.5], default 0.1.
        beta (float): Parameter in (0, 1), default 0.7.

    Returns:
        float: Step size t such that φ(t) ≤ φ(0) + α1 * t * φ'(0)
    """
    t = t_init
    phi_0 = phi(0)

    while phi(t) > phi_0 + alpha1 * t * phi_derivative_at_0:
        t *= beta

    return t

def proximal_gradient_descend(X, h_func_cp, h_func, t_init, grad_x, arg_prox, S, gamma, g_func, compute_U, epsilon=1e-2):

    while True:
        grad = grad_x(X, S)

        def phi(t):
            X_temp = prox_h(X, h_func_cp, t, grad_x, arg_prox, S, gamma)
            return g_func(X_temp, S) + h_func(X_temp, gamma)

        phi_derivative_at_0 = np.sum(grad * (-grad))  # directional derivative: <∇g, -∇g>

        t = backtracking_line_search(phi, phi_derivative_at_0, t_init)

        X_new = prox_h(X, h_func_cp, t, grad_x, arg_prox, S, gamma)
        delta = compute_stopping_criterion(X_new, S, g_func, h_func, gamma, compute_U)

        X = X_new

        if delta <= epsilon:
            break

    return X


In [2]:
np.random.seed(42)
n = 3
A = np.random.randn(n, n)
S = A @ A.T  # make S symmetric positive definite
X_init = np.eye(n)
gamma = 0.1
t_init = 1.0

X_result = proximal_gradient_descend(X_init, h_func_cp, h_func, t_init, grad_x, arg_prox, S, gamma, g_func, compute_U)
print("Final X:\n", X_result)


  return np.trace(S @ X) + np.log(np.linalg.det(X))


KeyboardInterrupt: 