In [6]:
import tensorly as tl
from utils import *
import numpy as np
from numpy import linalg as la
from sympy import *
from sympy.solvers.inequalities import *
from sympy.polys import Poly
from sympy.abc import x
from sympy.solvers.solveset import linsolve
import time
from joblib import Parallel, delayed

In [7]:
#generates random rank 3 tensors
def rank_tree(size = 100):
    return (low_tensor(size) + low_tensor(size) + low_tensor(size)) 

# random tensor of some dim
def rand_tensor(size = 100000):
    return tl.tensor(np.random.randint(1, size, size=(3,2,2)))*1.0

# generates rank 1 tensors
def low_tensor(size = 100):
    a = np.random.randint(1, size, size=3) 
    b = np.random.randint(1, size, size=2)
    c = np.random.randint(1, size, size=2)
    tens = tl.tensor(np.kron(np.kron(a, b), c).reshape(3, 2, 2)) * 1.0
    return tens

In [8]:

def proc_tensor(t):
    tens = t.copy()
    M_b = Matrix([[Matrix(tens[:,0,0]), Matrix(tens[:,0,1]), Matrix(tens[:,1,1])]]).transpose()
    M_a = Matrix([[Matrix(tens[:,0,0]), Matrix(tens[:,0,1]), Matrix(tens[:,1,0])]]).transpose()
    if abs(M_a.det()) < 10:
        print("warning det too small")
    a = 1
    b = M_b.det() / M_a.det()
    M = Matrix([[tens[0][0][0],0,a,0,0],[0,tens[0][0][1],b,0,0],
                [tens[1][0][0],0,0,a,0],[0,tens[1][0][1],0,b,0],
                [tens[2][0][0],0,0,0,a],[0,tens[2][0][1],0,0,b]])
    R = Matrix([Matrix(tens[0,1]),Matrix(tens[1,1]),Matrix(tens[2,1])])
    sol = la.solve(np.array(M.T@M,dtype = "float"), np.array(M.T@R, dtype = "float"))
    ret = np.append(np.array([b], dtype = "float"),sol)
    return np.all(ret >=0)


def test(tens):
    ret = [0] *11
    if proc_tensor(tens) or proc_tensor(mat_trans(tens)) or proc_tensor(mat_inv(tens)) or proc_tensor(rotate(tens)):
        ret[1] = 1
        ret[0] = 1
    if check_simple(tens):
        ret[2] = 1
        if ret[0] == 0:
            ret[8] = 1
        ret[0] = 1
    r2 = r2_sub(tens, upper = 2)
    r4 = check_r4(tens)
    if r4  and r2:
        ret[3] = 1
    if not r4 and r2:
        ret[6] = 1
        if ret[0] == 0:
            ret[7] = 1
    return ret

def test_r2(tens):
    A = Matrix(tens[0]) * Matrix(tens[1]).inv()
    res1 = Matrix(A.eigenvects()[0][2])
    res2 = Matrix(A.eigenvects()[1][2])

    B = Matrix(tens[0]).transpose() * Matrix(tens[1]).transpose().inv()
    res3 = Matrix(B.eigenvects()[0][2])
    res4 = Matrix(B.eigenvects()[1][2])
    
    T1 = np.kron(res1,res3) * 1.0
    T2 = np.kron(res2,res4) * 1.0
    assert(np.all(T1 >= 0) or np.all(T1 <= 0))
    assert(np.all(T2 >= 0) or np.all(T2 <= 0))
    T1 = Matrix(abs(T1))
    T2 = Matrix(abs(T2))

    a = symbols('a')
    M = Matrix([[T1,T2]])
    A1 = Matrix([a,a,1,1])
    A2 = Matrix([a,1,a,1])
    M1 = Matrix([[M,A1,Matrix(tens[2].reshape(4))]])
    M2 = Matrix([[M,A2,Matrix(tens[2].reshape(4))]])

    a1 = solve(M1.det())[0]
    A1 = Matrix([a1,a1,1,1])
    M1 = Matrix([[M,A1,Matrix(tens[2].reshape(4))]])
    M1_sub = M1[0:3,0:3]
    R = M1[0:3,3]
    d1 = False
    if abs(M1_sub.det()) > 0:
        c1 = np.array(M1_sub.inv() @ R, dtype = "float")
        d1 = np.all(c1 >=0)

    a2 = solve(M2.det())[0]
    A2 = Matrix([a2,1,a2,1])
    M2 = Matrix([[M,A2,Matrix(tens[2].reshape(4))]])
    M2_sub = M2[0:3,0:3]
    R = M2[0:3,3]
    d2 = False
    if abs(M2_sub.det()) > 0:
        c2 = np.array(M2_sub.inv() @ R, dtype = "float")
        d2 = np.all(c2 >=0)
    return d1 or d2

