In [7]:
# Import specific functions
import sys
import numpy as np
from numpy.linalg import eig, inv, pinv, eigvals
# from scipy.spatial.distance import squareform as squareform_sp
from scipy.sparse import csr_matrix, random as sparse_random, find, issparse
# from scipy.stats import multivariate_normal as mvnpdf
from pynndescent import NNDescent
import time
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import inv
import numpy as np
import scipy.sparse as sp
from scipy.spatial.distance import squareform
from scipy.sparse import triu, coo_matrix
from sklearn.metrics import precision_score, recall_score, f1_score, normalized_mutual_info_score
from scipy.sparse.linalg import norm as sparse_norm
from scipy.spatial.distance import squareform
from math import sqrt
import networkx as nx
import matplotlib.pyplot as plt
from scipy.linalg import cholesky


sys.path.append('/Users/paul_reitz/Documents/repos/PyAWGLMM/Smooth_AWGLMM')
from scripts.utils import (
    visualize_glmm,
    graph_learning_perf_eval,
    identify_and_compare,
    generate_connected_graph,
    normest,
    lin_map,
    squareform_sp,
    # sum_squareform,
    prox_sum_log,
    gsp_distanz,
    gsp_compute_graph_learning_theta,
    gsp_symmetrize,
)
from scripts.utils_deep import mvnpdf

np.random.seed(17307946)

In [8]:
def sum_squareform(n, mask=None):
    """
    Python version of the MATLAB function sum_squareform.
    
    Parameters:
    -----------
    n : int
        Size of the matrix W (an n-by-n matrix).
    mask : array-like or None
        If given, must be of size n(n-1)/2. Non-zero entries indicate
        which elements in w = squareform(W) are considered.
    
    Returns:
    --------
    S : scipy.sparse.csc_matrix
        Matrix so that S * w = sum(W) for w = squareform(W).
    St : scipy.sparse.csc_matrix
        The adjoint of S (S transpose).
        
    The output is consistent with the MATLAB function:
    [S, St] = sum_squareform(n, mask)
    """
    if mask is None:
        mask_given = False
    else:
        mask_given = True

    if mask_given:
        # Check mask length
        if len(mask) != n*(n-1)//2:
            raise ValueError('mask size must be n(n-1)/2')
        
        # Find nonzero indices in mask
        # MATLAB code tries to handle both column vectors and row vectors
        # In Python, let's just do a generic approach:
        mask = np.asarray(mask).flatten()
        ind_vec = np.nonzero(mask)[0]  # 0-based indices of nonzero
        
        ncols = len(ind_vec)

        I = np.zeros(ncols, dtype=np.int64)
        J = np.zeros(ncols, dtype=np.int64)

        curr_row = 1
        offset = 0
        # length of current row of matrix, counting from after the diagonal
        length = n - 1
        for ii in range(ncols):
            ind_vec_i = ind_vec[ii] + 1  # switch to 1-based for logic
            # Move down rows until we find the correct row
            while ind_vec_i > (length + offset):
                offset += length
                length -= 1
                curr_row += 1
            # curr_row, J position in MATLAB indexing
            I[ii] = curr_row
            # Convert the position within the row into an absolute column index
            # J(ii) = ind_vec_i - offset + (n - length)
            J[ii] = ind_vec_i - offset + (n - length)

    else:
        # No mask given
        ncols = (n-1)*n//2
        I = np.zeros(ncols, dtype=np.int64)
        J = np.zeros(ncols, dtype=np.int64)

        # Fill I
        k = 0
        for i in range(2, n+1):
            end_idx = k+(n - i + 1)
            # I(k : k+(n-i)) = i:n in MATLAB (1-based)
            # That means for each i, we fill a block of size (n-i+1) with values from i to n.
            I[k:end_idx] = np.arange(i, n+1)
            k = end_idx

        # Fill J
        k = 0
        for i in range(2, n+1):
            end_idx = k+(n - i + 1)
            # J(k : k+(n-i)) = i-1 in MATLAB
            # That means for each block, we fill it with i-1 repeated.
            J[k:end_idx] = (i - 1)
            k = end_idx

    # Convert from 1-based to 0-based indexing
    I -= 1
    J -= 1

    # Construct St
    # St in MATLAB: sparse([1:ncols, 1:ncols],[I, J],1,ncols,n)
    # means each row i has two ones: one at column I(i), one at column J(i).
    # In Python (0-based), rows = 0..ncols-1
    row_indices = np.repeat(np.arange(ncols), 2)
    col_indices = np.hstack((I, J))
    data = np.ones(2*ncols, dtype=np.float64)

    St = coo_matrix((data, (row_indices, col_indices)), shape=(ncols, n)).tocsc()
    S = St.transpose().tocsc()

    return S, St


