In [452]:
import numpy as np
from scipy import linalg
from dataclasses import dataclass, field
import warnings
import math

In [23]:
# Define my types
# if working in Python 3.11 or earlier, replace:
#   type A = B
# with:
#   from typing import TypeAlias
#   A: TypeAlias = B
type DataMatrix = np.ndarray
type GramMatrix = np.ndarray
type GramMatrices = list[GramMatrix]
type Axis = int
type FactorMatrix = np.ndarray
type DiagonalOffset = int
type NumToKeep = int
type SparsityPercent = float
type StepSize = float

In [762]:
@dataclass
class KSParameterization:
    """
    (kronsum_i^K L_i) + c_i * I
    where trace(L_i) = 0
    """
    Ls: list[FactorMatrix]
    c: DiagonalOffset = field(default=0.0)
    K: int = field(init=False)
    ds: list[int] = field(init=False)

    def __post_init__(self):
        self.K = len(self.Ls)
        if self.K <= 1:
            raise ValueError("K must be greater than 1")
        
        for L in self.Ls:
            if len(L.shape) != 2:
                raise ValueError("L must be a square 2D matrix")
            if L.shape[0] != L.shape[1]:
                raise ValueError("L must be a square matrix")
            
        self.ds = [L.shape[0] for L in self.Ls]

        self.identify()

    def identify(self):
        """
        Adjust the parameterization to an identifiable form.
        """
        for L in self.Ls:
            trL = np.trace(L)
            L -= (trL / L.shape[0]) * np.eye(L.shape[0])
            self.c += trL / L.shape[0]

    def __add__(self, other: "KSParameterization") -> "KSParameterization":
        """
        Add two parameterizations.
        """
        if self.K != other.K:
            raise ValueError("K must be the same for both parameterizations")
        
        new_Ls = [self.Ls[i] + other.Ls[i] for i in range(self.K)]
        new_c = self.c + other.c
        return KSParameterization(new_Ls, new_c)

    def __mul__(self, scalar: float) -> "KSParameterization":
        """
        Scale the parameterization by a scalar.
        """
        new_Ls = [scalar * L for L in self.Ls]
        new_c = scalar * self.c
        return KSParameterization(new_Ls, new_c)
    
    def __rmul__(self, scalar: float) -> "KSParameterization":
        """
        Reverse multiplication with a scalar.
        """
        return self * scalar
    
    def __sub__(self, other: "KSParameterization") -> "KSParameterization":
        return self + (-1 * other)
    
    def actualize(self) -> np.ndarray:
        """
        Compute Kronecker sum of factor matrices.
        """
        D = np.prod(self.ds)
        out = np.zeros((D, D))
        for i in range(self.K):
            D_less = int(np.prod(self.ds[:i]))
            D_more = int(np.prod(self.ds[i + 1:]))
            L = self.Ls[i]
            out += np.kron(np.eye(D_less), np.kron(L, np.eye(D_more)))
        out += self.c * np.eye(D)
        return out

In [833]:
def partial_trace_gradient(
    L: KSParameterization,
    ell: Axis,
    X: DataMatrix,
    Ss: GramMatrices
) -> FactorMatrix:
    # We often add an L.c / L.K term as our parameterization of L splits into
    # trace-0 factor matrices and then a constant addition on diagonal;
    # this is the identifiable parameterization!
    gradient = 2 * (L.Ls[ell] + (L.c / L.K) * np.eye(L.ds[ell])) @ Ss[ell]

    # Matricize the data
    for i in range(L.K):
        if i == ell:
            # Add all other factors, i.e. not including original factor ell itself
            continue
        # This makes it so that first axis is ell, second is i
        newX = np.moveaxis(X, i, 0)
        if i < ell:
            newX = np.moveaxis(newX, ell, 0)
        else:
            # The position of ell has shifted
            newX=  np.moveaxis(newX, ell+1, 0)

        # And then we perform the indicated summation
        other = L.Ls[i]+(L.c/L.K)*np.eye(L.ds[i])
        gradient += np.einsum("ab...,bc,dc...->ad", newX, other + other.T, newX)

    # We are optimizing over triangular matrices
    return np.tril(gradient)