def check_comb(tens):
    t = tens.copy()
    a = check_r2(tl.tensor([t[0],t[1]]))
    b = check_r2(tl.tensor([t[0],t[2]]))
    c = check_r2(tl.tensor([t[1],t[2]]))
    a1 = False
    a2 = False
    a3 = False
    if a:
        a1 = test_r2(t)
        if a1:
            return a1
    if b:
        a2 = test_r2(tl.tensor([t[0],t[2],t[1]]))
        if a2:
            return a2
    if c:
        a3 = test_r2(tl.tensor([t[1],t[2],t[0]]))
        if a3:
            return a3
    return a1 or a2 or a3

In [9]:

# assumes the nonnegative rank 2 subtensor are the first two slices
def test_simple(tens, M, A1, tol = 0.0001):
    a,b = symbols('a,b')
    M1 = Matrix([[M,A1,Matrix(tens[2].reshape(4))]])
    a1 = solve(M1.det(),a)[0]
    A1 = A1.subs(a,a1)
    M1 = Matrix([[M,A1,Matrix(tens[2].reshape(4))]])
    M1_sub = M1[1:4,0:3]
    R = M1[1:4,3]
    res1 = M1_sub.inv() @ R
    res1[0] = fraction(simplify(res1[0]))[0] * fraction(simplify(res1[0]))[1]
    res1[1] = fraction(simplify(res1[1]))[0] * fraction(simplify(res1[1]))[1]
    res1[2] = fraction(simplify(res1[2]))[0] * fraction(simplify(res1[2]))[1]
    a2 = fraction(simplify(a1))[0] * fraction(simplify(a1))[1]
    s1 = solveset(a2 >= 0, b, S.Reals)
    s2 = solveset(res1[0] >= 0, b, S.Reals)
    s3 = solveset(res1[1] >=0,b,S.Reals)
    s4 = solveset(res1[2] >=0,b,S.Reals)
    s5 = solveset(b >=0,b,S.Reals)
    sol = Intersection(s1,s2,s3,s4,s5)
    return sol.measure > tol


# initializes the matrix with the appropriate rank 1 terms given the tensors has nonneg rank 2 subtensor
def init_mat(tens):
    A = tens[0] @ la.inv(tens[1])
    eig = la.eig(A)
    B = tens[0].T @ la.inv(tens[1]).T
    eig1 = la.eig(B)

    a1 = eig[1][:,0]
    b1 = eig[1][:,1]

    a2 = eig1[1][:,0]
    b2 = eig1[1][:,1]

    T1 = np.kron(a1,b2)
    T2 = np.kron(b1,a2)
    #ret = Matrix([[T1,T2]])
    if eig[0][0] == eig[0][1]:
        print('identical eigenvalue, result should be discarded')
    
    # take the smaller tensor and check which decomposition to use
    temp = tl.tensor([tens[0],tens[1]])
    if mat_comb_small(temp,T1,T2):
        T1 = Matrix(abs(T1))  
        T2 = Matrix(abs(T2))
        ret = Matrix([[T1,T2]])
        #print(T1)
        
    T1 = np.kron(a1,a2)
    T2 = np.kron(b1,b2)
    if mat_comb_small(temp,T1,T2):
        T1 = Matrix(abs(T1))
        T2 = Matrix(abs(T2))
        ret = Matrix([[T1,T2]])
        #print(T1)
    return ret

# as before, checks if the matrix decompostion is valid by reconstructing it
def mat_comb_small(tens,K1,K2):
    T1 = Matrix(K1.reshape(4))
    T2 = Matrix(K2.reshape(4))
    pos1 = abs(sum(T1) / sum(abs(T1)))
    pos2 = abs(sum(T2) / sum(abs(T2)))
    if pos1 + pos2 != 2:
        print('returning')
        return False
    M = abs(Matrix([[T1,T2]]))
    R1 = tens[0].reshape(4)
    R2 = tens[1].reshape(4)
    sol1 = la.solve(np.array(M.T @ M, dtype = "float"), np.array(M.T @ R1, dtype = "float"))
    sol2 = la.solve(np.array(M.T @ M, dtype = "float"), np.array(M.T @ R2, dtype = "float"))
    a3 = [sol1[0],sol2[0]]
    b3 = [sol1[1],sol2[1]]
    m1 = np.kron(a3,abs(T1).reshape(1,4)).reshape(2,2,2)
    m2 = np.kron(b3,abs(T2).reshape(1,4)).reshape(2,2,2)
    #print(np.max(abs(m1+m2-tens)))
    return (np.max(abs(m1+m2-tens))<0.1)


    
