# CP Decomposition using ALS

In [104]:
import torch
import numpy as np
import tensorly as tl # Used for verification

## Khatri-Rao Product

In [105]:
"""
Input: A: Array of matrices whose Khatri-Rao product is to be calculated
Output: B: Khatri-Rao product of the matrices in A
"""
def khatri_rao_product(A):
    N = len(A)
    num_cols = A[0].shape[1]
    x_dim = 1
    for i in range(N):
        x_dim *= A[i].shape[0]
    B = torch.empty((x_dim, num_cols))

    for i in range(num_cols):
        for j in range(N):
            if j == 0:
                kron_prod = A[j][:, i]
            else:
                kron_prod = torch.kron(kron_prod, A[j][:, i])
        B[:, i] = kron_prod

    return B

## Test Khatri Rao Product

In [106]:
"""Khatri-Rao product using tensorly"""
print("Khatri-Rao product using tensorly")
A = [torch.rand(2, 3), torch.rand(1, 3), torch.rand(3, 3)]
Result = tl.tenalg.khatri_rao(A)
print("Result Shape: ", Result.shape)
print("Result:")
print(Result)

"""Khatri Rao product using khatri_rao_product()"""
print("\nKhatri Rao product using khatri_rao_product()")
B = khatri_rao_product(A)
print("Result Shape: ", B.shape)
print("Result:")
print(B)

Khatri-Rao product using tensorly
Result Shape:  (6, 3)
Result:
[[0.05185665 0.09278145 0.00429458]
 [0.03109804 0.131765   0.01447489]
 [0.06528176 0.14203328 0.01315795]
 [0.08305731 0.05487684 0.02318068]
 [0.04980883 0.07793418 0.07813057]
 [0.10455993 0.0840075  0.07102218]]

Khatri Rao product using khatri_rao_product()
Result Shape:  torch.Size([6, 3])
Result:
tensor([[0.0519, 0.0928, 0.0043],
        [0.0311, 0.1318, 0.0145],
        [0.0653, 0.1420, 0.0132],
        [0.0831, 0.0549, 0.0232],
        [0.0498, 0.0779, 0.0781],
        [0.1046, 0.0840, 0.0710]])


## Mode-N Matricization

In [107]:
"""
Input:  1) tensor: Input tensor
        2) n: mode along which to matricize the tensor (mode is 0-indexed)
Output: matrix: n-mode matricization of the tensor
"""
# Take mode = 0 for the first mode, mode = 1 for the second mode, ...., mode = n-1 for nth mode
def mode_n_matricization(tensor, n):
    mode = n
    # Get the size of the original tensor
    sz = tensor.size()
    # print(f"tensor size: {sz}")

    # Permute the dimensions of the tensor to bring the chosen mode to the front
    # This will make it easy to reshape the tensor into a matrix along the chosen mode
    permuted_dimensions = list(range(len(sz)))
#     print('Before Permuation, dimensions: ', permuted_dimensions)
    permuted_dimensions.remove(mode)
    permuted_dimensions.insert(0, mode)
#     print('After Permuation, dimensions: ', permuted_dimensions)
    permuted_tensor = tensor.permute(*permuted_dimensions)
#     print(f"permuted tensor size: {permuted_tensor.size()}")

    # Reshape the permuted tensor into a matrix along the chosen mode
    matrix = permuted_tensor.reshape(sz[mode], -1)
#     print(f"matrix size: {matrix.size()}")

#     print(f"n-mode matricization along mode {mode}:")
#     print(matrix)

    return matrix

## Test Mode-N Matricization

In [108]:
my_tensor = torch.empty((2,3,4))
my_tensor[0, :, :] = torch.tensor([[1,4,7,10], [2,5,8,11], [3,6,9,12]])
my_tensor[1, :, :] = torch.tensor([[13,16,19,22], [14,17,20,23], [15,18,21,24]])
print(f"Original Tensor: {my_tensor.size()}")

for n in range(len(my_tensor.size())):
    print(f"\nn-mode matricization along mode {n}:")
    matrix = mode_n_matricization(my_tensor, n)
    print(matrix)

Original Tensor: torch.Size([2, 3, 4])

n-mode matricization along mode 0:
tensor([[ 1.,  4.,  7., 10.,  2.,  5.,  8., 11.,  3.,  6.,  9., 12.],
        [13., 16., 19., 22., 14., 17., 20., 23., 15., 18., 21., 24.]])

n-mode matricization along mode 1:
tensor([[ 1.,  4.,  7., 10., 13., 16., 19., 22.],
        [ 2.,  5.,  8., 11., 14., 17., 20., 23.],
        [ 3.,  6.,  9., 12., 15., 18., 21., 24.]])

n-mode matricization along mode 2:
tensor([[ 1.,  2.,  3., 13., 14., 15.],
        [ 4.,  5.,  6., 16., 17., 18.],
        [ 7.,  8.,  9., 19., 20., 21.],
        [10., 11., 12., 22., 23., 24.]])


## Einsum functions