def full_trace_gradient(
    L: KSParameterization,
    X: DataMatrix,
    Ss: GramMatrices
) -> KSParameterization:
    return KSParameterization([
        partial_trace_gradient(L, ell, X, Ss)
        for ell in range(L.K)
    ])

def logterm_gradient(
    L: KSParameterization,
    X: DataMatrix,
    Ss: GramMatrices
) -> KSParameterization:
    
    T: np.ndarray = np.diag(L.Ls[0])
    for i in range(1, L.K):
        T = np.add.outer(T, np.diag(L.Ls[i]))
    T += L.c
    T = -2 / T

    return KSParameterization([
        np.diag(np.moveaxis(T, i, 0).reshape(L.ds[i], -1).sum(axis=1))
        for i in range(L.K)
    ])

def full_gradient(
    L: KSParameterization,
    X: DataMatrix,
    Ss: GramMatrices
) -> KSParameterization:
    """
    Compute the full gradient of the log likelihood with respect to the parameters.
    """
    return full_trace_gradient(L, X, Ss)# + logterm_gradient(L, X, Ss) # TODO: add back

def objective(
    L: KSParameterization,
    X: DataMatrix
) -> float:
    """
    The value of our objective function.
    """

    # Get value of log term
    T: np.ndarray = np.diag(L.Ls[0])
    for i in range(1, L.K):
        T = np.add.outer(T, np.diag(L.Ls[i]))
    #logdet: float = 2 * np.sum(np.log(T+L.c))

    # Get value of trace term
    to_sum = np.zeros_like(X)
    for i in range(L.K):
        X_swap = X.swapaxes(0, i)
        changed = (L.Ls[i] + (L.c / L.K) * np.eye(L.ds[i])) @ X_swap
        to_sum += changed.swapaxes(0, i)
    trace_term: float = (to_sum**2).sum()

    return trace_term# - logdet # TODO: add back

In [374]:
def hard_threshold_factor(
    L: FactorMatrix,
    s: NumToKeep
) -> FactorMatrix:
    """
    Hard thresholding of a factor matrix.
    Keep only the s largest off-diagonal elements by absolute value.
    """
    # Get the absolute values of the off-diagonal elements
    L_flat = np.abs(L[np.tril_indices(L.shape[0], -1)])
    
    # Sort the values
    sorted_indices = np.argsort(L_flat)[::-1]
    sorted_values = L_flat[sorted_indices]

    # Get the threshold value
    if s > 0:
        threshold = sorted_values[s - 1] if s < len(sorted_values) else 0
    else:
        threshold = np.inf

    # Remove all elements of L that are below the threshold
    new_L = L.copy()
    new_L[np.abs(new_L) < threshold] = 0
    
    # Ensure the diagonals got carried over correctly
    np.fill_diagonal(new_L, 0)
    new_L += np.diag(np.diag(L))
    
    return new_L

def hard_threshold(
    L: KSParameterization,
    s: NumToKeep
) -> KSParameterization:
    """
    Hard thresholding of a KSParameterization.
    Keep only the s largest off-diagonal elements by absolute value for each factor matrix.
    """
    return KSParameterization([
        hard_threshold_factor(L.Ls[i], s)
        for i in range(L.K)
    ], L.c)


