This notebook aims to empirically validate theoretical results, to build confidence that we haven't made a math mistake!

In [1]:
import numpy as np

# Objective Manipulations

In this section, we verify that the objective (TODO: LATEX) can be manipulated into the form (TODO: LATEX)

In [6]:
def random_Mss(
    num_groups: int,
    num_factors: int,
    starting_size: int = 3
) -> tuple[list[list[np.ndarray]], np.ndarray]:
    """
    The first factor has size `starting_size`, and every subsequent factor is 1 larger.
    This is to easily catch errors of multiplying things of wrong dimensionality.

    We're not restricting to lower-triangular here!
    """
    Mss = [
        [
            np.random.normal(size=(starting_size+factor, starting_size+factor))
            for factor in range(num_factors)
        ]
        for group in range(num_groups)
    ]
    X = np.random.normal(size=[starting_size+i for i in range(num_factors)])
    return Mss, X

In [33]:
np.random.seed(0)
Mss, X = random_Mss(2, 2)

def raw_trace_objective(Mss, X):
    """
    tr[(sum_g kron_f M(g∈f)) (sum_g kron_f M(g∈f))^T S]
    """
    full_M = 0
    for group in Mss:
        cur_M = group[0]
        for idx, factor in enumerate(group):
            if idx == 0:
                continue
            cur_M = np.kron(cur_M, factor)
        full_M += cur_M
    S = X.reshape(-1, 1) @ X.reshape(1, -1)
    return np.trace(full_M @ full_M.T @ S)

def one_kp_objective(Mss, X):
    """
    sum_g1 sum_g2 tr[(kron_f M(g1∈f)) (kron_f M(g2∈f))^T S]
    """

    full_trace = 0
    S = X.reshape(-1, 1) @ X.reshape(1, -1)
    for group1 in Mss:
        for group2 in Mss:
            cur_M1 = group1[0]
            cur_M2 = group2[0]
            for idx, (factor1, factor2) in enumerate(zip(group1, group2)):
                if idx == 0:
                    continue
                cur_M1 = np.kron(cur_M1, factor1)
                cur_M2 = np.kron(cur_M2, factor2)
            full_trace += np.trace(cur_M1 @ cur_M2.T @ S)
    return full_trace

def nmode_formalization(Mss, X):
    """
    sum_g1 sum_g2 vec[[X;{M^T_{f∈g1}}_f]]^Tvec[[X;{M^T_{f∈g2}}_f]]
    """
    full_dot = 0
    for group1 in Mss:
        for group2 in Mss:
            left = X
            right = X

            for axis, (factor1, factor2) in enumerate(zip(group1, group2)):
                left = left.swapaxes(-1, axis).dot(factor1).swapaxes(-1, axis)
                right = right.swapaxes(-1, axis).dot(factor2).swapaxes(-1, axis)

            full_dot += (left * right).sum()
    return full_dot

def matricization_method(Mss, X, chosen_axis):
    """
    Pick an axis f
    sum_g1 sum_g2 tr[
        L_{f_∈g1}^T mat_f[[X;{L_{f'∈g_1}}_{f'!=f}]]
        mat_f[[X;{L_{f'∈g_1}}_{f'!=f}]]^T L_{f_∈g2}
    ]
    """
    result = 0
    for group1 in Mss:
        for group2 in Mss:
            left = X
            right = X

            for axis, (factor1, factor2) in enumerate(zip(group1, group2)):
                if axis == chosen_axis:
                    continue
                left = left.swapaxes(-1, axis).dot(factor1).swapaxes(-1, axis)
                right = right.swapaxes(-1, axis).dot(factor2).swapaxes(-1, axis)
            
            factor1 = group1[chosen_axis]
            factor2 = group2[chosen_axis]
            left = left.swapaxes(chosen_axis, 0).reshape(factor1.shape[0], -1)
            right = right.swapaxes(chosen_axis, 0).reshape(factor2.shape[0], -1)

            left = factor1.T @ left
            right = factor2.T @ right

            result += np.trace(left @ right.T)
    return result



