In [1]:
import copy
import numpy as np
import tensorly as tl
from tensorly.decomposition import parafac
import torch 
import time
tl.set_backend('pytorch')

In [2]:
class CP_ALS():
    """
    This class computes the Candecomp PARAFAC decomposition using 
    N-way Alternating least squares algorithm along with khatri rao product
    """
    def moveaxis(self, tensor: torch.Tensor, source: int, destination: int) -> torch.Tensor:
        """
        This method is from the implementation given in pytorch 
        https://github.com/pytorch/pytorch/issues/36048#issuecomment-652786245
        """
        dim = tensor.dim()
        perm = list(range(dim))
        if destination < 0:
            destination += dim
        perm.pop(source)
        perm.insert(destination, source)
        return tensor.permute(*perm)
    
    def unfold_tensor(self, tensor, mode):
        """ This method unfolds the given input tensor along with the specified mode.
        Input :
            tensor : Input tensor
            mode : Specified mode of unfolding
        Output :
            matrix : Unfolded matrix of the tensor with specified mode
        """
        #t = tensor.transpose(mode, 0)
        t = self.moveaxis(tensor, mode, 0)
        matrix = t.reshape(tensor.shape[mode], -1)
        return matrix
        #return torch.reshape(torch.moveaxis(tensor, mode, 0), (tensor.shape[mode], -1))

    ## Old functions
    #def perform_Kronecker_Product(self, t1, t2):
    #    t1_flatten = torch.flatten(t1)
    #    op = torch.empty((0, ))
    #    for element in t1_flatten:
    #        output = element*t2
    #        op = torch.cat((op, output))
    #    return op
    
    #def perform_Khatri_Rao_Product(self, t1, t2):
    #    # Check for criteria if the columns of both matrices are same
    #    r1, c1 = t1.shape
    #    r2, c2 = t2.shape
    #    if c1 != c2:
    #        print("Number of columns are different. Product can't be performed")
    #        return 0
    #    opt = torch.empty((r1*r2, c1))
    #    for col_no in range(0, t1.shape[-1]):
    #        x = self.perform_Kronecker_Product(t1[:, col_no], t2[:, col_no])
    #        opt[:, col_no] = x
    #    return opt
    
    # New functions
    def perform_Kronecker_Product(self, A, B):
        """ 
        This method performs the kronecker product of the two matrices
        The method is adaption of the method proposed in https://discuss.pytorch.org/t/kronecker-product/3919/10
        Input : 
            A : Input matrix 1
            B : Input matrix 2
        Output : 
            Output is the resultant matrix after kronecker product
        """
        return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))
    
    def perform_Khatri_Rao_Product(self, A, B):
        """
        This methods performs the Khatri Rao product as it is the column wise kronecker product
        Input : 
            A : Input matrix 1
            B : Input matrix 2
        Output : 
            result : The resultant Khatri-Rao product matrix
        """
        if A.shape[1] != B.shape[1]:
            print("Inputs must have same number of columns")
            return 0
        result = None
        for col in range(A.shape[1]):
            res = self.perform_Kronecker_Product(A[:, col].unsqueeze(0), B[:, col].unsqueeze(0))
            if col == 0:
                result = res
            else:
                result = torch.cat((result, res), dim = 0)
        return result.T

    def compute_MTTKRP(self, tensor_matrix, A, k_value):
        """
        This method computes the Matricized Tensor Times Khatri-Rao product
        between the unfolded tensor and the all other factors apart from kth factor.
        Input : 
            tensor_matrix : Unfolded tensor as a matrix
            A : Factor matrices
            k_value : index of kth matrix to be excluded
        Output : 
            B : Resultant MTTKRP matrix
        """
        A_matrix = copy.deepcopy(A)
        A_matrix.pop(k_value)
        krp_matrix = A_matrix[0]
        for index in range(1, len(A_matrix)):
            krp_matrix = self.perform_Khatri_Rao_Product(krp_matrix, A_matrix[index])
        B = torch.matmul(tensor_matrix, krp_matrix)
        return B
    
    def compute_V_Matrix(self, A, k_value):
        """
        This method computes the V value as a hadamard product of 
        outer product of every factort matrix apart from kth factor matrix.
        Input : 
            A : Factor matrices
            k_value : index of kth matrix to be excluded
        Output : 
            v : Resultant V matrix after the hadamard product
        """
        A_matrix = copy.deepcopy(A)
        A_matrix.pop(k_value)
        v = torch.matmul(A_matrix[0].T, A_matrix[0])
        for index in range(1, len(A_matrix)):
            p = torch.matmul(A_matrix[index].T, A_matrix[index])
            v = v*p
        return v
    
    def create_A_Matrix(self, tensor_shape, rank):
        """
        This method generates required number of factor matrices.
        Input : 
            tensor_shape : shape of the input tensor
            rank : Required rank of the factors
        Output : 
            A : Resultant list of factor matrices
        """
        A = []
        for i in tensor_shape:
            A.append(torch.randn((i, rank)))
        return A
    
    def compute_ALS(self, input_tensor, max_iter, rank):
        """
        This method is heart of this algorithm, this computes the factors and also lambdas of the algorithm.
        Input : 
            input_tensor : Tensor containing input values
            max_iter : maximum number of iterations
            rank : prescribed rank of the resultant factors
        Output : 
            A : factor matrices
            lmbds : column norms of each factor matrices
        """
        A = self.create_A_Matrix(input_tensor.shape, rank)
        lmbds = []
        for l_iter in range(0, max_iter):
            for k in range(0, len(A)):
                X_unfolded = self.unfold_tensor(input_tensor, k)
                Z = self.compute_MTTKRP(X_unfolded, A, k)
                V = self.compute_V_Matrix(A, k)
                A_k = torch.matmul(Z, torch.pinverse(V))
                l = torch.norm(A_k, dim=0)
                d_l = np.zeros((rank, rank))
                np.fill_diagonal(d_l, l)
                #A_k = np.dot(A_k, np.linalg.pinv(d_l))
                if l_iter == 0:
                    lmbds.append(np.linalg.norm(l))
                else:
                    lmbds[k] = np.linalg.norm(l)
                A[k] = A_k
        return A, lmbds
    
    def reconstruct_tensor(self, factors, norm, rank, ip_shape):
        """
        This method reconstructs the tensor given factor matrices and norms
        Input : 
            factors : factor matrices
            norm : column norms of every factor matrices
            rank : prescribed rank of the resultant factors
            ip_shape : Input tensor shape 
        Output : 
            M : Reconstructed tensor
        """
        M = 0       
        for c in range(0, rank):
            op = factors[0][:, c]
            for i in range(1, len(factors)):
                op = np.outer(op.T, factors[i][:, c])
            M += op
        M = np.reshape(M, ip_shape)
        return M

    def reconstruct_Three_Way_Tensor(self, a, b, c):
        """This method reconstructs the tensor from the rank one factor matrices
        Inputs: 
            a : First factor in CP decomposition
            b : Second factor in CP decomposition
            c : Third factor in CP decomposition
        Output:
            x_t : Reconstructed output tensor"""

        x_t = 0
        #row, col = a.shape()
        for index in range(a.shape[1]):
            x_t += torch.ger(a[:,index], b[:,index]).unsqueeze(2)*c[:,index].unsqueeze(0).unsqueeze(0)
        return x_t

    # Reconstruct the tensor from the factors
    def reconstruct_Four_Way_Tensor(self, a, b, c, d):
        """This method reconstructs the tensor from the rank one factor matrices
        Inputs: 
            a : First factor in CP decomposition
            b : Second factor in CP decomposition
            c : Third factor in CP decomposition
            d : Fourth factor in CP decomposition
        Output:
            x_t : Reconstructed output tensor"""

        x_t = 0
        #row, col = a.shape()
        for index in range(a.shape[1]):
            Y = (torch.ger(a[:, index], b[:, index]).unsqueeze(2)*c[:, index]).unsqueeze(3)*d[:,index].unsqueeze(0).unsqueeze(0)
            x_t += Y
            #x_t += torch.ger(a[:,index], b[:,index]).unsqueeze(2)*c[:,index].unsqueeze(0).unsqueeze(0)
        return x_t

