In [1]:
import torch
import tensorly as tl
from scipy import linalg

Using numpy backend.


In [2]:
tl.set_backend('pytorch')

Using pytorch backend.


## The main function of SEMIBAT algorithm in paper

In [21]:
def semiBAT(X, Y, alpha, lamda, k):    ## k -- rank of CP decomposition, the size of core tensor
    Nsim = 3 ; tolerance = 10**(-4)   ## simulation times and torelance
    ## set constant
    nu_max = phi_max = psi_max = 10**6
    rho = 1.15
    
    size = X.size()
    m = size[0] ; t = size[2] ; n = size[3] ; l = Y.size(0)
    norm_X = torch.norm(X.view(-1))
    num_labels = Y.size(1)   ### c
    
    ## set initialization
    D = torch.cat((torch.eye(l), torch.zeros(l,n-l)),dim = 1)
    B = torch.rand(m,k) ; P = B
    T = torch.rand(t,k)
    A = torch.rand(k,n)                         ## since transpose after, flip  the number of rows and columns
    A = torch.nn.init.orthogonal_(A).t()        ## orthogonal transfer to row orthogonal(if not rows = cols)
    Q = A                                      ### final size: Q,A = (n,k)
    W = torch.rand(k,num_labels)
    
    y = torch.zeros(m,k)
    Phi = torch.zeros(n,k)
    Psi = torch.zeros(k,k)
    
    nu = phi = psi = 10**(-6)
    
    ## in case construct estimated X
    core = torch.zeros(k,k,k,k)
    for c in torch.arange(0,k):
        core[c,c,c,c] = 1           ### construct core matrix  
        
    for i in torch.arange(0,Nsim):
            ## update B and P
            E = tl.tenalg.khatri_rao([P,T,A])             
            ETE = P.t().matmul(P) * T.t().matmul(T) * A.t().matmul(A)
            B = (2 * tl.unfold(X,0).matmul(E) + nu * P + y).matmul(torch.inverse(2*ETE + nu * torch.ones(k,k)))
            
            F = tl.tenalg.khatri_rao([B,T,A])
            FTF = B.t().matmul(B) * T.t().matmul(T) * A.t().matmul(A)
            P = (2 * tl.unfold(X,1).matmul(F) + nu * B - y).matmul(torch.inverse(2*FTF + nu * torch.ones(k,k)))
            
            ## update T
            G = tl.tenalg.khatri_rao([B,P,A])
            GTG = B.t().matmul(B) * P.t().matmul(P) * A.t().matmul(A)
            T = (tl.unfold(X,2).matmul(G)).matmul(torch.inverse(GTG))
            
            ## update A   
            ## A_x,A_y,A_z is the intermediate in equation 12(for computing sylvester equation),
            ## not related to global parameter           
            H = tl.tenalg.khatri_rao([B,P,T])
            HTH = B.t().matmul(B) * P.t().matmul(P) * T.t().matmul(T)   
            A_x = psi * Q.matmul(Q.t())   ## n*n tensor
            A_y = 2 * HTH + 2 * alpha * W.matmul(W.t()) + phi * torch.ones(k,k)     ## k * k tensor
            A_z = (2 * tl.unfold(X,3).matmul(H) + 2 * alpha * D.t().matmul(Y).matmul(W.t()) + \
                               (phi + psi) * Q) + Phi - Q.matmul(Psi)
            
            A = linalg.solve_sylvester(A_x.numpy(),A_y.numpy(),A_z.numpy())  ##  solve sylvester equation
            A = torch.tensor(A)
            
            ## update Q
            #print(psi * A.matmul(A.t()) + phi*torch.ones(n,n))
            #AA = psi * A.matmul(A.t())
            #print(AA.size())
            Q = torch.inverse(psi * A.matmul(A.t()) + phi*torch.ones(n,n)).matmul((phi + psi) * A - \
                                                                       Phi - A.matmul(Psi.t()))
            #Q, LU = torch.gesv((phi + psi) * A - Phi - A.matmul(Psi.t()),psi * A.matmul(A.t()) + phi*torch.ones(n,n))
                        ## update W
            ## W_x,W_y,W_z is the intermediate in equation 12(for computing sylvester equation),
            ## not related to global parameter 
            epsilon = 10**(-6)
            for j in torch.arange(0,100):
                W1 = W
                Omega = torch.diag(1/ (2*(torch.norm(W.float(),dim = 0) + epsilon)))
                W_x = 2 * A.t().matmul(D.t()).matmul(D).matmul(A)   ## k*k tensor
                W_y = nu * Omega                                     ## c*c tensor
                W_z = 2 * A.t().matmul(D.t()).matmul(Y)             ## k*c tensor
                W = linalg.solve_sylvester(W_x.numpy(),W_y.numpy(),W_z.numpy())  ##  solve sylvester equation
                W = torch.tensor(W)
                # print(torch.norm((W - W1).view(-1)))
                if torch.norm((W - W1).view(-1)) < epsilon:
                    break
                
            
            ## update y,Phi, Psi
            y = y + nu * (P - B)
        
            Phi = Phi + phi * (Q - A)
            Psi = Psi + psi * (Q.t().matmul(A) - torch.zeros(k,k))
            
            ## update nu, phi, psi
            nu = min(rho * nu, nu_max)
            phi = min(rho * phi, phi_max)
            psi = min(rho * psi, psi_max)
            
            ### detect convergence
            X_new = tl.tucker_to_tensor(core,[B,P,T,A])  ## reconstruct estimated X
            # print(X_new.contiguous().view(-1))
            norm_X_new = torch.norm(X_new.contiguous().view(-1))
            if abs(norm_X - norm_X_new) <= tolerance:
                break
    
    matrices = [B,P,T,A]
    return(X_new)#,matrices)    
    

