In [2]:
import tensorly as tl
from utils import *
import numpy as np
from numpy import linalg as la
from sympy import *
from sympy.solvers.inequalities import *
from sympy.polys import Poly
from sympy.abc import x
from sympy.solvers.solveset import linsolve
import time
from joblib import Parallel, delayed

In [7]:
#generates random rank 3 tensors
def rank_tree(size = 100, s = (2,3,3)):
    return (low_tensor(size,s) + low_tensor(size,s) + low_tensor(size,s)) 

def rank_two(size = 100,s = (2,3,3)):
    return (low_tensor(size,s) + low_tensor(size,s)) 

# random tensor of some dim
def rand_tensor(size = 100000):
    return tl.tensor(np.random.randint(1, size, size=(3,2,2)))*1.0

In [3]:
def proc_tensor(t):
    tens = t.copy()
    M_b = Matrix([[Matrix(tens[:,0,0]), Matrix(tens[:,0,1]), Matrix(tens[:,1,1])]]).transpose()
    M_a = Matrix([[Matrix(tens[:,0,0]), Matrix(tens[:,0,1]), Matrix(tens[:,1,0])]]).transpose()
    if abs(M_a.det()) < 10:
        print("warning det too small")
    a = 1
    b = M_b.det() / M_a.det()
    M = Matrix([[tens[0][0][0],0,a,0,0],[0,tens[0][0][1],b,0,0],
                [tens[1][0][0],0,0,a,0],[0,tens[1][0][1],0,b,0],
                [tens[2][0][0],0,0,0,a],[0,tens[2][0][1],0,0,b]])
    R = Matrix([Matrix(tens[0,1]),Matrix(tens[1,1]),Matrix(tens[2,1])])
    sol = la.solve(np.array(M.T@M,dtype = "float"), np.array(M.T@R, dtype = "float"))
    ret = np.append(np.array([b], dtype = "float"),sol)
    return np.all(ret >=0)


def test(tens):
    ret = [0] *11
    if proc_tensor(tens) or proc_tensor(mat_trans(tens)) or proc_tensor(mat_inv(tens)) or proc_tensor(rotate(tens)):
        ret[1] = 1
        ret[0] = 1
    if check_simple(tens):
        ret[2] = 1
        if ret[0] == 0:
            ret[8] = 1
        ret[0] = 1
    r2 = r2_sub(tens, upper = 1)
    r4 = check_r4(tens)
    if r4:
        ret[4] = 1
    if r4 and r2:
        ret[3] = 1
    if r2 and not r4:
        ret[6] = 1
        if ret[0] == 0:
            ret[7] = 1
    return ret

In [4]:
# assumes the nonnegative rank 2 subtensor are the first two slices
def test_simple(tens, M, A1, tol = 0.0001):
    a,b = symbols('a,b')
    M1 = Matrix([[M,A1,Matrix(tens[2].reshape(4))]])
    a1 = solve(M1.det(),a)[0]
    A1 = A1.subs(a,a1)
    M1 = Matrix([[M,A1,Matrix(tens[2].reshape(4))]])
    M1_sub = M1[1:4,0:3]
    R = M1[1:4,3]
    res1 = M1_sub.inv() @ R
    res1[0] = fraction(simplify(res1[0]))[0] * fraction(simplify(res1[0]))[1]
    res1[1] = fraction(simplify(res1[1]))[0] * fraction(simplify(res1[1]))[1]
    res1[2] = fraction(simplify(res1[2]))[0] * fraction(simplify(res1[2]))[1]
    a2 = fraction(simplify(a1))[0] * fraction(simplify(a1))[1]
    s1 = solveset(a2 >= 0, b, S.Reals)
    s2 = solveset(res1[0] >= 0, b, S.Reals)
    s3 = solveset(res1[1] >=0,b,S.Reals)
    s4 = solveset(res1[2] >=0,b,S.Reals)
    s5 = solveset(b >=0,b,S.Reals)
    sol = Intersection(s1,s2,s3,s4,s5)
    return sol.measure > tol