In [9]:
def generate_connected_graph(n, p, zero_thresh, maxit=10, verbose=1):
    """
    Generate a connected Erdos-Renyi graph using networkx and ensure 
    its second smallest eigenvalue of Laplacian is > zero_thresh.

    Parameters
    ----------
    n : int
        Number of nodes.
    p : float
        Probability of connection between nodes.
    zero_thresh : float
        Threshold for the second smallest eigenvalue to ensure connectivity.
    maxit : int, optional
        Maximum number of tries to get a connected graph (default 10).
    verbose : int, optional
        Verbosity level (default 1).

    Returns
    -------
    G_dict : dict
        A dictionary mimicking the G structure with keys:
        - 'W': adjacency/weight matrix
        - 'L': Laplacian matrix
        - 'N': number of nodes
        - 'type': 'erdos_renyi'
    
    Raises
    ------
    ValueError
        If after maxit attempts no connected graph is found.
    """
    for iteration in range(1, maxit+1):
        # Generate an Erdos-Renyi graph
        Gnx = nx.erdos_renyi_graph(n, p)
        W = nx.to_numpy_array(Gnx, dtype=np.float64)

        # Remove any self loops by zeroing diagonal
        np.fill_diagonal(W, 0.0)

        # Compute Laplacian
        d = np.sum(W, axis=1)
        D = np.diag(d)
        L = D - W

        # Compute eigenvalues
        e = np.linalg.eigvalsh((L + L.T)*0.5)
        e = np.sort(e)

        # Check connectivity condition
        if len(e) > 1 and e[1] > zero_thresh:
            if verbose > 1:
                print(f"A connected graph has been created in {iteration} iteration(s)")

            G_dict = {}
            G_dict['W'] = W
            G_dict['L'] = L
            G_dict['N'] = n
            G_dict['type'] = 'erdos_renyi'
            return G_dict
        else:
            if verbose > 1:
                print(f"Iteration {iteration} failed. Trying again.")

    if verbose:
        print("Warning: The graph is not strongly connected after maxit attempts.")
    raise ValueError("Could not generate a connected graph after maxit attempts.")







