In [157]:
import copy
import numpy as np
import tensorly as tl
from tensorly.decomposition import parafac
import torch 
import time
tl.set_backend('pytorch')

In [158]:
class CP_ALS():
    """
    This class computes the Candecomp PARAFAC decomposition using 
    N-way Alternating least squares algorithm along with khatri rao product
    """
    def moveaxis(self, tensor: torch.Tensor, source: int, destination: int) -> torch.Tensor:
        """
        This method is from the implementation given in pytorch 
        https://github.com/pytorch/pytorch/issues/36048#issuecomment-652786245
        """
        dim = tensor.dim()
        perm = list(range(dim))
        if destination < 0:
            destination += dim
        perm.pop(source)
        perm.insert(destination, source)
        return tensor.permute(*perm)
    
    def unfold_tensor(self, tensor, mode):
        """ This method unfolds the given input tensor along with the specified mode.
        Input :
            tensor : Input tensor
            mode : Specified mode of unfolding
        Output :
            matrix : Unfolded matrix of the tensor with specified mode
        """
        #t = tensor.transpose(mode, 0)
        t = self.moveaxis(tensor, mode, 0)
        matrix = t.reshape(tensor.shape[mode], -1)
        return matrix
        #return torch.reshape(torch.moveaxis(tensor, mode, 0), (tensor.shape[mode], -1))

    ## Old functions
    #def perform_Kronecker_Product(self, t1, t2):
    #    t1_flatten = torch.flatten(t1)
    #    op = torch.empty((0, ))
    #    for element in t1_flatten:
    #        output = element*t2
    #        op = torch.cat((op, output))
    #    return op
    
    #def perform_Khatri_Rao_Product(self, t1, t2):
    #    # Check for criteria if the columns of both matrices are same
    #    r1, c1 = t1.shape
    #    r2, c2 = t2.shape
    #    if c1 != c2:
    #        print("Number of columns are different. Product can't be performed")
    #        return 0
    #    opt = torch.empty((r1*r2, c1))
    #    for col_no in range(0, t1.shape[-1]):
    #        x = self.perform_Kronecker_Product(t1[:, col_no], t2[:, col_no])
    #        opt[:, col_no] = x
    #    return opt
    
    # New functions
    def perform_Kronecker_Product(self, A, B):
        """ 
        This method performs the kronecker product of the two matrices
        The method is adaption of the method proposed in https://discuss.pytorch.org/t/kronecker-product/3919/10
        Input : 
            A : Input matrix 1
            B : Input matrix 2
        Output : 
            Output is the resultant matrix after kronecker product
        """
        return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))
    
    def perform_Khatri_Rao_Product(self, A, B):
        """
        This methods performs the Khatri Rao product as it is the column wise kronecker product
        Input : 
            A : Input matrix 1
            B : Input matrix 2
        Output : 
            result : The resultant Khatri-Rao product matrix
        """
        if A.shape[1] != B.shape[1]:
            print("Inputs must have same number of columns")
            return 0
        result = None
        for col in range(A.shape[1]):
            res = self.perform_Kronecker_Product(A[:, col].unsqueeze(0), B[:, col].unsqueeze(0))
            if col == 0:
                result = res
            else:
                result = torch.cat((result, res), dim = 0)
        return result.T

    def compute_MTTKRP(self, tensor_matrix, A, k_value):
        """
        This method computes the Matricized Tensor Times Khatri-Rao product
        between the unfolded tensor and the all other factors apart from kth factor.
        Input : 
            tensor_matrix : Unfolded tensor as a matrix
            A : Factor matrices
            k_value : index of kth matrix to be excluded
        Output : 
            B : Resultant MTTKRP matrix
        """
        A_matrix = copy.deepcopy(A)
        A_matrix.pop(k_value)
        krp_matrix = A_matrix[0]
        for index in range(1, len(A_matrix)):
            krp_matrix = self.perform_Khatri_Rao_Product(krp_matrix, A_matrix[index])
        B = torch.matmul(tensor_matrix, krp_matrix)
        return B
    
    def compute_V_Matrix(self, A, k_value):
        """
        This method computes the V value as a hadamard product of 
        outer product of every factort matrix apart from kth factor matrix.
        Input : 
            A : Factor matrices
            k_value : index of kth matrix to be excluded
        Output : 
            v : Resultant V matrix after the hadamard product
        """
        A_matrix = copy.deepcopy(A)
        A_matrix.pop(k_value)
        v = torch.matmul(A_matrix[0].T, A_matrix[0])
        for index in range(1, len(A_matrix)):
            p = torch.matmul(A_matrix[index].T, A_matrix[index])
            v = v*p
        return v
    
    def create_A_Matrix(self, tensor_shape, rank):
        """
        This method generates required number of factor matrices.
        Input : 
            tensor_shape : shape of the input tensor
            rank : Required rank of the factors
        Output : 
            A : Resultant list of factor matrices
        """
        A = []
        for i in tensor_shape:
            A.append(torch.randn((i, rank)))
        return A
    
    def compute_ALS(self, input_tensor, max_iter, rank):
        """
        This method is heart of this algorithm, this computes the factors and also lambdas of the algorithm.
        Input : 
            input_tensor : Tensor containing input values
            max_iter : maximum number of iterations
            rank : prescribed rank of the resultant factors
        Output : 
            A : factor matrices
            lmbds : column norms of each factor matrices
        """
        A = self.create_A_Matrix(input_tensor.shape, rank)
        lmbds = []
        for l_iter in range(0, max_iter):
            for k in range(0, len(A)):
                X_unfolded = self.unfold_tensor(input_tensor, k)
                #X_unfolded = tl.unfold(input_tensor, mode = k)
                #try:
                ##print(torch.all(torch.eq(X_unfolded_mine, X_unfolded)))
                    #assert torch.all(torch.eq(X_unfolded_mine, X_unfolded))
                #except:
                    #print("Assertion failed")
                    #print("Unfolding from the implemented function: ")
                    #print(X_unfolded_mine)
                    #print("Unfolding from the tensorly library: ")
                    #print(X_unfolded)
                    #print("Iteration: ")
                    #print(l_iter)
                    #print("Mode: ")
                    #print(k)
                Z = self.compute_MTTKRP(X_unfolded, A, k)
                V = self.compute_V_Matrix(A, k)
                A_k = torch.matmul(Z, torch.pinverse(V))
                l = torch.norm(A_k, dim=0)
                d_l = np.zeros((rank, rank))
                np.fill_diagonal(d_l, l)
                #A_k = np.dot(A_k, np.linalg.pinv(d_l))
                if l_iter == 0:
                    lmbds.append(np.linalg.norm(l))
                else:
                    lmbds[k] = np.linalg.norm(l)
                A[k] = A_k
        return A, lmbds
    
    def reconstruct_tensor(self, factors, norm, rank, ip_shape):
        """
        This method reconstructs the tensor given factor matrices and norms
        Input : 
            factors : factor matrices
            norm : column norms of every factor matrices
            rank : prescribed rank of the resultant factors
            ip_shape : Input tensor shape 
        Output : 
            M : Reconstructed tensor
        """
        M = 0       
        for c in range(0, rank):
            op = factors[0][:, c]
            for i in range(1, len(factors)):
                op = np.outer(op.T, factors[i][:, c])
            M += op
        M = np.reshape(M, ip_shape)
        return M

    def reconstruct_Three_Way_Tensor(self, a, b, c):
        """This method reconstructs the tensor from the rank one factor matrices
        Inputs: 
            a : First factor in CP decomposition
            b : Second factor in CP decomposition
            c : Third factor in CP decomposition
        Output:
            x_t : Reconstructed output tensor"""

        x_t = 0
        #row, col = a.shape()
        for index in range(a.shape[1]):
            x_t += torch.ger(a[:,index], b[:,index]).unsqueeze(2)*c[:,index].unsqueeze(0).unsqueeze(0)
        return x_t

    # Reconstruct the tensor from the factors
    def reconstruct_Four_Way_Tensor(self, a, b, c, d):
        """This method reconstructs the tensor from the rank one factor matrices
        Inputs: 
            a : First factor in CP decomposition
            b : Second factor in CP decomposition
            c : Third factor in CP decomposition
            d : Fourth factor in CP decomposition
        Output:
            x_t : Reconstructed output tensor"""

        x_t = 0
        #row, col = a.shape()
        for index in range(a.shape[1]):
            Y = (torch.ger(a[:, index], b[:, index]).unsqueeze(2)*c[:, index]).unsqueeze(3)*d[:,index].unsqueeze(0).unsqueeze(0)
            x_t += Y
            #x_t += torch.ger(a[:,index], b[:,index]).unsqueeze(2)*c[:,index].unsqueeze(0).unsqueeze(0)
        return x_t