# initializes the matrix with the appropriate rank 1 terms given the tensors has nonneg rank 2 subtensor
def init_mat(tens):
    A = tens[0] @ la.inv(tens[1])
    eig = la.eig(A)
    B = tens[0].T @ la.inv(tens[1]).T
    eig1 = la.eig(B)

    a1 = eig[1][:,0]
    b1 = eig[1][:,1]

    a2 = eig1[1][:,0]
    b2 = eig1[1][:,1]

    T1 = np.kron(a1,b2)
    T2 = np.kron(b1,a2)
    #ret = Matrix([[T1,T2]])
    ret = None
    if eig[0][0] == eig[0][1]:
        print('identical eigenvalue, result should be discarded')
    
    # take the smaller tensor and check which decomposition to use
    temp = tl.tensor([tens[0],tens[1]])
    if mat_comb_small(temp,T1,T2):
        T1 = Matrix(abs(T1))  
        T2 = Matrix(abs(T2))
        ret = Matrix([[T1,T2]])
        #print(T1)
        
    T1 = np.kron(a1,a2)
    T2 = np.kron(b1,b2)
    if mat_comb_small(temp,T1,T2):
        T1 = Matrix(abs(T1))
        T2 = Matrix(abs(T2))
        ret = Matrix([[T1,T2]])
        #print(T1)
    return ret

# as before, checks if the matrix decompostion is valid by reconstructing it
def mat_comb_small(tens,K1,K2):
    T1 = Matrix(K1.reshape(4))
    T2 = Matrix(K2.reshape(4))
    pos1 = abs(sum(T1) / sum(abs(T1)))
    pos2 = abs(sum(T2) / sum(abs(T2)))
    if pos1 + pos2 != 2:
        print('returning')
        return False
    M = abs(Matrix([[T1,T2]]))
    R1 = tens[0].reshape(4)
    R2 = tens[1].reshape(4)
    sol1 = la.solve(np.array(M.T @ M, dtype = "float"), np.array(M.T @ R1, dtype = "float"))
    sol2 = la.solve(np.array(M.T @ M, dtype = "float"), np.array(M.T @ R2, dtype = "float"))
    a3 = [sol1[0],sol2[0]]
    b3 = [sol1[1],sol2[1]]
    m1 = np.kron(a3,abs(T1).reshape(1,4)).reshape(2,2,2)
    m2 = np.kron(b3,abs(T2).reshape(1,4)).reshape(2,2,2)
    #print(np.max(abs(m1+m2-tens)))
    return (np.max(abs(m1+m2-tens))<0.1)


    
# checks if a given 2x2x3 has nonneg rank 2 subtensor and if it does check if the third slice is a linear combination
def check_simple(tens):
    t = tens.copy()
    c1 = check_r2(tl.tensor([t[0],t[1]]))
    c2 = check_r2(tl.tensor([t[0],t[2]]))
    c3 = check_r2(tl.tensor([t[1],t[2]]))
    a1 = False
    a2 = False
    a3 = False
    if c1:
        a,b = symbols('a,b')
        A1 = Matrix([a*b,a,b,1])
        A2 = Matrix([a,a*b,1,b])
        A3 = Matrix([b,1,a*b,a])
        A4 = Matrix([1,b,a,a*b])
        temp = t
        M = init_mat(temp)
        a1 = test_simple(temp,M,A1) or test_simple(temp,M,A2) or test_simple(temp,M,A3) or test_simple(temp,M,A4)
        if a1:
            return a1
    if c2:
        a,b = symbols('a,b')
        A1 = Matrix([a*b,a,b,1])
        A2 = Matrix([a,a*b,1,b])
        A3 = Matrix([b,1,a*b,a])
        A4 = Matrix([1,b,a,a*b])
        temp = tl.tensor([t[0],t[2],t[1]])
        M = init_mat(temp)
        a2 = test_simple(temp,M,A1) or test_simple(temp,M,A2) or test_simple(temp,M,A3) or test_simple(temp,M,A4)
        if a2:
            return a2
    if c3:
        a,b = symbols('a,b')
        A1 = Matrix([a*b,a,b,1])
        A2 = Matrix([a,a*b,1,b])
        A3 = Matrix([b,1,a*b,a])
        A4 = Matrix([1,b,a,a*b])
        temp = tl.tensor([t[1],t[2],t[0]])
        M = init_mat(temp)
        a3 = test_simple(temp,M,A1) or test_simple(temp,M,A2) or test_simple(temp,M,A3) or test_simple(temp,M,A4)
        if a3:
            return a3
    return False

In [8]:
tens1 = rand_tensor()
test_res = test(tens1)
cond1 = test_res[0] != 1 and test_res[4] != 1
while not cond1:
    tens1 = rand_tensor()
    test_res = test(tens1)
    cond1 = test_res[0] != 1 and test_res[4] != 1
tens1