In [10]:
def gsp_learn_graph_log_degrees(Z, a, b, params=None):
    """
    Python version of gsp_learn_graph_log_degrees.

    Parameters
    ----------
    Z : array-like
        Matrix or vector of (squared) pairwise distances.
    a : float
        Log prior constant.
    b : float
        ||W||_F^2 prior constant.
    params : dict, optional
        Dictionary of parameters:
            - verbosity : int, default=1
            - maxit : int, default=1000
            - tol : float, default=1e-5
            - step_size : float in (0,1), default=0.5
            - fix_zeros : bool, default=False
            - max_w : float, default=inf
            - w_0 : array-like or scalar, optional
            - c : float, required if w_0 given
            - edge_mask : array-like, required if fix_zeros=True

    Returns
    -------
    W : ndarray
        Learned weighted adjacency matrix (or vector if Z was vector).
    stat : dict
        Dictionary with stats about the optimization.
    """

    if params is None:
        params = {}
    verbosity = params.get('verbosity', 1)
    maxit = params.get('maxit', 1000)
    tol = params.get('tol', 1e-5)
    step_size = params.get('step_size', 0.5)
    fix_zeros = params.get('fix_zeros', False)
    max_w = params.get('max_w', np.inf)

    # Convert Z to vector form if needed
    # Assume we have squareform_sp implemented
    if Z.ndim > 1:
        z = squareform_sp(Z)
    else:
        z = Z
    z = z.flatten()

    l = len(z)  # number of edges
    # n(n-1)/2 = l => n = (1 + sqrt(1+8*l))/2
    n = int(round((1 + np.sqrt(1+8*l))/2))

    if 'w_0' in params:
        w_0 = params['w_0']
        if 'c' not in params:
            raise ValueError("When params.w_0 is specified, params.c must also be specified")
        c = params['c']
        if w_0.ndim > 1:
            w_0 = squareform_sp(w_0)
        w_0 = w_0.flatten()
    else:
        w_0 = 0
        c = 0  # not used if w_0=0

    # If fix_zeros is set, we learn only a subset of edges
    if fix_zeros:
        edge_mask = params['edge_mask']
        if edge_mask.ndim > 1:
            edge_mask = squareform_sp(edge_mask)
        edge_mask = edge_mask.flatten()
        ind = np.nonzero(edge_mask)[0]
        z = z[ind]
        if not np.isscalar(w_0):
            w_0 = w_0[ind]
    else:
        # Make sure arrays are full
        z = np.array(z, dtype=float)
        if not np.isscalar(w_0):
            w_0 = np.array(w_0, dtype=float)

    # Initialization
    # If W_init given, use it
    if 'W_init' in params:
        w = params['W_init'].copy()
        if w.ndim > 1:
            w = squareform_sp(w)
        w = w.flatten()
        if fix_zeros:
            w = w[ind]
    else:
        w = np.zeros_like(z)

    # Get S, St from sum_squareform
    if fix_zeros:
        S, St = sum_squareform(n, edge_mask)
    else:
        S, St = sum_squareform(n)

    # Define linear operators K and Kt
    K_op = lambda w: S @ w
    Kt_op = lambda z: St @ z

    # norm_K (used for step sizing)
    if fix_zeros:
        norm_K = normest(S)
    else:
        norm_K = np.sqrt(2*(n-1))

    # Define f, g, h as in the code
    # f(w) = 2*w'*z, prox_f(w,c) = min(max_w, max(0, w - 2*c*z))
    def f_eval(w_):
        return 2 * np.dot(w_, z)

    def f_prox(w_, c_):
        return np.minimum(max_w, np.maximum(0, w_ - 2*c_*z))

    # g(w) = -a * sum(log(sum(W)))
    # We handle g through g_star_prox. We have prox_sum_log already.
    def g_eval_edges(w_):
        # g(K_op(w)) = -a * sum(log(K_op(w))) 
        tmp = K_op(w_)
        return -a * np.sum(np.log(tmp))

    param_prox_log = {'verbose': verbosity - 3}
    # def g_star_prox(z_, c_):
    #     # z - c*a * prox_sum_log(z/(c*a), 1/(c*a))
    #     return z_ - c_*a * prox_sum_log(z_/(c_*a), 1/(c_*a), param_prox_log)
    
    def g_star_prox(z_, c_):
        prox_val, _ = prox_sum_log(z_ / (c_ * a), 1 / (c_ * a), param_prox_log)  # Unpack the tuple
        prox_val = np.array(prox_val, dtype=float) 
        return z_ - c_*a * prox_val

    # h(w) = b||w||_F^2 + c||w - w_0||_F^2 if w_0 given
    if np.isscalar(w_0):
        def h_eval(w_):
            return b * np.sum(w_**2)
        def h_grad(w_):
            return 2 * b * w_
        h_beta = 2 * b
    else:
        def h_eval(w_):
            return b * np.sum(w_**2) + c * np.sum((w_ - w_0)**2)
        def h_grad(w_):
            return 2 * ((b+c)*w_ - c*w_0)
        h_beta = 2 * (b + c)

    # Parameters for FBF
    mu = h_beta + norm_K
    epsilon = lin_map(0.0, [0, 1/(1+mu)], [0,1])
    gn = lin_map(step_size, [epsilon, (1 - epsilon)/mu], [0,1])

    # Dual variable
    v_n = K_op(w)

    # statistics
    stat = {}
    if verbosity > 1 or True:
        stat['f_eval'] = np.full((maxit,), np.nan)
        stat['g_eval'] = np.full((maxit,), np.nan)
        stat['h_eval'] = np.full((maxit,), np.nan)
        stat['fgh_eval'] = np.full((maxit,), np.nan)
        stat['pos_violation'] = np.full((maxit,), np.nan)

    if verbosity > 1:
        print('Relative change of primal, dual variables, and objective fun')

    start_time = 0.0
    import time
    start_time = time.time()

    for i in range(maxit):
        # Forward-backward-forward updates
        Y_n = w - gn * (h_grad(w) + Kt_op(v_n))
        y_n = v_n + gn * (K_op(w))
        P_n = f_prox(Y_n, gn)
        p_n = g_star_prox(y_n, gn)
        Q_n = P_n - gn * (h_grad(P_n) + Kt_op(p_n))
        q_n = p_n + gn * (K_op(P_n))

        # Evaluate objective and constraints if needed
        stat['f_eval'][i] = f_eval(w)
        stat['g_eval'][i] = g_eval_edges(w)
        stat['h_eval'][i] = h_eval(w)
        stat['fgh_eval'][i] = stat['f_eval'][i] + stat['g_eval'][i] + stat['h_eval'][i]
        stat['pos_violation'][i] = -np.sum(np.minimum(0, w))

        # stopping criterion
        rel_norm_primal = np.linalg.norm(-Y_n + Q_n) / (np.linalg.norm(w) + 1e-15)
        rel_norm_dual = np.linalg.norm(-y_n + q_n) / (np.linalg.norm(v_n) + 1e-15)

        if verbosity > 2:
            print(f'iter {i+1:4d}: {rel_norm_primal:6.4e}   {rel_norm_dual:6.4e}   {stat["fgh_eval"][i]:6.3e}')
        elif verbosity > 1:
            print(f'iter {i+1:4d}: {rel_norm_primal:6.4e}   {rel_norm_dual:6.4e}')

        w = w - Y_n + Q_n
        v_n = v_n - y_n + q_n

        if rel_norm_primal < tol and rel_norm_dual < tol:
            break

    stat['time'] = time.time() - start_time
    if verbosity > 0:
        print(f'# iters: {i+1:4d}. Rel primal: {rel_norm_primal:6.4e} Rel dual: {rel_norm_dual:6.4e}  OBJ {stat["fgh_eval"][i]:6.3e}')
        print(f'Time needed is {stat["time"]:.6f} seconds')

    # If fix_zeros, we must put values back in their original positions
    if fix_zeros:
        full_w = np.zeros(l)
        full_w[ind] = w
        w = full_w

    # If input was a matrix, return W as a matrix
    if Z.ndim > 1:
        W = squareform_sp(w)
    else:
        W = w

    return W, stat


