In [153]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [201]:
from bsde_solver.tensor.tensor_train import TensorTrain, left_unfold, right_unfold
from bsde_solver.tensor.tensor_core import TensorCore

import numpy as np

np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)


In [155]:
tt = TensorTrain(shape=[4, 4, 4, 4], ranks=[1, 3, 3, 3, 1])
tt.randomize()

Implementation of several functions and algorithms with tensor train decomposition using the following random tensor train (matrix product state):

<center><img src="images/test_tt.png" style="width: 40%"/></center>

In [156]:
tt

TensorNetwork(
    core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}),
    core_1: TensorCore(r_1 {3}, m_2 {4}, r_2 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)

### Left and Right unfoldings

Left unfolding:
<center><img src="images/left_unfold.png" style="width: 40%"/></center>

Right unfolding:
<center><img src="images/right_unfold.png" style="width: 40%"/></center>

In [175]:
R = right_unfold(tt[1])
L = left_unfold(tt[1])

print(R)
print(L)

core_1: TensorCore(r_1 {3}, m_2+r_2 {12})
core_1: TensorCore(r_1+m_2 {12}, r_2 {3})


### Left and Right orthogonalizations


Left orthogonalization:
<center><img src="images/tt_left_ortho.png" style="width: 40%"/></center>

Right orthogonalization:
<center><img src="images/tt_right_ortho.png" style="width: 40%"/></center>


In [176]:
tt1 = tt.copy()
tt2 = tt.copy()

In [177]:
tt[0].view(np.ndarray), tt1[0].view(np.ndarray), tt2[0].view(np.ndarray)

(array([[[-0.43,  0.53,  0.03],
         [-0.05, -0.29, -0.29],
         [ 0.16, -0.76, -0.33],
         [ 0.04, -0.51, -0.23]]]),
 array([[[-0.43,  0.53,  0.03],
         [-0.05, -0.29, -0.29],
         [ 0.16, -0.76, -0.33],
         [ 0.04, -0.51, -0.23]]]),
 array([[[-0.43,  0.53,  0.03],
         [-0.05, -0.29, -0.29],
         [ 0.16, -0.76, -0.33],
         [ 0.04, -0.51, -0.23]]]))

In [178]:
tt1.orthonormalize(mode="left")

<center><img src="images/tt_left_ortho_id.png" style="width: 30%"/></center>

In [179]:
L = left_unfold(tt1[1]).view(np.ndarray)
L.T @ L

array([[ 1.,  0.,  0.],
       [ 0.,  1., -0.],
       [ 0., -0.,  1.]])

In [180]:
tt2.orthonormalize(mode="right")

<center><img src="images/tt_right_ortho_id.png" style="width: 30%"/></center>

In [181]:
R = right_unfold(tt2[3]).view(np.ndarray)
R @ R.T

array([[ 1., -0.,  0.],
       [-0.,  1., -0.],
       [ 0., -0.,  1.]])

### i-th Left & Right part of the tensor train

Left part :
<center><img src="images/left_contract.png" style="width: 50%"/></center>
Right part :
<center><img src="images/right_contract.png" style="width: 50%"/></center>

Then, we can write the tensor train as:

<center><img src="images/left_right_repr.png" style="width: 80%"/></center>


## Alternating Least Squares (ALS) algorithm

In [182]:
from opt_einsum import contract