In [234]:
ip_shape = (3, 3, 3)
r_state = 0
max_iter = 100   
r = 64
torch.manual_seed(0)
X_tensor = torch.randn(ip_shape)

In [235]:
cp_als = CP_ALS()

In [236]:
start = time.time()
A, lmbds =  cp_als.compute_ALS(X_tensor, max_iter, r)
end = time.time()
print("Run time in seconds: ", end-start)

Run time in seconds:  5.731144666671753


In [237]:
print(A[0])
print(A[1])
print(A[2])

tensor([[-1.4708e-01, -5.1026e-03,  6.4594e-01, -8.3004e-01,  6.7953e-02,
          1.2228e+00,  1.9016e+00, -5.2583e-02, -3.7303e-01, -7.6715e-01,
          3.5953e-01,  5.6924e-01, -8.6432e-01, -2.6314e-01,  7.4567e-01,
         -1.1696e+00,  3.6001e-01,  2.1652e-07,  1.5804e-08, -3.1078e-08,
         -2.1852e-08,  4.0365e-10,  7.5392e-09, -1.3658e-08,  4.0514e-09,
         -2.5962e-10, -3.0019e-10,  4.6688e-09,  6.7207e-09,  7.3426e-09,
          2.3377e-09, -1.7781e-08, -2.5469e-10,  6.1672e-11,  4.3342e-17,
         -7.0002e-20, -1.0522e-22,  7.5701e-02, -3.1923e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.9755e-01, -2.0322e-01,  1.9703e-02, -2

In [238]:
print(X_tensor)

tensor([[[-1.1258, -1.1524, -0.2506],
         [-0.4339,  0.8487,  0.6920],
         [-0.3160, -2.1152,  0.3223]],

        [[-1.2633,  0.3500,  0.2660],
         [ 0.1665,  0.8744, -0.1435],
         [-0.1116, -0.6136,  0.0316]],

        [[-0.4927,  0.0537,  0.6181],
         [-0.4128, -0.8411, -2.3160],
         [-0.1023,  0.7924, -0.2897]]])


In [239]:
recon_tensor = cp_als.reconstruct_Three_Way_Tensor(A[0], A[1], A[2])

In [240]:
print(recon_tensor)

tensor([[[-1.1750, -1.2421, -1.5101],
         [-0.2495,  1.1822,  2.5442],
         [-0.2710, -2.1294, -0.5619]],

        [[-1.2854,  0.3125, -0.3325],
         [ 0.1641,  0.8365,  0.4278],
         [-0.1397, -0.5837, -0.4065]],

        [[-0.4803,  0.0993,  1.0400],
         [-0.5284, -0.9748, -3.3126],
         [-0.1776,  0.9075, -0.0899]]])


In [245]:
start = time.time()
X = parafac(X_tensor, r, n_iter_max= max_iter)[1]
end = time.time()
print("Run time: ", end-start)

Run time:  1.3552627563476562


In [246]:
recon_tensor_tl = cp_als.reconstruct_Three_Way_Tensor(X[0], X[1], X[2])

In [247]:
print(recon_tensor_tl)

tensor([[[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]]])


In [248]:
print(X[0], X[1], X[2])

tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, na