print(raw_trace_objective(Mss, X))
print(one_kp_objective(Mss, X))
print(nmode_formalization(Mss, X))
print(matricization_method(Mss, X, 0))
print(matricization_method(Mss, X, 1))

98.7935337739534
98.7935337739534
98.79353377395341
98.79353377395341
98.79353377395341


In [61]:
np.random.seed(0)
Mss, X = random_Mss(2, 2)

def gradient(Mss, X, chosen_axis, chosen_group):
    """
    Pick an axis f
    sum_g1 sum_g2 tr[
        L_{f_∈g1}^T mat_f[[X;{L_{f'∈g_1}}_{f'!=f}]]
        mat_f[[X;{L_{f'∈g_1}}_{f'!=f}]]^T L_{f_∈g2}
    ]

    TODO: DIFFERENTIATE
    """
    result = 0
    for g1, group1 in enumerate(Mss):
        for g2, group2 in enumerate(Mss):
            if g1 != chosen_group and g2 != chosen_group:
                continue
            left = X
            right = X

            for axis, (factor1, factor2) in enumerate(zip(group1, group2)):
                if axis == chosen_axis:
                    continue
                left = left.swapaxes(-1, axis).dot(factor1).swapaxes(-1, axis)
                right = right.swapaxes(-1, axis).dot(factor2).swapaxes(-1, axis)
            
            factor1 = group1[chosen_axis]
            factor2 = group2[chosen_axis]
            left = left.swapaxes(chosen_axis, 0).reshape(factor1.shape[0], -1)
            right = right.swapaxes(chosen_axis, 0).reshape(factor2.shape[0], -1)

            if g1 == chosen_group and g2 == chosen_group:
                right = factor2.T @ right
                result += 2 * left @ right.T
            elif g2 == chosen_group:
                left = factor1.T @ left
                result += (left @ right.T).T
            elif g1 == chosen_group:
                right = factor1.T @ right
                result += left @ right.T
            else:
                raise Exception("How?")
    return np.tril(result)

def gradient_(Lss, X):        
        grads = [[None] * len(Lss[0])] * len(Lss)

        full_log_term = 0
        for Ls in Lss:
            cur_log_term = np.diag(Ls[0])
            for L in Ls[1:]:
                cur_log_term = np.add.outer(cur_log_term, np.diag(L))
            full_log_term += cur_log_term
        full_log_term = (-2 / full_log_term)

        for group, Ls in enumerate(Lss):
            for factor, L in enumerate(Ls):
                if False:
                    grads[group][factor] = np.zeros_like(L)
                    continue

                L_sum = sum([_Ls[factor] for _Ls in Lss])
                X_mat = X
                for other_factor, other_L in enumerate(Ls):
                    if other_factor == factor:
                        continue
                    #X_mat = np.tensordot(X_mat, other_L, [other_factor, 1])
                    X_mat = X_mat.swapaxes(-1, other_factor).dot(other_L).swapaxes(-1, other_factor)
                X_mat = X_mat.swapaxes(0, factor).reshape(L.shape[0], -1)
                S = X_mat @ X_mat.T
                trace_term = S @ L_sum

                trace_term = np.tril(2 * trace_term)
                log_term = np.diag(full_log_term.swapaxes(factor, 0).reshape(L.shape[0], -1).sum(axis=1))
                frob_term = 2*0*L

                grads[group][factor] = trace_term + 0*log_term + 0*frob_term

        return grads[0][0]

print(gradient(Mss, X, 0, 0))
print(gradient_(Mss, X))

[[12.72398205  0.          0.        ]
 [-6.13481815  1.25950755  0.        ]
 [ 6.15714248  0.05041531  3.29593484]]
[[ 15.95915275   0.           0.        ]
 [  8.10775864  16.47250927   0.        ]
 [  5.89994266 -21.17256839 -10.36551601]]