In [815]:
def iterative_hard_thresholding(
    X: DataMatrix,
    *,
    s: NumToKeep = None,
    sp: SparsityPercent = None,
    eta: StepSize = None,
    L_init: KSParameterization = None,
    max_iter: int = 100,
    tol: float = 1e-6,
    verbose: bool = False,
    verbose_every: int = 10,
    warn_backtrack: bool = True
) -> KSParameterization:
    """
    Iterative hard thresholding for the KSParameterization.
    """

    # Ensure at least one of s or sp is provided
    if s is None and sp is None:
        raise ValueError("Either s or sp must be provided.")
    if s is not None and sp is not None:
        raise ValueError("Only one of s or sp should be provided.")
    
    off_diagonal_elements = sum((d**2-d)/2 for d in X.shape)
    if sp is not None:
        if sp < 0 or sp > 1:
            raise ValueError(f"sp must be between 0 and 1, is {sp}.")
        # Convert sparsity percent to number of elements
        s = int(off_diagonal_elements * sp)
    if s < 0 or s > off_diagonal_elements:
        raise ValueError(
            f"s must be between 0 and the number of off-diagonal elements, is {s}."
        )

    # Initialize the parameterization
    if L_init is None:
        L = KSParameterization([np.eye(d) for d in X.shape])
    else:
        L = L_init
    
    # Choose a good step size
    if eta is None:
        # In the paper, we see that we have smoothness constant L=3*||X||^2
        # and that 2/(3L) is a good step size
        eta = 2 / (9 * (X**2).sum())

    # Compute the Gram matrices
    Ss = [None for _ in range(X.ndim)]
    for i in range(X.ndim):
        X_mat = X.swapaxes(0, i).reshape(X.shape[i], -1)
        Ss[i] = X_mat @ X_mat.T

    # Iteratively update the parameterization
    prev_objective_value = np.inf
    for iter in range(max_iter):
        grad = full_gradient(L, X, Ss)
        #grad = partial_trace_gradient(L, iter % 2, X, Ss)
        #gradL = 0 * L
        #gradL.c = np.trace(grad) / grad.shape[0]
        #gradL.Ls[iter % 2] = grad - gradL.c * np.eye(grad.shape[0])
        #grad = gradL
        print(grad)
        L_old = L
        
        #print(L.Ls[0]+(L.c/L.K)*np.eye(L.ds[0]))
        L -= eta * grad
        #print(L.Ls[0]+(L.c/L.K)*np.eye(L.ds[0]))
        #print(grad.Ls[0]+(grad.c/grad.K)*np.eye(grad.ds[0]))
        #1/0

        # Hard thresholding
        #L = hard_threshold(L, s)
        #print(L.Ls[0]+(L.c/L.K)*np.eye(L.ds[0]))
        #1/0

        # Check convergence
        objective_value = objective(L, X)

        if math.isnan(objective_value):
            print(L_old)
            print(L)
            print(grad)
            raise ValueError("Objective function returned NaN, check your input data.")

        delta = prev_objective_value - objective_value
        prev_objective_value = objective_value
        if abs(delta) < tol:
            if verbose:
                print(f"Converged after {iter} iterations.")
            break

        if delta > 0 and warn_backtrack:
            warnings.warn("Objective function increased, check your step size.")

        if verbose and iter % verbose_every == 0:
            print(f"Iteration {iter}:\n\tDelta = {delta:.6f}\n\tObjective = {objective_value:.6f}")
    else:
        if verbose:
            print("Max iterations reached without convergence.")

    return L

In [25]:
def generate_data(L: KSParameterization) -> DataMatrix:
    """
    Generates data from the Cartesian LGAM model.

    Amusingly, even though this is a "generative model", a naive
    approach is **very** computationally expensive.

    We need to solve L^-1 * e where e is i.i.d. standard normal
    and L is a Kronecker sum of matrices.

    This could be done in O(d^[3K]) time... i.e. sextic for K=2!!!

    A less-naive way is to use specialized triangular-solvers.  These
    still require the computation of L and the solve, which is O(d^(2K)) time,
    i.e. quartic for K=2.  That's what is implemented here.

    A more advanced method can be achieved with the Bartels–Stewart algorithm
    in the K=2 case, which is implemented in scipy as `scipy.linalg.solve_sylvester`.
    But scipy does not implement this in the K>2 case, so I have not either.
    In the K=2 case, though, runtime is cubic - which is about as good as we could
    have ever hoped for!
    """

    # Half-naive approach: actualize L and solve triangular system:
    L_actual = L.actualize()

    # Generate i.i.d. standard normal data
    e = np.random.normal(size=L_actual.shape[0])

    # Solve the triangular system
    x = linalg.solve_triangular(L_actual, e, lower=True)

    # Reshape the data to the original dimensions
    data_shape = tuple(L.ds)
    data = x.reshape(data_shape)
    return data