def retraction_operator(tt, i):
    operator = tt.extract([f'core_{j}' for j in range(tt.order) if j != i])
    return operator
    # struct = [
    #     [tt.cores[j], tt.indices[j]]
    #     for j in range(tt.order)
    #     if j != i
    # ]
    # # Flatten
    # struct = [item for sublist in struct for item in sublist]

    # # Output indices (internal indices, external indices)
    # if i != 0 and i != tt.order - 1:
    #     indices = [tt.indices[0][0], tt.indices[i-1][2], tt.indices[i+1][0], tt.indices[-1][2]] + [tt.indices[j][1] for j in range(tt.order) if j != i]
    # elif i == 0:
    #     indices = [tt.indices[i+1][0], tt.indices[-1][2]] + [tt.indices[j][1] for j in range(tt.order) if j != i]
    # elif i == tt.order - 1:
    #     indices = [tt.indices[0][0], tt.indices[i-1][2]] + [tt.indices[j][1] for j in range(tt.order) if j != i]
    # struct.append(indices)
    # # print('idx', indices)


    # operator = contract(*struct)

    # if i == 0:
    #     operator = np.expand_dims(operator, 0)
    #     operator = np.expand_dims(operator, 0)
    #     indices = ["a0", "r0"] + indices
    # elif i == tt.order - 1:
    #     operator = np.expand_dims(operator, 2)
    #     operator = np.expand_dims(operator, 2)
    #     indices = indices[:2] + [f"a{tt.order}", f"r{tt.order}"] + indices[2:]

    # # Contract
    # return operator, indices

In [183]:
tt3 = tt.copy()
tt3.orthonormalize(mode="right")

print(tt3)

P1 = retraction_operator(tt3, 0)
print("\nRectraction operator (1st):")
print(P1)
print(P1.contract())

TensorNetwork(
    core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}),
    core_1: TensorCore(r_1 {3}, m_2 {4}, r_2 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)

Rectraction operator (1st):
TensorNetwork(
    core_1: TensorCore(r_1 {3}, m_2 {4}, r_2 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)
TensorNetwork(
    core_0: TensorCore(r_1 {3}, m_2 {4}, m_3 {4}, m_4 {4}, r_4 {1})
)


In [187]:
tt3.orthonormalize(mode="left")
tt3.orthonormalize(mode="right", start=1)

P2 = retraction_operator(tt3, 1)
print("\nRectraction operator (2nd):")
print(P2)
print(P2.contract())

P2_mat = P2.contract()[0].unfold(('r_1', 'r_2'), -1).view(np.ndarray)
(P2_mat @ P2_mat.T)


Rectraction operator (2nd):
TensorNetwork(
    core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}),
    core_2: TensorCore(r_2 {3}, m_3 {4}, r_3 {3}),
    core_3: TensorCore(r_3 {3}, m_4 {4}, r_4 {1})
)
TensorNetwork(
    core_0: TensorCore(r_0 {1}, m_1 {4}, r_1 {3}, r_2 {3}, m_3 {4}, m_4 {4}, r_4 {1})
)


array([[ 1., -0., -0., -0.,  0.,  0., -0.,  0.,  0.],
       [-0.,  1.,  0.,  0., -0., -0., -0.,  0., -0.],
       [-0.,  0.,  1., -0.,  0., -0.,  0., -0.,  0.],
       [-0.,  0., -0.,  1., -0.,  0.,  0.,  0., -0.],
       [ 0., -0.,  0., -0.,  1.,  0., -0., -0.,  0.],
       [ 0., -0., -0.,  0.,  0.,  1., -0., -0., -0.],
       [-0., -0.,  0.,  0., -0., -0.,  1., -0.,  0.],
       [ 0.,  0., -0.,  0., -0., -0., -0.,  1.,  0.],
       [ 0., -0.,  0., -0.,  0., -0.,  0.,  0.,  1.]])

In [195]:
tt3.orthonormalize(mode="left")
tt3.orthonormalize(mode="right", start=2)

P3 = retraction_operator(tt3, 2)
print(P3.contract()[0])
P3_mat = P3.contract()[0].unfold(('r_2', 'r_3'), -1).view(np.ndarray)
(P3_mat @ P3_mat.T)

core_0: TensorCore(r_0 {1}, m_1 {4}, m_2 {4}, r_2 {3}, r_3 {3}, m_4 {4}, r_4 {1})