## Get data

In [17]:
import torch.distributions as tdist
n = tdist.Normal(torch.tensor(10.0), torch.tensor(5.0))
X = n.sample([10,10,10,10])
Y = torch.zeros(10,2)
Y[0:5,0] = 1
Y[5:10,1] = 1

In [18]:
Y

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.]])

In [19]:
X[0,0]

tensor([[ 5.7318,  5.6629, 10.2096, 11.1726, 19.5394, 14.2556,  6.1664,  4.8568,
          2.5029, 17.8772],
        [10.4028,  8.1539,  0.1932, 13.5247, 21.5223, 16.2710,  6.8368,  6.4153,
          6.8005,  5.4829],
        [ 2.6562, 14.3059,  9.6156,  6.9111, 10.7612,  5.7026, -0.7631,  5.1648,
          2.2546,  2.7816],
        [11.9526,  7.2200,  6.3333, 10.0906, -0.7316, 13.2061, 15.5956,  6.6772,
         17.5114, 12.8520],
        [ 2.3988,  9.2151, 11.1017, 18.2231, -1.4917, 10.3898,  9.7182, 16.0403,
         15.7369, 10.8352],
        [12.1595,  9.7827, -2.7010, 16.2035,  7.0356,  2.5066, 10.7060,  6.8664,
         10.4408,  9.3649],
        [ 2.6479, 10.6861, 19.0438, 20.2406,  5.1012,  5.1796, 12.1425, 17.3545,
          1.9534,  7.5155],
        [ 1.6773, 12.3488, 10.4153, 13.6157,  7.6716,  7.7400,  3.3476, 13.0081,
          7.1625,  7.0255],
        [ 4.4469,  9.5234, 18.7330,  7.5996,  7.5890,  7.3923,  5.8313,  9.6815,
          8.1342,  7.5134],
        [ 8.3495,  

In [22]:
semiBAT(X, Y, alpha = 1, lamda = 1, k = 6)

tensor([[[[0.2767, 0.3320, 0.3097,  ..., 0.2794, 0.2999, 0.3019],
          [0.2824, 0.3389, 0.3161,  ..., 0.2853, 0.3061, 0.3081],
          [0.2757, 0.3308, 0.3086,  ..., 0.2785, 0.2988, 0.3008],
          ...,
          [0.2766, 0.3318, 0.3095,  ..., 0.2793, 0.2997, 0.3017],
          [0.2829, 0.3395, 0.3166,  ..., 0.2857, 0.3066, 0.3087],
          [0.2667, 0.3201, 0.2985,  ..., 0.2694, 0.2891, 0.2910]],

         [[0.2807, 0.3366, 0.3142,  ..., 0.2834, 0.3043, 0.3061],
          [0.2846, 0.3414, 0.3186,  ..., 0.2875, 0.3085, 0.3104],
          [0.2768, 0.3320, 0.3098,  ..., 0.2795, 0.3000, 0.3019],
          ...,
          [0.2814, 0.3375, 0.3149,  ..., 0.2842, 0.3050, 0.3070],
          [0.2872, 0.3444, 0.3215,  ..., 0.2900, 0.3113, 0.3132],
          [0.2698, 0.3236, 0.3020,  ..., 0.2725, 0.2925, 0.2943]],

         [[0.2818, 0.3380, 0.3154,  ..., 0.2846, 0.3054, 0.3073],
          [0.2871, 0.3445, 0.3214,  ..., 0.2900, 0.3111, 0.3131],
          [0.2774, 0.3327, 0.3104,  ..., 0