# Factorwise Averages

In [None]:
import numpy as np

In [None]:
def factorwise_matrix(
    mat: "Mat to be factored",
    ds: "Dimensions of factor matrices",
    idx: "Indices",
    k: "Dimension"
):
    """
    A(i, j | k) notation from TeraLasso
    """
    
    # Stride-blocking based off of StackOverflow answer:
    # https://stackoverflow.com/a/8070716/10642078
    d = ds[k]
    d_left = np.prod(ds[:k]).astype(int)
    d_right = np.prod(ds[k+1:]).astype(int)
    size = d_left * d * d_right
    
    sz = mat.itemsize
    shape = (d * d_right, d * d_right, d_left, d_left)

    strides = sz * np.array([
        mat.shape[0] * d_left,
        d_left,
        mat.shape[0],
        1
    ])
    blocks = np.lib.stride_tricks.as_strided(mat, shape=shape, strides=strides)
    
    # Now that we've got the strides, we need to pick the right one based on the idxs
    i, j = idx
    
    # Last two dimensions are the sizes of the blocks
    specific_block =  blocks[i::d, j::d]
    
    # Now we need to de-stride the block
    return np.concatenate(
        np.concatenate(
            specific_block,
            axis=1
        ),
        axis=1
    )
    
n=8
m=8
a = np.arange(n*m).reshape(n,m)
print(a)
out = factorwise_matrix(a, [2, 2, 2], (0, 0), 1)

out

[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]
 [32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47]
 [48 49 50 51 52 53 54 55]
 [56 57 58 59 60 61 62 63]]


array([[ 0,  1,  4,  5],
       [ 8,  9, 12, 13],
       [32, 33, 36, 37],
       [40, 41, 44, 45]])

In [None]:
def factorwise_average(
    mat: "Mat to be factored",
    ds: "Dimensions of factor matrices",
    k: "Dimension"
):
    d = ds[k]
    d_left = np.prod(ds[:k]).astype(int)
    d_right = np.prod(ds[k+1:]).astype(int)
    d_non = d_left * d_right
    out = np.zeros((d_non, d_non))
    for i in range(d):
        out += factorwise_matrix(mat, ds, (i, i), k)
    return out / d_non
factorwise_average(a, [2, 2, 2], 1)

array([[ 4.5,  5. ,  6.5,  7. ],
       [ 8.5,  9. , 10.5, 11. ],
       [20.5, 21. , 22.5, 23. ],
       [24.5, 25. , 26.5, 27. ]])

In [None]:
def kronecker_factor(
    mat: "Mat to be factored",
    ds: "Dimensions of factor matrices",
    k: "Dimension"
):
    K = len(ds)
    A = factorwise_average(mat, ds, k)
    offset = (K-1)/K * np.trace(A) / ds[k]
    return A - offset * np.eye(A.shape[0])

kronecker_factor(a, [2, 2, 2], 1)

array([[-16.66666667,   5.5       ,   7.        ,   7.5       ],
       [  9.        , -12.16666667,  11.        ,  11.5       ],
       [ 21.        ,  21.5       ,   1.33333333,  23.5       ],
       [ 25.        ,  25.5       ,  27.        ,   5.83333333]])

In [None]:
def kron_sum(A, B):
    """
    Computes the kronecker sum of two square input matrices
    
    Note: `scipy.sparse.kronsum` is a thing that would
    be useful - but it seems that `scipy.sparse` is not
    yet a mature library to use.
    """
    a, _ = A.shape
    b, _ = B.shape
    return np.kron(A, np.eye(b)) + np.kron(np.eye(a), B)

In [None]:
factor_1 = np.arange(9).reshape(3, 3) - 4
factor_2 = np.arange(9).reshape(3, 3) - 4
b = kron_sum(factor_1, factor_2)
print(factor_1 == kronecker_factor(b, [3, 3], 0))
print(factor_2 == kronecker_factor(b, [3, 3], 1))

[[ True  True  True]
 [ True  True  True]
 [ True  True  True]]
[[ True  True  True]
 [ True  True  True]
 [ True  True  True]]


In [None]:
factor_1 = np.arange(9).reshape(3, 3) - 4
factor_2 = np.arange(9).reshape(3, 3)
factor_2 = factor_2 @ factor_2 - 60
b = kron_sum(factor_1, factor_2)
print(factor_1 == kronecker_factor(b, [3, 3], 0))
print(factor_2 == kronecker_factor(b, [3, 3], 1))

[[ True  True  True]
 [ True  True  True]
 [ True  True  True]]
[[ True  True  True]
 [ True  True  True]
 [ True  True  True]]


In [None]:
# TODO: Fix this
factor_1 = np.arange(9).reshape(3, 3) - 4
factor_2 = np.arange(4).reshape(2, 2) - 1.5
b = kron_sum(factor_1, factor_2)
print(factor_1)
print(kronecker_factor(b, [3, 2], 0))
print(factor_2)
print(kronecker_factor(b, [3, 2], 1))

[[-4 -3 -2]
 [-1  0  1]
 [ 2  3  4]]
[[-4.75  0.  ]
 [ 0.    4.75]]
[[-1.5 -0.5]
 [ 0.5  1.5]]
[[-1.33333333 -0.16666667 -0.66666667]
 [ 0.16666667  0.         -0.16666667]
 [ 0.66666667  0.16666667  1.33333333]]


## Fast KS-Diag

In [None]:
def kronsum_diag(
    *lams: "1D vectors to be kronsummed"
):
    ds = [len(lam) for lam in lams]
    skips = np.cumprod([1] + ds)
    total = skips[-1]
    
    out = np.arange(total)#np.zeros(total)
    print(out)
    
    for i in range(len(lams)):
        #out[::skips[i]] += lams[i]
        amnt = lams[i].shape[0]
        sz = lams[i].itemsize
        skip = skips[i]
        toset = np.lib.stride_tricks.as_strided(
            out,
            shape=(amnt, total // amnt),
            strides=(sz * skip, sz * total // skip)
        )
        print((skip, total // skip))
        print(toset)
    
    
kronsum_diag(
    np.arange(2),
    np.arange(3),
    np.arange(4)
)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
(1, 24)
[[0 63606911108032 63606911108224 0 0 0 0 0 0 0 0 0]
 [1 7 2043 0 0 0 0 0 0 0 0 0]]
(2, 12)
[[ 0 12 63606911108032  0 63606911108224  0  0  0]
 [ 2 14  4 14  4 14  0  0]
 [ 4 16 14 14 14 14  0  0]]
(6, 4)
[[ 0  4  8 12 16 20]
 [ 6 10 14 18 22  8]
 [12 16 20 63606911108032 16 14]
 [18 22  8  6 22 20]]
