In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from bsde_solver.tensor.tensor_train import TensorTrain, left_unfold, right_unfold
from bsde_solver.tensor.tensor_core import TensorCore
from bsde_solver.tensor.tensor_network import TensorNetwork

import numpy as np

np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)


In [3]:
tt = TensorTrain(shape=[4, 4, 4, 4], ranks=[1, 3, 3, 3, 1])
tt.randomize()

Implementation of several functions and algorithms with tensor train decomposition using the following random tensor train (matrix product state):

<center><img src="./images/test_tt.png" style="width: 40%"/></center>

In [4]:
tt

TensorNetwork(
    core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}),
    core_1: TensorCore(r_1 {3}, m_2 {4}, r_2 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)

### Left and Right unfoldings

Left unfolding:
<center><img src="./images/left_unfold.png" style="width: 40%"/></center>

Right unfolding:
<center><img src="./images/right_unfold.png" style="width: 40%"/></center>

In [5]:
R = right_unfold(tt[1])
L = left_unfold(tt[1])

print(R)
print(L)

core_1: TensorCore(r_1 {3}, m_2+r_2 {12})
core_1: TensorCore(r_1+m_2 {12}, r_2 {3})


### Left and Right orthogonalizations


Left orthogonalization:
<center><img src="./images/tt_left_ortho.png" style="width: 40%"/></center>

Right orthogonalization:
<center><img src="./images/tt_right_ortho.png" style="width: 40%"/></center>


In [6]:
tt1 = tt.copy()
tt2 = tt.copy()

In [7]:
tt[0].view(np.ndarray), tt1[0].view(np.ndarray), tt2[0].view(np.ndarray)

(array([[[ 0.44,  0.92, -0.93],
         [-0.1 , -0.79, -0.74],
         [-0.45, -0.29, -0.9 ],
         [-0.7 , -0.34,  0.22]]]),
 array([[[ 0.44,  0.92, -0.93],
         [-0.1 , -0.79, -0.74],
         [-0.45, -0.29, -0.9 ],
         [-0.7 , -0.34,  0.22]]]),
 array([[[ 0.44,  0.92, -0.93],
         [-0.1 , -0.79, -0.74],
         [-0.45, -0.29, -0.9 ],
         [-0.7 , -0.34,  0.22]]]))

In [8]:
tt1.orthonormalize(mode="left")

<center><img src="./images/tt_left_ortho_id.png" style="width: 30%"/></center>

In [9]:
L = left_unfold(tt1[1]).view(np.ndarray)
L.T @ L

array([[ 1., -0., -0.],
       [-0.,  1.,  0.],
       [-0.,  0.,  1.]])

In [10]:
tt2.orthonormalize(mode="right")

<center><img src="./images/tt_right_ortho_id.png" style="width: 30%"/></center>

In [11]:
R = right_unfold(tt2[3]).view(np.ndarray)
R @ R.T

array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]])

### Left & Right part of the tensor train

Left part :
<center><img src="./images/left_contract.png" style="width: 50%"/></center>
Right part :
<center><img src="./images/right_contract.png" style="width: 50%"/></center>

Then, we can write the tensor train as:

<center><img src="./images/left_right_repr.png" style="width: 80%"/></center>


## Alternating Least Squares (ALS) algorithm

In [12]:
from opt_einsum import contract

def retraction_operator(tt, i):
    operator = tt.extract([f'core_{j}' for j in range(tt.order) if j != i])
    return operator

In [13]:
tt3 = tt.copy()
tt3.orthonormalize(mode="right")

print(tt3)

P1 = retraction_operator(tt3, 0)
print("\nRectraction operator (1st):")
print(P1)
print(P1.contract())

TensorNetwork(
    core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}),
    core_1: TensorCore(r_1 {3}, m_2 {4}, r_2 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)