In [9]:
ip_shape = (256, 64, 11, 11)
r_state = 0
max_iter = 100   
r = 64
torch.manual_seed(0)
X_tensor = torch.randn(ip_shape)

In [None]:
cp_als = CP_ALS()

In [None]:
start = time.time()
A, lmbds =  cp_als.compute_ALS(X_tensor, max_iter, r)
end = time.time()
print("Run time in seconds: ", end-start)

In [None]:
print(A[0].shape)
print(A[1].shape)
print(A[2].shape)
print(A[3].shape)

In [None]:
recon_tensor = cp_als.reconstruct_Three_Way_Tensor(A[0], A[1], A[2])

In [None]:
print(recon_tensor)

In [11]:
start = time.time()
X = parafac(X_tensor, r, n_iter_max= max_iter, init="random", verbose=True)[1]
end = time.time()
print("Run time: ", end-start)

reconstruction error=0.9947566390037537
iteration 1,  reconstraction error: 0.9933165311813354, decrease = 0.0014401078224182129, unnormalized = 1398.4619140625
iteration 2,  reconstraction error: 0.9925410747528076, decrease = 0.000775456428527832, unnormalized = 1397.3701171875
iteration 3,  reconstraction error: 0.9920414090156555, decrease = 0.0004996657371520996, unnormalized = 1396.6666259765625
iteration 4,  reconstraction error: 0.9916770458221436, decrease = 0.0003643631935119629, unnormalized = 1396.1536865234375
iteration 5,  reconstraction error: 0.9913944005966187, decrease = 0.00028264522552490234, unnormalized = 1395.7557373046875
iteration 6,  reconstraction error: 0.991166889667511, decrease = 0.00022751092910766602, unnormalized = 1395.4354248046875
iteration 7,  reconstraction error: 0.9909786581993103, decrease = 0.0001882314682006836, unnormalized = 1395.17041015625
iteration 8,  reconstraction error: 0.9908198118209839, decrease = 0.00015884637832641602, unnormali