array([[[55643., 89848.],
        [32690., 20439.]],

       [[44297.,  6836.],
        [ 7309.,  1857.]],

       [[99857., 86158.],
        [80916., 91776.]]])

In [9]:
test(tens1)

[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0]

In [11]:
tens = tens1

In [12]:
alpha_1,alpha_2, beta_0,beta_2 = symbols('alpha_1,alpha_2, beta_0,beta_2 ')
sub1 = tl.tensor([tens1[0] + alpha_1* tens1[1]+alpha_2 *tens[2],beta_0*tens[0] + tens1[1] + beta_2 * tens[2]])
sub2 = tl.tensor([tens1[2],tens1[1]])
sub1

array([[[44297.0*alpha_1 + 99857.0*alpha_2 + 55643.0,
         6836.0*alpha_1 + 86158.0*alpha_2 + 89848.0],
        [7309.0*alpha_1 + 80916.0*alpha_2 + 32690.0,
         1857.0*alpha_1 + 91776.0*alpha_2 + 20439.0]],

       [[55643.0*beta_0 + 99857.0*beta_2 + 44297.0,
         89848.0*beta_0 + 86158.0*beta_2 + 6836.0],
        [32690.0*beta_0 + 80916.0*beta_2 + 7309.0,
         20439.0*beta_0 + 91776.0*beta_2 + 1857.0]]], dtype=object)

In [77]:
k = symbols('k')
sub1 = tl.tensor([tens1[0] + k* tens1[2] ,tens1[1]])
sub2 = tl.tensor([tens1[2],tens1[1]])
sub1

array([[[2760.0*k + 10726.0, 51401.0*k + 73628.0],
        [66358.0*k + 93701.0, 64618.0*k + 33748.0]],

       [[89772.0, 86118.0],
        [68222.0, 64886.0]]], dtype=object)

In [14]:
sol = Matrix([[Matrix(sub1[1,0,:]),Matrix(sub1[0,1,:])]]).inv() @ Matrix(sub1[0,0,:])
int1 = solveset(sol[0] >=0,alpha_1,S.Reals)
int2 = solveset(sol[0] >=0,alpha_2,S.Reals)
int3 = solveset(sol[0] >=0,beta_0,S.Reals)
int4 = solveset(sol[0] >=0,beta_2,S.Reals)
int5 = solveset(sol[1] >=0,alpha_1,S.Reals)
int6 = solveset(sol[1] >=0,alpha_2,S.Reals)
int7 = solveset(sol[1] >=0,beta_0,S.Reals)
int8 = solveset(sol[1] >=0,beta_2,S.Reals)
print(int1,int2)
Intersection(int1,int2)