In [19]:
def glmm(y, iterations, classes, avg_nr_edges , spread=0.1, regul=0.15, norm_par=1.5, alpha=None):
    """
    Python version of glmm_matlab with custom cluster priors (alpha).

    Parameters
    ----------
    y : ndarray (m x n)
        Data matrix with m samples and n features.
    iterations : int
        Number of iterations.
    classes : int
        Number of classes (clusters).
    spread : float, optional
        Default 0.1
    regul : float, optional
        Default 0.15
    norm_par : float, optional
        Default 1.5
    alpha : array-like, optional
        Mixing coefficients for the clusters. Should sum to 1 and have length = classes.
        If None, alpha is initialized to 1/classes each.

    Returns
    -------
    L : ndarray (n x n x classes)
        Graph Laplacians for each class.
    gamma_hat : ndarray (m x classes)
        Cluster posterior probabilities.
    mu : ndarray (n x classes)
        Cluster means.
    log_likelihood : ndarray (iterations,)
        Log-likelihood at each iteration.
    """

    y = np.asarray(y, dtype=np.float64)
    n = y.shape[1]
    m = y.shape[0]

    L = np.zeros((n,n,classes))
    W = np.zeros((n,n,classes))
    sigma = np.zeros((n-1,n-1,classes))
    mu = np.zeros((n, classes))
    gamma_hat = np.zeros((m, classes))
    p = np.zeros(classes)
    vecl = np.zeros((n,n,classes))
    vall = np.zeros((n,n,classes))
    yl = np.zeros((m, n-1, classes))

    # Initialization
    # If alpha is provided, use it. Otherwise, initialize uniformly.
    if alpha is not None:
        alpha = np.asarray(alpha, dtype=float)
        if len(alpha) != classes:
            raise ValueError("Length of alpha must match number of classes.")
        if not np.allclose(np.sum(alpha), 1.0):
            raise ValueError("Alpha must sum to 1.")
        p = alpha
    else:
        p[:] = 1.0/classes

    for class_idx in range(classes):
        L[:,:,class_idx] = spread*np.eye(n) - (spread/n)*np.ones((n,n))
        mu_curr = np.mean(y, axis=0) + np.random.randn(n)*np.std(y,axis=0)
        mu_curr = mu_curr - np.mean(mu_curr)
        mu[:,class_idx] = mu_curr

    log_likelihood = np.zeros(iterations)

    for it in range(iterations):
        # E-step
        pall = np.zeros(m, dtype=np.float64)
        for class_idx in range(classes):
            # eigen decomposition
            eigvals, eigvecs = np.linalg.eig(L[:,:,class_idx])
            vecl[:,:,class_idx] = eigvecs
            vall[:,:,class_idx] = np.diag(eigvals)

            sub_eigvals = eigvals[1:]
            Sigma_inv = np.diag(sub_eigvals) + regul*np.eye(n-1)
            Sigma = np.linalg.inv(Sigma_inv)
            Sigma = (Sigma+Sigma.T)/2
            sigma[:,:,class_idx] = Sigma

            Y_centered = y - mu[:,class_idx]
            YL = Y_centered @ vecl[:,1:,class_idx]
            yl[:,:,class_idx] = YL

            mvn_val = mvnpdf(YL, np.zeros(n-1), Sigma)
            pall += p[class_idx]*mvn_val

        pall[pall==0] = 0.1

        for class_idx in range(classes):
            mvn_val = mvnpdf(yl[:,:,class_idx], np.zeros(n-1), sigma[:,:,class_idx])
            gamma_hat[:,class_idx] = (p[class_idx]*mvn_val)/pall

        log_likelihood[it] = np.sum(np.log(pall))

        # M-step
        for class_idx in range(classes):
            wght = gamma_hat[:,class_idx]
            mu[:,class_idx] = (wght @ y)/np.sum(wght)

            yc = (y - mu[:,class_idx])*np.sqrt(wght)[:,None]
            Z = gsp_distanz(yc)**2
            # Z = gsp_ann_distanz(yc, k=avg_nr_edges, metric_in='euclidean')**2
            # Z = gsp_symmetrize(Z, 'full')
            # theta = np.mean(Z)/norm_par
            theta = gsp_compute_graph_learning_theta(Z, avg_nr_edges, geom_mean=1)
            print(theta)

            # delta = 2 from code 
            delta = 1/theta
            W_curr, _ = gsp_learn_graph_log_degrees(theta * Z, 1, 1, params={})
            W_curr = delta*W_curr

            p[class_idx] = np.sum(wght)/m
            W_sum = np.sum(W_curr, axis=1)
            L[:,:,class_idx] = np.diag(W_sum)-W_curr
            W_curr[W_curr<1e-3] = 0
            W[:,:,class_idx] = W_curr

    return L, gamma_hat, mu, log_likelihood