array([[ 1.,  0.,  0., -0., -0., -0.,  0., -0., -0.],
       [ 0.,  1.,  0.,  0., -0., -0.,  0.,  0.,  0.],
       [ 0.,  0.,  1., -0.,  0., -0.,  0.,  0.,  0.],
       [-0.,  0., -0.,  1.,  0., -0., -0., -0., -0.],
       [-0., -0.,  0.,  0.,  1.,  0., -0., -0., -0.],
       [-0., -0., -0., -0.,  0.,  1.,  0.,  0., -0.],
       [ 0.,  0.,  0., -0., -0.,  0.,  1.,  0.,  0.],
       [-0.,  0.,  0., -0., -0.,  0.,  0.,  1.,  0.],
       [-0.,  0.,  0., -0., -0., -0.,  0.,  0.,  1.]])

In [197]:
ttt = tt.copy()
ttt.orthonormalize(mode="left")

In [245]:
def ALS(A, b, n_iter=10):
    tt = TensorTrain(b.shape, b.ranks)
    tt.randomize()
    tt.orthonormalize(mode="right", start=1)

    def get_idx(j):
        if j == 0:
            indices = (b[j].indices[0], *tt[j].indices[1:])
        elif j == tt.order-1:
            indices = (*tt[j].indices[:-1], b[j].indices[2])
        else:
            indices = tt[j].indices
        return indices

    for i in range(n_iter):
        # Left half sweep
        for j in range(tt.order-1):
            P = retraction_operator(tt, j)
            V = P.contract(b, indices=get_idx(j))[0]
            # print(V)

            core_curr = tt.cores[f"core_{j}"]
            core_next = tt.cores[f"core_{j+1}"]

            L = left_unfold(V).view(np.ndarray)
            R = right_unfold(core_next).view(np.ndarray)

            Q, S = np.linalg.qr(L)
            W = S @ R

            tt.cores[f"core_{j}"] = TensorCore(
                Q.reshape(core_curr.shape),
                indices=core_curr.indices, name=core_curr.name
            )
            tt.cores[f"core_{j+1}"] = TensorCore(
                W.reshape(core_next.shape),
                indices=core_next.indices, name=core_next.name
            )

        # Right half sweep
        for j in range(tt.order-1, 0, -1):
            P = retraction_operator(tt, j)
            V = P.contract(b, indices=get_idx(j))[0]
            # print(V)

            core_prev = tt.cores[f"core_{j-1}"]
            core_curr = tt.cores[f"core_{j}"]

            L = left_unfold(core_prev).view(np.ndarray)
            R = right_unfold(V).view(np.ndarray)

            Q, S = np.linalg.qr(R.T)
            W = L @ S.T

            tt.cores[f"core_{j-1}"] = TensorCore(
                W.reshape(core_prev.shape),
                indices=core_prev.indices, name=core_prev.name
            )
            tt.cores[f"core_{j}"] = TensorCore(
                Q.T.reshape(core_curr.shape),
                indices=core_curr.indices, name=core_curr.name
            )

    return tt



b = TensorTrain([4, 4, 4, 4, 4], [1, 3, 3, 3, 3, 1])
b.randomize()
b.rename('r_*', 't_*')
b.orthonormalize(mode="right")
print(b)
x = ALS(None, b, n_iter=10)

TensorNetwork(
    core_0: TensorCore(t_0 {1}, m_1 {4}, t_1 {3}),
    core_1: TensorCore(t_1 {3}, m_2 {4}, t_2 {3}),
    core_2: TensorCore(t_2 {3}, m_3 {4}, t_3 {3}),
    core_3: TensorCore(t_3 {3}, m_4 {4}, t_4 {3}),
    core_4: TensorCore(t_4 {3}, m_5 {4}, t_5 {1})
)


In [268]:
xx = x.contract()[0]
bb = b.contract()[0]
print("Reconstruction error:", np.linalg.norm(xx - bb))

Reconstruction error: 1.7476453227712166e-14