In [109]:
"""
Input: N: Length of Array of factor matrices
Output: equation: equation string for torch.einsum()
"""
def einsum_equation(N):
    equation = ''    
    for i in range(N):
        if(i != 0):
            equation += ','
        equation += f'{chr(97 + i)}z'

    equation += '->'

    for i in range(N):
        equation += f'{chr(97 + i)}'
    # print(f"equation: {equation}")
    return equation

A = [torch.rand(2, 3), torch.rand(1, 3), torch.rand(3, 3)]
eqn = einsum_equation(len(A))
print(f"equation: {eqn}")

equation: az,bz,cz->abc


## CP-ALS

In [110]:
"""
Input:  1) T: Tensor of size (s_1, s_2, ..., s_N)
        2) R: CP-Rank of the Tensor
Output: 1) A: Array of N factor matrices of size (s_i, R)
"""
def CP_ALS(T, R, max_iter=100, tol=1e-5):
    N = len(T.shape)
    sz = T.shape
    A = []
    for n in range(N):
        An = torch.rand(sz[n], R)
        A.append(An)

    # Iterate until convergence
    for iter in range(max_iter):
        prev_factors = A.copy()

        for n in range(N):
            # Calculate the Khatri-Rao product of all the factor matrices except the n-th factor matrix
            B = khatri_rao_product(A[:n] + A[n+1:])
            B_pinv = torch.pinverse(B)

            # Calculate the n-mode matricization of the tensor
            Tn = mode_n_matricization(T, n)
            
            # Calculate the n-th factor matrix
            AnT = torch.mm(B_pinv, Tn.T)
            A[n] = AnT.T

        # Check for convergence
        if all(torch.norm(A[i] - prev_factors[i]) < tol for i in range(T.dim())):
            print(f"Converged at iteration {iter}")
            break

    return A  
    

## Test CP-ALS

In [116]:
"""CP Decomposition using tensorly"""
print("CP Decomposition using tensorly")
tensor = torch.rand((3, 4, 5, 6))
rank = 3

print(f"Rank: {rank}")
print(f"Tensor: {tensor.shape}")
print('Original Tensor:')
print(tensor)

T = tensor.numpy()
cp_decomp = tl.decomposition.CP(rank)
weights, factors = cp_decomp.fit_transform(T)
factor_matrices = [torch.from_numpy(f) for f in factors]

print('\nFactor Matrices:')
for i in range(len(factor_matrices)):
    print("\nFactor Matrix {} Shape: {}".format(i+1, factor_matrices[i].shape))
    print("Factor Matrix {}: ".format(i+1))
    print(factor_matrices[i])

reconstructed_tensor = tl.cp_tensor.cp_to_tensor((weights, factors))

# Norm difference between original tensor and reconstructed tensor
print('\nNorm difference between original tensor and reconstructed tensor: ', torch.norm(tensor - reconstructed_tensor).item())


"""CP Decomposition using CP_ALS()"""
print("\nCP Decomposition using CP_ALS()")
A = CP_ALS(tensor, rank)

print('Factor Matrices:')
for i in range(len(A)):
    print("\nFactor Matrix {} Shape: {}".format(i+1, A[i].shape))
    print("Factor Matrix {}: ".format(i+1))
    print(A[i])

# Norm difference between original tensor and reconstructed tensor
eqn = einsum_equation(len(A))
print('\nNorm difference between original tensor and reconstructed tensor: ', torch.norm(tensor - torch.einsum(eqn, *A)).item())

CP Decomposition using tensorly
Rank: 3
Tensor: torch.Size([3, 4, 5, 6])
Original Tensor:
tensor([[[[7.3303e-01, 2.7122e-01, 9.7729e-01, 1.4882e-01, 3.5822e-01,
           6.2768e-01],
          [4.2788e-01, 7.2460e-02, 3.1299e-03, 5.6017e-01, 3.8844e-01,
           6.4930e-01],
          [7.3361e-01, 5.9463e-02, 9.5667e-01, 9.0106e-01, 5.6100e-01,
           4.0784e-01],
          [4.9124e-02, 8.6707e-01, 5.6155e-01, 8.1165e-01, 4.1121e-02,
           3.8511e-01],
          [8.8898e-01, 2.9411e-01, 8.4099e-01, 3.9187e-01, 9.1817e-01,
           2.9660e-01]],

         [[4.5352e-01, 3.8871e-01, 1.3354e-02, 4.7324e-01, 6.5394e-02,
           7.7297e-01],
          [6.2402e-01, 6.4765e-01, 6.4291e-01, 3.0947e-04, 3.6882e-01,
           3.2977e-01],
          [5.3252e-01, 1.1666e-01, 3.0270e-01, 4.6903e-01, 9.8659e-01,
           5.3293e-01],
          [7.6254e-01, 5.6487e-01, 6.1243e-01, 3.9324e-01, 4.4052e-01,
           8.7232e-01],
          [6.7884e-01, 6.4211e-01, 3.4059e-01, 9.1122