In [31]:



# Example parameters
n = 15  # graph size
m = 150 # number of signals
k = 2   # number of clusters
zero_thresh = 10e-4
p = np.linspace(0, 1, k+1)  # p = 0:1/k:1 in MATLAB
print(p)

# Generate graphs as MATLAB does
g = []
for i in range(k):
    g.append(generate_connected_graph(n, 0.7, zero_thresh, maxit=10, verbose=2))

# Generate gamma and gamma_cut
gamma = np.random.rand(m, 1)  # gamma = rand([m,1]) in MATLAB
gamma_cut = np.zeros((m, k))

dist = 0.5
y = np.zeros((m, n))
true_y = np.zeros((m, n, k))
center = np.zeros((n, k))
gauss = np.zeros((n, n, k))
Lap = np.zeros((n, n, k))

for i in range(k):
    # In MATLAB: gc = pinv(full(g(i).L));
    # We have g[i] as a dict, use g[i]['L']
    L_mat = g[i]['L']
    gc = np.linalg.pinv(L_mat)
    gauss[:, :, i] = (gc + gc.T)/2
    Lap[:, :, i] = L_mat

    c = dist * np.random.randn(n)
    c = c - np.mean(c)
    center[:, i] = c

    # gamma_cut(p(i)<gamma & gamma<=p(i+1), i) = 1;
    mask = (gamma[:, 0] > p[i]) & (gamma[:, 0] <= p[i+1])
    gamma_cut[mask, i] = 1

    samples = np.random.multivariate_normal(center[:, i], gauss[:, :, i], m)
    samples = gamma_cut[:, i][:, np.newaxis] * samples
    true_y[:, :, i] = samples
    y += samples