Rectraction operator (1st):
TensorNetwork(
    core_1: TensorCore(r_1 {3}, m_2 {4}, r_2 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)
TensorCore(r_1 {3}, m_2 {4}, m_3 {4}, m_4 {4}, r_4 {1})


In [14]:
tt3.orthonormalize(mode="left")
tt3.orthonormalize(mode="right", start=1)

P2 = retraction_operator(tt3, 1)
print("\nRectraction operator (2nd):")
print(P2)
print(P2.contract())

P2_mat = P2.contract().unfold(('r_1', 'r_2'), -1).view(np.ndarray)
(P2_mat @ P2_mat.T)


Rectraction operator (2nd):
TensorNetwork(
    core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)
TensorCore(r_0 {1}, m_1 {4}, r_1 {3}, r_2 {3}, m_3 {4}, m_4 {4}, r_4 {1})


array([[ 1., -0., -0.,  0.,  0., -0.,  0.,  0.,  0.],
       [-0.,  1.,  0., -0.,  0., -0.,  0.,  0., -0.],
       [-0.,  0.,  1.,  0., -0.,  0.,  0., -0.,  0.],
       [ 0., -0.,  0.,  1., -0.,  0.,  0., -0.,  0.],
       [ 0.,  0., -0., -0.,  1.,  0., -0.,  0., -0.],
       [-0., -0.,  0.,  0.,  0.,  1., -0.,  0., -0.],
       [ 0.,  0.,  0.,  0., -0., -0.,  1., -0., -0.],
       [ 0.,  0., -0., -0.,  0.,  0., -0.,  1.,  0.],
       [ 0., -0.,  0.,  0., -0., -0., -0.,  0.,  1.]])

In [15]:
tt3.orthonormalize(mode="left")
tt3.orthonormalize(mode="right", start=2)

P3 = retraction_operator(tt3, 2)
print(P3.contract())
P3_mat = P3.contract().unfold(('r_2', 'r_3'), -1).view(np.ndarray)
(P3_mat @ P3_mat.T)

TensorCore(r_0 {1}, m_1 {4}, m_2 {4}, r_2 {3}, r_3 {3}, m_4 {4}, r_4 {1})


array([[ 1., -0.,  0.,  0.,  0., -0.,  0., -0., -0.],
       [-0.,  1.,  0., -0.,  0.,  0., -0.,  0., -0.],
       [ 0.,  0.,  1.,  0., -0., -0.,  0., -0.,  0.],
       [ 0., -0.,  0.,  1., -0.,  0.,  0.,  0.,  0.],
       [ 0.,  0., -0., -0.,  1.,  0., -0.,  0.,  0.],
       [-0.,  0., -0.,  0.,  0.,  1.,  0., -0.,  0.],
       [ 0., -0.,  0.,  0., -0.,  0.,  1., -0.,  0.],
       [-0.,  0., -0.,  0.,  0., -0., -0.,  1.,  0.],
       [-0., -0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]])

In [16]:
ttt = tt.copy()
ttt.orthonormalize(mode="left")

#### Rename indices

In [49]:
tt4 = tt.copy()
tt5 = tt4.rename("m_*", "n_*", inplace=False)
tt4, tt5

(TensorNetwork(
     core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}),
     core_1: TensorCore(r_1 {3}, m_2 {4}, r_2 {3}),
     core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
     core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
 ),
 TensorNetwork(
     core_0: TensorCore(r_0 {1}, n_1 {4}, r_1 {3}),
     core_1: TensorCore(r_1 {3}, n_2 {4}, r_2 {3}),
     core_2: TensorCore(r_2 {3}, n_3 {4}, r_3 {3}),
     core_3: TensorCore(r_3 {3}, n_4 {4}, r_4 {1})
 ))

#### Micro-optimization

At each step of the ALS algorithm, we need to compute the following expression:

$$P_i^T A P_i 