In [173]:
def generate_KSParameterization(
    ds: list[int],
    sp: SparsityPercent
) -> DataMatrix:
    # Generate random triangular matrices with strictly positive diagonals
    Ls = [np.random.normal(size=(d, d)) for d in ds]
    for i, L in enumerate(Ls):
        # Make the matrix triangular
        L = np.tril(L)

        # Make the diagonal positive
        L_diag = np.diag(np.diag(L))
        np.fill_diagonal(L, 0)
        L += np.abs(L_diag)

        # Make the diagonal strictly positive
        L += np.eye(L.shape[0]) * 1e-6

        # Remove ~sp% of the off-diagonal elements randomly
        # Do this by creating a Bernoulli mask
        mask = np.random.rand(*L.shape) > sp

        # Ensure diagonal is not masked
        np.fill_diagonal(mask, 0)
        L[mask] = 0
        Ls[i] = L

    # Create the KSParameterization
    L = KSParameterization(Ls)
    return L

In [841]:
sp = 1
L_star = generate_KSParameterization((3, 4), sp)
X = generate_data(L_star)
#X /= (X**2).sum()
test_eta = 1#2/(3*(X**2).sum())
#print(test_eta)

L_hat = iterative_hard_thresholding(X, sp=1, verbose=True, verbose_every=1, max_iter=100, eta = test_eta, L_init=L_star)
L_hat.Ls[0] + (L_hat.c / L_hat.K) * np.eye(L_hat.ds[0])

KSParameterization(Ls=[array([[-1.93749664,  0.        ,  0.        ],
       [ 0.12048679,  3.06548227,  0.        ],
       [ 0.1408981 ,  1.58798185, -1.12798564]]), array([[ 1.57353672,  0.        ,  0.        ,  0.        ],
       [-1.32897904,  0.0869276 ,  0.        ,  0.        ],
       [ 1.54311226, -0.24808338, -0.32758958,  0.        ],
       [ 1.93695066, -0.49409615,  1.27725104, -1.33287474]])], c=np.float64(4.155440565414407), K=2, ds=[3, 4])
Iteration 0:
	Delta = inf
	Objective = 149.721445