# Now we train glmm on data y
# Assuming glmm function is defined and returns (Ls, gamma_hats, mus, log_likelihood)
avg_nr_edges = 12
iterations = 200
Ls, gamma_hats, mus, log_likelihood = glmm(y, iterations, k, avg_nr_edges, spread=0.1, regul=0.15, norm_par=1.5, alpha=None)
print('Training done')

print("sum(gamma_hats,1):", np.sum(gamma_hats, axis=0))

# If identify_and_compare returns (identify, precision, recall, f, cl_errors)
identify, precision, recall, f, cl_errors ,  NMI_scores, num_of_edges_arr = identify_and_compare(Ls, Lap, gamma_hats, gamma_cut, k)

print("Identify:", identify)
print("Precision:", precision)
print("Recall:", recall)
print("F-measure:", f)
print("Cluster Errors:", cl_errors)
print("NMI Scores:", NMI_scores)
print("Number of edges:", num_of_edges_arr)

summed_gamma_hats = np.sum(gamma_hats, axis=1)
are_all_elements_one = np.allclose(summed_gamma_hats, 1.0, atol=1e-8)
print("\nAre all elements in each row of gamma_hats summing to 1:", are_all_elements_one)


[0.  0.5 1. ]
A connected graph has been created in 1 iteration(s)
A connected graph has been created in 1 iteration(s)
0.07878855346143744
# iters:   70. Rel primal: 8.9638e-06 Rel dual: 4.2302e-08  OBJ 2.742e+00
Time needed is 0.119805 seconds
0.08179785505858687
# iters:   70. Rel primal: 9.1152e-06 Rel dual: 4.9647e-08  OBJ 2.716e+00
Time needed is 0.019137 seconds
0.09095758154826267
# iters:   69. Rel primal: 9.7773e-06 Rel dual: 5.9715e-08  OBJ 2.656e+00
Time needed is 0.003894 seconds
0.07413870013280317
# iters:   69. Rel primal: 9.9635e-06 Rel dual: 9.5741e-08  OBJ 2.817e+00
Time needed is 0.003863 seconds
0.0736164649653105
# iters:   69. Rel primal: 9.7319e-06 Rel dual: 7.2359e-08  OBJ 2.806e+00
Time needed is 0.003883 seconds
0.11045166022919503
# iters:   69. Rel primal: 9.7173e-06 Rel dual: 1.0531e-07  OBJ 3.568e+00
Time needed is 0.006017 seconds
0.16133604337230675


  return -a * np.sum(np.log(tmp))


# iters:   71. Rel primal: 9.4364e-06 Rel dual: 4.8111e-07  OBJ 8.226e+00
Time needed is 0.008220 seconds
0.10171010328349242
# iters:   68. Rel primal: 9.5978e-06 Rel dual: 3.1139e-07  OBJ 5.258e+00
Time needed is 0.008358 seconds
0.13688414444368954
# iters:   71. Rel primal: 9.5574e-06 Rel dual: 5.0772e-07  OBJ 8.233e+00
Time needed is 0.006137 seconds
0.21601116945836304
# iters:   75. Rel primal: 9.4982e-06 Rel dual: 3.1904e-07  OBJ 1.149e+01
Time needed is 0.004734 seconds
0.20922506831233245
# iters:   76. Rel primal: 8.9768e-06 Rel dual: 1.0664e-06  OBJ 1.261e+01
Time needed is 0.004277 seconds
0.20144692207082765
# iters:   74. Rel primal: 9.6098e-06 Rel dual: 5.2099e-07  OBJ 1.112e+01
Time needed is 0.008017 seconds
0.2245365992987501
# iters:   77. Rel primal: 9.0894e-06 Rel dual: 1.7282e-06  OBJ 1.320e+01
Time needed is 0.004332 seconds
0.19672426840975676
# iters:   74. Rel primal: 9.9721e-06 Rel dual: 5.3245e-07  OBJ 1.101e+01
Time needed is 0.004353 seconds
0.22446938057