In [114]:
def ALS(A, b, n_iter=10, ranks=None):
    tt = TensorTrain(b.shape, b.ranks if ranks is None else ranks)
    tt.randomize()
    tt.orthonormalize(mode="right", start=1)

    def get_idx(j):
        if j == 0: indices = (b[j].indices[0], *tt[j].indices[1:])
        elif j == tt.order-1: indices = (*tt[j].indices[:-1], b[j].indices[2])
        else: indices = tt[j].indices
        return indices

    A1 = A.rename("m_*", "p_*", inplace=False)
    A2 = A.rename("n_*", "q_*", inplace=False).rename("m_*", "p_*", inplace=False)
    c = b.copy().contract().rename("m_*", "q_*", inplace=False)
    T = TensorNetwork(cores=[A,A1,A2,c], names=['A','A1','A2','c']).contract()#indices=[f"m_{i+1}" for i in range(len(b.shape))])

    print(T)

    for i in range(n_iter):
        # Left half sweep
        for j in range(tt.order-1):
            # Micro optimization
            P = retraction_operator(tt, j)
            V = P.contract(T, indices=get_idx(j))

            core_curr = tt.cores[f"core_{j}"]
            core_next = tt.cores[f"core_{j+1}"]

            L = left_unfold(V).view(np.ndarray)
            R = right_unfold(core_next).view(np.ndarray)

            Q, S = np.linalg.qr(L)
            W = S @ R

            tt.cores[f"core_{j}"] = TensorCore.like(Q, core_curr)
            tt.cores[f"core_{j+1}"] = TensorCore.like(W, core_next)

        # Right half sweep
        for j in range(tt.order-1, 0, -1):
            # Micro optimization
            P = retraction_operator(tt, j)
            V = P.contract(T, indices=get_idx(j))

            core_prev = tt.cores[f"core_{j-1}"]
            core_curr = tt.cores[f"core_{j}"]

            L = left_unfold(core_prev).view(np.ndarray)
            R = right_unfold(V).view(np.ndarray)

            Q, S = np.linalg.qr(R.T)
            W = L @ S.T

            tt.cores[f"core_{j-1}"] = TensorCore.like(W, core_prev)
            tt.cores[f"core_{j}"] = TensorCore.like(Q.T, core_curr)

    return tt


# b = TensorTrain([4, 4, 4, 4, 4], [1, 3, 3, 3, 3, 1])
b = TensorTrain([4, 4, 4], [1, 3, 3, 1])
b.randomize()
b.rename('r_*', 't_*')
b.orthonormalize(mode="right")

A = TensorCore(np.random.randn(4, 4, 4, 4, 4, 4), ['n_1', 'n_2', 'n_3', 'm_1', 'm_2', 'm_3'])
# print(b, A)

import time

start_time = time.time()
x = ALS(A, b, n_iter=10)
print(f"Elapsed time: {(time.time() - start_time):.5f}s")

TensorCore(m_1 {4}, m_2 {4}, m_3 {4}, t_0 {1}, t_3 {1})
Elapsed time: 0.01300s


In [80]:
print(A.view(np.ndarray).shape, b.contract().view(np.ndarray).shape)
c = np.linalg.tensorsolve(A.view(np.ndarray), b.contract().view(np.ndarray))

(4, 4, 4, 4, 4, 4) (1, 4, 4, 4, 1)


LinAlgError: Last 2 dimensions of the array must be square

In [76]:
xx = x.contract()
bb = b.contract()
print("Reconstruction error:", np.linalg.norm(xx - bb))

Reconstruction error: 1.0707434133867275


### Tensor decomposition

In [None]:
A = np.arange(10*10*10*10*10).reshape(10, 10, 10, 10, 10)
Att = TensorTrain.from_tensor(A, [1, 2, 2, 2, 2, 1])

print("Decomposition error:", np.linalg.norm(A - Att.contract().squeeze()))

Decomposition error: 1.1397139833714047e-05


In [None]:
Id = np.eye(4*4).reshape(4, 4, 4, 4)
I = TensorTrain.from_tensor(Id, [1, 2, 2, 2, 1])

print("Identity error:", np.linalg.norm(Id - I.contract().squeeze()))


id = TensorNetwork([TensorCore(Id, ['i', 'j', 'k', 'l'])])
U = TensorNetwork([TensorCore(np.random.randn(4, 4), ['i', 'j'])])

I.rename("m_1", "i")
I.rename("m_2", "j")
I.rename("m_3", "k")
I.rename("m_4", "l")

print(I, )

Identity error: 3.7416573867739413
TensorNetwork(
    core_0: TensorCore(r_0 {1}, i {4}, r_1 {2}),
    core_1: TensorCore(r_1 {2}, j {4}, r_2 {2}),
    core_2: TensorCore(r_2 {2}, k {4}, r_3 {2}),
    core_3: TensorCore(r_3 {2}, l {4}, r_4 {1})
)
