In [None]:
# Import specific functions
import sys
import numpy as np
from numpy.linalg import eig, inv, pinv, eigvals
from scipy.spatial.distance import squareform, pdist
from scipy.sparse import csr_matrix, random as sparse_random, find, issparse
# from scipy.stats import multivariate_normal as mvnpdf

import time
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import inv
import numpy as np
import scipy.sparse as sp
from scipy.spatial.distance import squareform
from scipy.sparse import triu, coo_matrix
from sklearn.metrics import precision_score, recall_score, f1_score, normalized_mutual_info_score
from scipy.sparse.linalg import norm as sparse_norm
from scipy.spatial.distance import squareform
from math import sqrt
import networkx as nx


sys.path.append('/Users/paul_reitz/Documents/repos/PyAWGLMM/Smooth_AWGLMM')
from scripts.utils import (
    visualize_glmm,
    graph_learning_perf_eval,
    identify_and_compare,
    generate_connected_graph,
    normest,
    lin_map,
    squareform_sp,
    # sum_squareform,
    prox_sum_log,
    gsp_distanz,
    gsp_compute_graph_learning_theta,
    normalize_data,
    
    
    
)
from scripts.utils_deep import mvnpdf

np.random.seed(17307946)

In [None]:
def gsp_learn_graph_log_degrees(Z, a, b, params=None):
    if params is None:
        params = {}
    verbosity = params.get('verbosity', 1)
    maxit = params.get('maxit', 1000)
    tol = params.get('tol', 1e-5)
    step_size = params.get('step_size', 0.5)
    fix_zeros = params.get('fix_zeros', False)
    max_w = params.get('max_w', np.inf)

    w_0 = params.get('w_0', 0)
    if w_0 != 0 and 'c' not in params:
        raise ValueError("When params.w_0 is specified, c must also be specified")
    c = params.get('c', 0.0 if w_0 == 0 else None)

    if Z.ndim == 2 and Z.shape[0] == Z.shape[1]:
        z = squareform_sp(Z)
    else:
        z = Z
    z = z.ravel()
    l = len(z)
    n = int(round((1 + sqrt(1+8*l))/2))

    if not np.isscalar(w_0):
        # Convert w_0 to vector form if needed
        if w_0.ndim == 2 and w_0.shape[0] == w_0.shape[1]:
            w_0 = squareform_sp(w_0)
        w_0 = w_0.ravel()
    else:
        w_0 = float(w_0)

    if fix_zeros:
        edge_mask = params.get('edge_mask', None)
        if edge_mask is None:
            raise ValueError("edge_mask must be provided when fix_zeros is True")
        if edge_mask.ndim == 2 and edge_mask.shape[0] == edge_mask.shape[1]:
            edge_mask = squareform_sp(edge_mask)
        edge_mask = edge_mask.ravel()
        ind = np.flatnonzero(edge_mask)
        z = z[ind].astype(float)
        if not np.isscalar(w_0):
            w_0 = w_0[ind].astype(float)
    else:
        z = z.astype(float)
        if not np.isscalar(w_0):
            w_0 = w_0.astype(float)

    w = params.get('W_init', np.zeros_like(z, dtype=float))

    # Construct S, St
    if fix_zeros:
        S, St = sum_squareform(n, edge_mask)
    else:
        S, St = sum_squareform(n)

    K_op = lambda w_: S.dot(w_)
    Kt_op = lambda z_: St.dot(z_)

    if fix_zeros:
        norm_K = normest(S)
    else:
        norm_K = sqrt(2*(n-1))

    def f_eval(w_):
        return 2*np.dot(w_, z)
    def f_prox(w_, c_):
        return np.minimum(max_w, np.maximum(0, w_ - 2*c_*z))

    def g_eval(x):
        # if x<=0, log not defined => infinite
        if np.any(x <= 0):
            return np.inf
        return -a * np.sum(np.log(x))
    def g_prox(z_, c_):
        sol, _ = prox_sum_log(z_, c_*a, param={'verbose':-3})
        return sol

    def g_star_prox(z_, c_):
        z_ = np.asarray(z_, dtype=float)
        sol, _ = prox_sum_log(z_/(c_*a), 1/(c_*a), param={'verbose':-3})
        return z_ - c_*a * sol

    if w_0 == 0:
        def h_eval(w_):
            return b*np.sum(w_**2)
        def h_grad(w_):
            return 2*b*w_
        h_beta = 2*b
    else:
        def h_eval(w_):
            return b*np.sum(w_**2) + c*np.sum((w_-w_0)**2)
        def h_grad(w_):
            return 2*((b+c)*w_ - c*w_0)
        h_beta = 2*(b+c)

    mu = h_beta + norm_K
    epsilon = lin_map(0.0, [0, 1/(1+mu)], [0,1])
    gn = lin_map(step_size, [epsilon, (1-epsilon)/mu], [0,1])

    v_n = K_op(w)

    stat = {}
    if verbosity > 1:
        stat['f_eval'] = np.full(maxit, np.nan)
        stat['g_eval'] = np.full(maxit, np.nan)
        stat['h_eval'] = np.full(maxit, np.nan)
        stat['fgh_eval'] = np.full(maxit, np.nan)
        stat['pos_violation'] = np.full(maxit, np.nan)
        if verbosity > 1:
            print('Relative change of primal, dual variables, and objective fun')

    t0 = time.time()
    for i in range(maxit):
        Y_n = w - gn*(h_grad(w) + Kt_op(v_n))
        y_n = v_n + gn*(K_op(w))

        P_n = f_prox(Y_n, gn)
        p_n = g_star_prox(y_n, gn) 
        Q_n = P_n - gn*(h_grad(P_n) + Kt_op(p_n))
        q_n = p_n + gn*(K_op(P_n))

        if verbosity > 2:
            stat['f_eval'][i] = f_eval(w)
            val_g = g_eval(K_op(w))
            stat['g_eval'][i] = val_g
            stat['h_eval'][i] = h_eval(w)
            stat['fgh_eval'][i] = stat['f_eval'][i] + stat['g_eval'][i] + stat['h_eval'][i]
            stat['pos_violation'][i] = -np.sum(np.minimum(0,w))

        denom_w = np.linalg.norm(w)
        if denom_w < 1e-15:
            denom_w = 1e-15
        denom_v = np.linalg.norm(v_n)
        if denom_v < 1e-15:
            denom_v = 1e-15

        rel_norm_primal = np.linalg.norm(-Y_n + Q_n)/denom_w
        rel_norm_dual = np.linalg.norm(-y_n + q_n)/denom_v

        if verbosity > 1 and verbosity <= 3:
            print(f'iter {i+1:4d}: {rel_norm_primal:6.4e}   {rel_norm_dual:6.4e}')

        w = w - Y_n + Q_n
        v_n = v_n - y_n + q_n

        if rel_norm_primal < tol and rel_norm_dual < tol:
            break

    stat['time'] = time.time() - t0
    if verbosity > 0:
        obj_val = f_eval(w) + g_eval(K_op(w)) + h_eval(w)
        print(f'# iters: {i+1:4d}. Rel primal: {rel_norm_primal:6.4e} Rel dual: {rel_norm_dual:6.4e}  OBJ {obj_val:6.3e}')
        print(f'Time needed is {stat["time"]} seconds')

    if fix_zeros:
        # reconstruct full w
        full_w = np.zeros(l, dtype=float)
        full_w[ind] = w
        w = full_w

    # Convert back to matrix if Z was a matrix
    if Z.ndim == 2 and Z.shape[0] == Z.shape[1]:
        W = squareform_sp(w)
    else:
        W = w

    return W, stat