ConditionSet(alpha_1, (-103329051.0*alpha_1 - 5106691968.0*alpha_2 - 1137287277.0)*(44297.0*alpha_1 + 99857.0*alpha_2 + 55643.0)/(30791165852783.0*alpha_1*beta_0 + 24721871796839.0*alpha_1*beta_2 - 1797002091815.0*alpha_1 + 120380781578400.0*alpha_2*beta_0 - 122020386260472.0*alpha_2*beta_2 - 195432666264528.0*alpha_2 + 100148710956049.0*beta_0 + 43152693208471.0*beta_2 - 37943937845149.0) + (6836.0*alpha_1 + 86158.0*alpha_2 + 89848.0)*(406694687.0*alpha_1 + 4502408988.0*alpha_2 + 1818969670.0)/(30791165852783.0*alpha_1*beta_0 + 24721871796839.0*alpha_1*beta_2 - 1797002091815.0*alpha_1 + 120380781578400.0*alpha_2*beta_0 - 122020386260472.0*alpha_2*beta_2 - 195432666264528.0*alpha_2 + 100148710956049.0*beta_0 + 43152693208471.0*beta_2 - 37943937845149.0) >= 0, Reals) ConditionSet(alpha_2, (-103329051.0*alpha_1 - 5106691968.0*alpha_2 - 1137287277.0)*(44297.0*alpha_1 + 99857.0*alpha_2 + 55643.0)/(30791165852783.0*alpha_1*beta_0 + 24721871796839.0*alpha_1*beta_2 - 1797002091815.0*alpha_1 +

Intersection(ConditionSet(alpha_1, (-103329051.0*alpha_1 - 5106691968.0*alpha_2 - 1137287277.0)*(44297.0*alpha_1 + 99857.0*alpha_2 + 55643.0)/(30791165852783.0*alpha_1*beta_0 + 24721871796839.0*alpha_1*beta_2 - 1797002091815.0*alpha_1 + 120380781578400.0*alpha_2*beta_0 - 122020386260472.0*alpha_2*beta_2 - 195432666264528.0*alpha_2 + 100148710956049.0*beta_0 + 43152693208471.0*beta_2 - 37943937845149.0) + (6836.0*alpha_1 + 86158.0*alpha_2 + 89848.0)*(406694687.0*alpha_1 + 4502408988.0*alpha_2 + 1818969670.0)/(30791165852783.0*alpha_1*beta_0 + 24721871796839.0*alpha_1*beta_2 - 1797002091815.0*alpha_1 + 120380781578400.0*alpha_2*beta_0 - 122020386260472.0*alpha_2*beta_2 - 195432666264528.0*alpha_2 + 100148710956049.0*beta_0 + 43152693208471.0*beta_2 - 37943937845149.0) >= 0, Reals), ConditionSet(alpha_2, (-103329051.0*alpha_1 - 5106691968.0*alpha_2 - 1137287277.0)*(44297.0*alpha_1 + 99857.0*alpha_2 + 55643.0)/(30791165852783.0*alpha_1*beta_0 + 24721871796839.0*alpha_1*beta_2 - 17970020918

In [23]:
sol.subs(alpha_1, 10)

Matrix([
[(-5106691968.0*alpha_2 - 2170577787.0)*(99857.0*alpha_2 + 498613.0)/(120380781578400.0*alpha_2*beta_0 - 122020386260472.0*alpha_2*beta_2 - 195432666264528.0*alpha_2 + 408060369483879.0*beta_0 + 290371411176861.0*beta_2 - 55913958763299.0) + (86158.0*alpha_2 + 158208.0)*(4502408988.0*alpha_2 + 5885916540.0)/(120380781578400.0*alpha_2*beta_0 - 122020386260472.0*alpha_2*beta_2 - 195432666264528.0*alpha_2 + 408060369483879.0*beta_0 + 290371411176861.0*beta_2 - 55913958763299.0)],
[                                               (86158.0*alpha_2 + 158208.0)*(-55643.0*beta_0 - 99857.0*beta_2 - 44297.0)/(2163448800.0*alpha_2*beta_0 - 2192915304.0*alpha_2*beta_2 - 3512259696.0*alpha_2 + 7333543653.0*beta_0 + 5218471527.0*beta_2 - 1004869593.0) + (99857.0*alpha_2 + 498613.0)*(89848.0*beta_0 + 86158.0*beta_2 + 6836.0)/(2163448800.0*alpha_2*beta_0 - 2192915304.0*alpha_2*beta_2 - 3512259696.0*alpha_2 + 7333543653.0*beta_0 + 5218471527.0*beta_2 - 1004869593.0)]])

In [None]:
sol = Matrix([[Matrix(sub1[:,1,0]),Matrix(sub1[:,0,1])]]).inv() @ Matrix(sub1[:,0,0])
int1 = solveset(sol[0] >=0,k,S.Reals)
int2 = solveset(sol[1] >=0,k,S.Reals)
print(int1,int2)
Intersection(int1,int2)

In [96]:
sol = Matrix([[Matrix(sub1[1,:,0]),Matrix(sub1[0,:,1])]]).inv() @ Matrix(sub1[1,:,1])
int1 = solveset(sol[0] >=0,k,S.Reals)
int2 = solveset(sol[1] >=0,k,S.Reals)
print(int1,int2)
Intersection(int1,int2)

Union(Interval(-oo, 0.839228248611671), Interval.open(0.868894143731446, oo)) Interval.open(-oo, 0.868894143731446)


Interval(-oo, 0.839228248611671)

This next cell shows that for any $k \in \mathbb{R}$, the $2\times 2\times 2$ tensor formed by adding two of the three slices and keeping the other slice will have nonnegative rank three, contrary to our prior beliefs that this may be sufficient.

In [100]:
sol = Matrix([[Matrix(sub1[0,:,0]),Matrix(sub1[1,:,1])]]).inv() @ Matrix(sub1[1,:,0])
int1 = solveset(sol[0] >=0,k,S.Reals)
int2 = solveset(sol[1] >=0,k,S.Reals)
print(int1,int2)
Intersection(int1,int2)

Interval.open(-1.33200825223388, oo) Union(Interval.open(-oo, -1.33200825223388), Interval(-1.33129595766151, oo))


Interval(-1.33129595766151, oo)

In [107]:
check(tens1)

True

In [17]:
check_rank(tens,3,n=100)

False