iteration 67,  reconstraction error: 0.9889544248580933, decrease = 8.404254913330078e-06, unnormalized = 1392.320556640625
iteration 68,  reconstraction error: 0.9889463186264038, decrease = 8.106231689453125e-06, unnormalized = 1392.3092041015625
iteration 69,  reconstraction error: 0.9889383316040039, decrease = 7.987022399902344e-06, unnormalized = 1392.2979736328125
iteration 70,  reconstraction error: 0.9889305830001831, decrease = 7.748603820800781e-06, unnormalized = 1392.2869873046875
iteration 71,  reconstraction error: 0.9889230132102966, decrease = 7.569789886474609e-06, unnormalized = 1392.2763671875
iteration 72,  reconstraction error: 0.9889156222343445, decrease = 7.3909759521484375e-06, unnormalized = 1392.2659912109375
iteration 73,  reconstraction error: 0.9889084696769714, decrease = 7.152557373046875e-06, unnormalized = 1392.255859375
iteration 74,  reconstraction error: 0.9889016151428223, decrease = 6.854534149169922e-06, unnormalized = 1392.2462158203125
iterati

In [None]:
recon_tensor_tl = cp_als.reconstruct_Three_Way_Tensor(X[0], X[1], X[2])

In [None]:
print(recon_tensor_tl)

In [8]:
print(X[0], X[1], X[2])

tensor([[-0.0908,  0.2785,  0.0415,  ...,  0.1439,  0.0776, -0.0684],
        [ 0.1006, -0.0416, -0.1343,  ..., -0.0298, -0.1549,  0.0404],
        [ 0.0878,  0.2651,  0.0461,  ..., -0.0802,  0.1021, -0.0468],
        ...,
        [ 0.0656, -0.0098,  0.0185,  ..., -0.0786, -0.0512, -0.1057],
        [ 0.0534,  0.1074, -0.0121,  ..., -0.0820, -0.0613,  0.0450],
        [-0.0439,  0.0093, -0.2031,  ..., -0.0457,  0.0633,  0.0759]]) tensor([[ 0.4543, -0.9571,  0.8842,  ..., -0.3643, -0.4478, -0.5704],
        [ 0.3822,  0.6023,  0.6620,  ...,  0.8507,  0.6056,  0.6880],
        [ 0.4397,  0.5664, -0.0469,  ...,  0.8737,  0.1434, -0.7347],
        ...,
        [ 0.1160,  0.2364,  0.1595,  ..., -0.2168, -0.4983,  0.0615],
        [-0.0612,  0.3396, -0.3239,  ..., -0.1391, -0.1278,  0.5554],
        [-0.2399,  0.6649,  0.3770,  ...,  0.1180,  0.2870, -1.2194]]) tensor([[ 9.8575e-01,  3.5923e-01,  3.2304e-01,  4.2593e-01, -8.5133e-01,
          2.6809e-01, -3.0701e-01, -2.9293e-04,  9.8084e-0