# checks if a given 2x2x3 has nonneg rank 2 subtensor and if it does check if the third slice is a linear combination
def check_simple(tens):
    t = tens.copy()
    c1 = check_r2(tl.tensor([t[0],t[1]]))
    c2 = check_r2(tl.tensor([t[0],t[2]]))
    c3 = check_r2(tl.tensor([t[1],t[2]]))
    a1 = False
    a2 = False
    a3 = False
    if c1:
        a,b = symbols('a,b')
        A1 = Matrix([a*b,a,b,1])
        A2 = Matrix([a,a*b,1,b])
        A3 = Matrix([b,1,a*b,a])
        A4 = Matrix([1,b,a,a*b])
        temp = t
        M = init_mat(temp)
        a1 = test_simple(temp,M,A1) or test_simple(temp,M,A2) or test_simple(temp,M,A3) or test_simple(temp,M,A4)
        if a1:
            return a1
    if c2:
        a,b = symbols('a,b')
        A1 = Matrix([a*b,a,b,1])
        A2 = Matrix([a,a*b,1,b])
        A3 = Matrix([b,1,a*b,a])
        A4 = Matrix([1,b,a,a*b])
        temp = tl.tensor([t[0],t[2],t[1]])
        M = init_mat(temp)
        a2 = test_simple(temp,M,A1) or test_simple(temp,M,A2) or test_simple(temp,M,A3) or test_simple(temp,M,A4)
        if a2:
            return a2
    if c3:
        a,b = symbols('a,b')
        A1 = Matrix([a*b,a,b,1])
        A2 = Matrix([a,a*b,1,b])
        A3 = Matrix([b,1,a*b,a])
        A4 = Matrix([1,b,a,a*b])
        temp = tl.tensor([t[1],t[2],t[0]])
        M = init_mat(temp)
        a3 = test_simple(temp,M,A1) or test_simple(temp,M,A2) or test_simple(temp,M,A3) or test_simple(temp,M,A4)
        if a3:
            return a3
    return False

def loop_rotations(i):
    tens = rand_tensor()
    tens = rank_tree(size = 50)
    tens = tensors[i]
    ret = [0] *11
    if proc_tensor(tens) or proc_tensor(mat_trans(tens)) or proc_tensor(mat_inv(tens)) or proc_tensor(rotate(tens)):
        ret[1] = 1
        ret[0] = 1
    if check_simple(tens):
        ret[2] = 1
        if ret[0] == 0:
            ret[8] = 1
        ret[0] = 1
    r2 = r2_sub(tens, upper = 1)
    r4 = check_r4(tens)
    if r4:
        ret[4] = 1
    if r4 and r2:
        ret[3] = 1
    if r2 and not r4:
        ret[6] = 1
        if ret[0] == 0:
            ret[7] = 1
    tensors[i] = (tensors[i],ret[0])
    return ret

def loop_rotations_old(i):
    tens = rand_tensor()
    tens = rank_tree(size = 50)
    tens = tensors[i]
    ret = [0] *11
    if proc_tensor(tens) or proc_tensor(mat_trans(tens)) or proc_tensor(mat_inv(tens)) or proc_tensor(rotate(tens)):
        ret[1] = 1
        ret[0] = 1
    if check_comb(tens):
        ret[2] = 1
        if ret[0] == 0:
            ret[8] = 1
        ret[0] = 1
    r2 = r2_sub(tens, upper = 2)
    r4 = check_r4(tens)
    if r4:
        ret[4] = 1
    if r4 and r2:
        ret[3] = 1
    if not r4 and r2:
        ret[6] = 1
        if ret[0] == 0:
            ret[7] = 1
    return ret

In [10]:
total = 10000
tensors = []
for i in range(total):
    tensors.append(rand_tensor())

In [50]:
te = time.time()
total = 2000
results = Parallel(n_jobs=6)(delayed(loop_rotations)(i) for i in range(total))
res = [sum(x) for x in zip(*results)]
print(res)
print(time.time() - te)

[318, 282, 108, 86, 1010, 0, 187, 63, 36, 0, 0]
272.9294993877411


In [16]:
results = []
res = []
te = time.time()
for i in range(1000):
    tens = tensors[i]
    ret = [0] *11
    if proc_tensor(tens) or proc_tensor(mat_trans(tens)) or proc_tensor(mat_inv(tens)) or proc_tensor(rotate(tens)):
        ret[1] = 1
        ret[0] = 1
    if check_simple(tens):
        ret[2] = 1
        if ret[0] == 0:
            ret[8] = 1
        ret[0] = 1
    r2 = r2_sub(tens, upper = 1)
    r4 = check_r4(tens)
    if r4:
        ret[4] = 1
    if r4 and r2:
        ret[3] = 1
    if not r4 and r2:
        ret[6] = 1
        if ret[0] == 0:
            ret[7] = 1
    res.append((tensors[i],max(ret[0],ret[4], ret[0])))
    results.append(ret)