In [None]:
def sum_squareform(n, mask=None):
    """
    Python version of the MATLAB function sum_squareform.
    
    Parameters:
    -----------
    n : int
        Size of the matrix W (an n-by-n matrix).
    mask : array-like or None
        If given, must be of size n(n-1)/2. Non-zero entries indicate
        which elements in w = squareform(W) are considered.
    
    Returns:
    --------
    S : scipy.sparse.csc_matrix
        Matrix so that S * w = sum(W) for w = squareform(W).
    St : scipy.sparse.csc_matrix
        The adjoint of S (S transpose).
        
    The output is consistent with the MATLAB function:
    [S, St] = sum_squareform(n, mask)
    """
    if mask is None:
        mask_given = False
    else:
        mask_given = True

    if mask_given:
        # Check mask length
        if len(mask) != n*(n-1)//2:
            raise ValueError('mask size must be n(n-1)/2')
        
        # Find nonzero indices in mask
        # MATLAB code tries to handle both column vectors and row vectors
        # In Python, let's just do a generic approach:
        mask = np.asarray(mask).flatten()
        ind_vec = np.nonzero(mask)[0]  # 0-based indices of nonzero
        
        ncols = len(ind_vec)

        I = np.zeros(ncols, dtype=np.int64)
        J = np.zeros(ncols, dtype=np.int64)

        curr_row = 1
        offset = 0
        # length of current row of matrix, counting from after the diagonal
        length = n - 1
        for ii in range(ncols):
            ind_vec_i = ind_vec[ii] + 1  # switch to 1-based for logic
            # Move down rows until we find the correct row
            while ind_vec_i > (length + offset):
                offset += length
                length -= 1
                curr_row += 1
            # curr_row, J position in MATLAB indexing
            I[ii] = curr_row
            # Convert the position within the row into an absolute column index
            # J(ii) = ind_vec_i - offset + (n - length)
            J[ii] = ind_vec_i - offset + (n - length)

    else:
        # No mask given
        ncols = (n-1)*n//2
        I = np.zeros(ncols, dtype=np.int64)
        J = np.zeros(ncols, dtype=np.int64)

        # Fill I
        k = 0
        for i in range(2, n+1):
            end_idx = k+(n - i + 1)
            # I(k : k+(n-i)) = i:n in MATLAB (1-based)
            # That means for each i, we fill a block of size (n-i+1) with values from i to n.
            I[k:end_idx] = np.arange(i, n+1)
            k = end_idx

        # Fill J
        k = 0
        for i in range(2, n+1):
            end_idx = k+(n - i + 1)
            # J(k : k+(n-i)) = i-1 in MATLAB
            # That means for each block, we fill it with i-1 repeated.
            J[k:end_idx] = (i - 1)
            k = end_idx

    # Convert from 1-based to 0-based indexing
    I -= 1
    J -= 1

    # Construct St
    # St in MATLAB: sparse([1:ncols, 1:ncols],[I, J],1,ncols,n)
    # means each row i has two ones: one at column I(i), one at column J(i).
    # In Python (0-based), rows = 0..ncols-1
    row_indices = np.repeat(np.arange(ncols), 2)
    col_indices = np.hstack((I, J))
    data = np.ones(2*ncols, dtype=np.float64)

    St = coo_matrix((data, (row_indices, col_indices)), shape=(ncols, n)).tocsc()
    S = St.transpose().tocsc()

    return S, St