KSParameterization(Ls=[array([[ 12.04558083,   0.        ,   0.        ],
       [  0.95419689, -23.65797388,   0.        ],
       [ -0.31016177,  -2.81830538,  11.61239306]]), array([[ -9.66978966,   0.        ,   0.        ,   0.        ],
       [  8.38786695,   5.49299303,   0.        ,   0.        ],
       [-15.8177095 ,   5.24662946,  -3.50436094,   0.        ],
       [ -8.29057376,   2.34371168,  -6.87040178,   7.68115757]])], c=np.float64(-20.6254544709073), K=2, ds=[3,



array([[-3.18583952e+82,  0.00000000e+00,  0.00000000e+00],
       [-5.24110928e+81,  7.24881330e+83,  0.00000000e+00],
       [ 1.42340090e+81, -1.47804008e+83, -4.49201938e+81]])

In [377]:
display(L_star.Ls[0])
hard_threshold_factor(L_star.Ls[0], 2)

array([[ 0.04262894,  0.        ,  0.        ],
       [-1.98214612, -0.89980744,  0.        ],
       [ 0.70614549,  0.60471845,  0.8571785 ]])

array([[ 0.04262894,  0.        ,  0.        ],
       [-1.98214612, -0.89980744,  0.        ],
       [ 0.70614549,  0.        ,  0.8571785 ]])

In [585]:
proj_mat = np.array(
    [
        [1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1],
    ]
)
proj_mat2 = np.array(
    [
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [1, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [0, 0, 0],
    ]
)
test = np.tril(np.arange(1,10).reshape(3, 3), k=-1)
test_full = np.kron(test, test).reshape(1, -1)
print(test_full)

1/0
print(test @ proj_mat2 @ proj_mat2.T)
print((np.random.normal(size=(3, 3)).reshape(1, -1) @ proj_mat2 @ proj_mat2.T).reshape(3, 3))

P = proj_mat2 @ proj_mat2.T

X = generate_data(L_star)
S = X.reshape(-1, 1) * X.reshape(1, -1)
#print(np.linalg.eig(P @ np.kron(S, np.eye(9)) @ P).eigenvalues.min())
print(np.linalg.eig(proj_mat2.T @ np.kron(S, np.eye(9)) @ proj_mat2).eigenvalues.min())

[[ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0 16  0  0  0  0  0  0  0  0 28 32  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 28  0  0 32  0  0  0  0  0
  49 56  0 56 64  0  0  0  0]]


ZeroDivisionError: division by zero

In [626]:
s = 5
test = np.tril(np.arange(1,s**2+1).reshape(s, s), k=-1)
test_full = np.kron(test, test).reshape(1, -1)
print(test_full)

P = np.zeros(((test_full > 0).sum(), test_full.shape[1]))
print(P.shape)
for idx, val in enumerate(np.where(test_full.reshape(-1) != 0)[0]):
    P[idx, val] = 1
print(P @ test_full.reshape(-1))

L_star = generate_KSParameterization((s, s), sp)
X = generate_data(L_star)
S = np.kron(X.reshape(-1, 1) * X.reshape(1, -1), np.eye(s**2))
print(S.shape)
np.linalg.eigh(P @ S @ P.T).eigenvalues.min()

[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0  36   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0  66  72   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0  96 102 108   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0 126 132 138 144   0   0   0   0   0
    0   0   0   0   0   0   0   0   0 

np.float64(-1.7593405853817684e-13)

In [832]:
L_star = generate_KSParameterization((3, 4), 1)
X = generate_data(L_star)
L_star *= 10000
X /= (X**2).sum()
Ss = [None for _ in range(3)]
for i in range(2):
    X_mat = X.swapaxes(0, i).reshape(X.shape[i], -1)
    Ss[i] = X_mat @ X_mat.T
print(partial_trace_gradient(L_star, 0, X, Ss))

L1 = L_star.Ls[0] + (L_star.c / L_star.K) * np.eye(L_star.ds[0])
L2 = L_star.Ls[1] + (L_star.c / L_star.K) * np.eye(L_star.ds[1])
Ss = [X@X.T, X.T@X]
print(np.tril(2 * L1 @ Ss[0] + X @ (L2+L2.T) @ X.T))

[[121.01100424   0.           0.        ]
 [-20.98231325 212.41038805   0.        ]
 [-55.96243364 188.81583027 483.53322393]]
[[111.60204568   0.           0.        ]
 [-12.83292166 141.31708691   0.        ]
 [-50.93470447  67.095682   235.10454214]]


(L1 * I + I * L2)(L1 * I + I * L2)^T X
(L1L1^T * I)vXvX^T + (I * L2L2^T)vXvX^T + (L1 * L2^T)vXvX^T + (L1^T * L2)vXvX^T
v[X L1^T]v[X L1^T]^T v[L2 X]v[L2 X]^T + v[L2 X]v[X L1^T]^T + v[X L1^T]v[L2 X]^T
X L1^T L1 X^T + L2 X X^T L2^T + L2 X L1 X^T + X L1^T X^T L2

L1^1 L1 X^T X + L2^T L2 X X^T + L1 X^T L2 X + L1 X^T L2^T X
-> L1 X^T X + X^T L2^T X + X^T L2 X
-> L1 S + X^T sym L2 X