print(time.time() - te)
print([sum(x) for x in zip(*results)])

516.8704991340637
[191, 168, 60, 50, 480, 0, 101, 34, 23, 0, 0]


In [18]:
wrongs = []
for tens in res:
    if tens[1] != 1:
        if check_rank(tens[0],3,n=20):
            wrongs.append((tens[0],3))
        else:
            wrongs.append((tens[0],4))

In [20]:
counter = 0
for tens in wrongs:
    if tens[1] == 3:
        counter +=1
print(counter)

20


In [130]:
c2 = False

#while not (c2):
#    tens = rank_tree(size = 60)
#    c2 = check_r2(tl.tensor([tens[0],tens[1]]))
tens    

array([[[46165., 33523.],
        [50065., 40113.]],

       [[62570., 44942.],
        [67889., 52781.]],

       [[40435., 28681.],
        [50791., 37699.]]])

In [197]:
tens = rank_tree(size = 50)
temp = tl.tensor([tens[0],tens[1]])
sol = Matrix([[Matrix(temp[:,1,0]),Matrix(temp[:,0,1])]]).inv() * Matrix(temp[:,1,1])
np.kron(temp[:,0,1],np.kron([1,sol[1]],[0,1])).reshape(2,2,2) + np.kron(temp[:,1,0],np.kron([0,1],[1,sol[0]])).reshape(2,2,2) + np.kron(temp[:,0,0],np.kron([1,0],[1,0])).reshape(2,2,2)
A = Matrix(np.kron([0,1],[1,sol[0]]))
B = Matrix(np.kron([1,sol[1]],[0,1]))
C = Matrix(np.kron([1,0],[1,0]))
M = Matrix([[A,B,C]])
R1 = tens[2].reshape(4)
sols = la.solve(np.array(M.T @ M, dtype = "float"), np.array(M.T @ R1, dtype = "float"))
print(sols)
a = np.kron(np.append(temp[:,1,0],[tens[2,1,0]]),np.kron([0,1],[1,sol[0]])).reshape(3,2,2) 
b = np.kron(np.append(temp[:,0,1],[tens[2,0,1]]),np.kron([1,sol[1]],[0,1])).reshape(3,2,2) 
c = np.kron(np.append(temp[:,0,0],[tens[2,0,0]]),np.kron([1,0],[1,0])).reshape(3,2,2)
np.max(abs(a+b+c - tens))

[56914.14211389 96683.89083073 91028.        ]


5801.48093540504

In [186]:
sol

Matrix([
[ 0.46100202622914],
[0.563943647094393]])

In [188]:
np.all(np.array(sol) >=0)

True

In [189]:
def rank_three_comb(tens):
    temp = tl.tensor([tens[0],tens[1]])
    sol = Matrix([[Matrix(temp[:,1,0]),Matrix(temp[:,0,1])]]).inv() * Matrix(temp[:,1,1])
    A = Matrix(np.kron([0,1],[1,sol[0]]))
    B = Matrix(np.kron([1,sol[1]],[0,1]))
    C = Matrix(np.kron([1,0],[1,0]))
    M = Matrix([[A,B,C]])
    R1 = tens[2].reshape(4)
    sols = la.solve(np.array(M.T @ M, dtype = "float"), np.array(M.T @ R1, dtype = "float"))
    a = np.kron(np.append(temp[:,1,0],[tens[2,1,0]]),np.kron([0,1],[1,sol[0]])).reshape(3,2,2) 
    b = np.kron(np.append(temp[:,0,1],[tens[2,0,1]]),np.kron([1,sol[1]],[0,1])).reshape(3,2,2) 
    c = np.kron(np.append(temp[:,0,0],[tens[2,0,0]]),np.kron([1,0],[1,0])).reshape(3,2,2)
    return np.max(abs(a+b+c - tens)) < 1 and np.all(np.array(sol) >=0)
    

In [194]:
counter = 0
for i in range(10000):
    tens = rank_tree()
    a = rank_three_comb(tens)
    b = rank_three_comb(rotate(tens))
    c = rank_three_comb(rotate(rotate((tens))))
    d = rank_three_comb(rotate(rotate(rotate(tens))))
    if a or b or c or d:
        counter += 1
print(counter)

3
