In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from bsde_solver.core.tensor.tensor_train import TensorTrain, left_unfold, right_unfold
from bsde_solver.core.tensor.tensor_core import TensorCore
from bsde_solver.core.tensor.tensor_network import TensorNetwork

import numpy as np

np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)


In [None]:
tt = TensorTrain(shape=[4, 4, 4, 4], ranks=[1, 3, 3, 3, 1])
tt.randomize()

Implementation of several functions and algorithms with tensor train decomposition using the following random tensor train (matrix product state):

<center><img src="./images/test_tt.png" style="width: 40%"/></center>

In [None]:
tt

### Left and Right unfoldings

Left unfolding:
<center><img src="./images/left_unfold.png" style="width: 40%"/></center>

Right unfolding:
<center><img src="./images/right_unfold.png" style="width: 40%"/></center>

In [None]:
R = right_unfold(tt[1])
L = left_unfold(tt[1])

print(R)
print(L)

### Left and Right orthogonalizations


Left orthogonalization:
<center><img src="./images/tt_left_ortho.png" style="width: 40%"/></center>

Right orthogonalization:
<center><img src="./images/tt_right_ortho.png" style="width: 40%"/></center>


In [None]:
tt1 = tt.copy()
tt2 = tt.copy()

In [None]:
tt[0].view(np.ndarray), tt1[0].view(np.ndarray), tt2[0].view(np.ndarray)

In [None]:
tt1.orthonormalize(mode="left")

<center><img src="./images/tt_left_ortho_id.png" style="width: 30%"/></center>

In [None]:
L = left_unfold(tt1[1]).view(np.ndarray)
L.T @ L

In [None]:
tt2.orthonormalize(mode="right")

<center><img src="./images/tt_right_ortho_id.png" style="width: 30%"/></center>

In [None]:
R = right_unfold(tt2[3]).view(np.ndarray)
R @ R.T

### Left & Right part of the tensor train

Left part :
<center><img src="./images/left_contract.png" style="width: 50%"/></center>
Right part :
<center><img src="./images/right_contract.png" style="width: 50%"/></center>

Then, we can write the tensor train as:

<center><img src="./images/left_right_repr.png" style="width: 80%"/></center>


## Alternating Least Squares (ALS) algorithm

In [None]:
from opt_einsum import contract

def retraction_operator(tt, i):
    operator = tt.extract([f'core_{j}' for j in range(tt.order) if j != i])
    return operator

In [None]:
tt3 = tt.copy()
tt3.orthonormalize(mode="right")

print(tt3)

P1 = retraction_operator(tt3, 0)
print("\nRectraction operator (1st):")
print(P1)
print(P1.contract())

In [None]:
tt3.orthonormalize(mode="left")
tt3.orthonormalize(mode="right", start=1)

P2 = retraction_operator(tt3, 1)
print("\nRectraction operator (2nd):")
print(P2)
print(P2.contract())

P2_mat = P2.contract().unfold(('r_1', 'r_2'), -1).view(np.ndarray)
(P2_mat @ P2_mat.T)

In [None]:
tt3.orthonormalize(mode="left")
tt3.orthonormalize(mode="right", start=2)

P3 = retraction_operator(tt3, 2)
print(P3.contract())
P3_mat = P3.contract().unfold(('r_2', 'r_3'), -1).view(np.ndarray)
(P3_mat @ P3_mat.T)

In [None]:
ttt = tt.copy()
ttt.orthonormalize(mode="left")

#### Rename indices

In [None]:
tt4 = tt.copy()
tt5 = tt4.rename("m_*", "n_*", inplace=False)
tt4, tt5

#### Micro-optimization

At each step of the ALS algorithm, we need to compute the following expression:

$$P_i^T A P_i 

In [None]:
def ALS(A, b, n_iter=10, ranks=None):
    shape = [10, 10, 10]
    ranks = [1, 3, 3, 1]
    tt = TensorTrain(shape, ranks)
    tt.randomize()
    tt.orthonormalize(mode="right", start=1)

    print(tt)

    def get_idx(j):
        if j == 0: indices = (b[j].indices[0], *tt[j].indices[1:])
        elif j == tt.order-1: indices = (*tt[j].indices[:-1], b[j].indices[2])
        else: indices = tt[j].indices
        return indices

    def get_idx2(j):
        if j == 0: indices = (f'r_{tt.order}', f'm_{j+1}', f'r_{j+1}', f't_{tt.order}', f'n_{j+1}', f't_{j+1}', )
        #(f't_{tt.order-1}', *tt[j].indices[1:], )
        elif j == tt.order-1: indices = (f'r_{j}', f'm_{j+1}', f'r_0', f't_{j}', f'n_{j+1}', f't_{0}', )
        else: indices = (*tt[j].indices, f't_{j}', f'n_{j+1}', f't_{j+1}', )
        return indices

    def micro_optimization(tt, j):
        P = retraction_operator(tt, j)
        P.name = 'P'
        # T = TensorNetwork(
        #     cores=[P, A, P.rename("m_*", "n_*", inplace=False).rename("r_*", "t_*", inplace=False)],
        #     names=['P^T','A','P']
        # ).contract(indices=get_idx2(j))
        # U = TensorNetwork(cores=[P, b], names=['P','b']).contract(indices=get_idx(j))

        # V = np.linalg.tensorsolve(T.view(np.ndarray), U.view(np.ndarray))
        # V = TensorCore(V, indices=get_idx(j))

        V = TensorNetwork(cores=[P, b], names=['P','b']).contract(indices=get_idx(j))
        # print(V)
        return V

    for i in range(n_iter):
        # Left half sweep
        for j in range(tt.order-1):
            # Micro optimization
            V = micro_optimization(tt, j)

            core_curr = tt.cores[f"core_{j}"]
            core_next = tt.cores[f"core_{j+1}"]

            L = left_unfold(V).view(np.ndarray)
            R = right_unfold(core_next).view(np.ndarray)

            Q, S = np.linalg.qr(L)
            W = S @ R

            tt.cores[f"core_{j}"] = TensorCore.like(Q, core_curr)
            tt.cores[f"core_{j+1}"] = TensorCore.like(W, core_next)

        # Right half sweep
        for j in range(tt.order-1, 0, -1):
            # Micro optimization
            V = micro_optimization(tt, j)

            core_prev = tt.cores[f"core_{j-1}"]
            core_curr = tt.cores[f"core_{j}"]

            L = left_unfold(core_prev).view(np.ndarray)
            R = right_unfold(V).view(np.ndarray)

            Q, S = np.linalg.qr(R.T)
            W = L @ S.T

            tt.cores[f"core_{j-1}"] = TensorCore.like(W, core_prev)
            tt.cores[f"core_{j}"] = TensorCore.like(Q.T, core_curr)

    return tt


# b = TensorTrain([4, 4, 4, 4, 4], [1, 3, 3, 3, 3, 1])
b = TensorTrain([10, 10, 10], [1, 4, 4, 1])
b.randomize()
b.rename('r_*', 't_*')
print(b)
b.orthonormalize(mode="right")

b = TensorTrain.from_tensor(np.arange(10*10*10).reshape(10, 10, 10), ranks=[1, 8, 8, 1])
b.rename('r_*', 't_*')
print(b)


a = np.random.rand(3, 4, 5, 3, 4, 5)
#.reshape(4, 4, 4, 4, 4, 4)
A = TensorCore(a, ['m_1', 'm_2', 'm_3', 'n_1', 'n_2', 'n_3'])
# print(b, A)

import time

start_time = time.time()
x = ALS(A, b, n_iter=100)
print(f"Elapsed time: {(time.time() - start_time):.5f}s")

In [None]:
def scalar_ALS(X, b, n_iter=10, ranks=None):
    shape = tuple(X.shape[1] for _ in range(X.shape[0]))
    tt = TensorTrain(shape, ranks)
    tt.randomize()
    tt.orthonormalize(mode="right", start=1)

    def get_idx(j):
        if j == 0: indices = (A[j].indices[0], *tt[j].indices[1:])
        elif j == tt.order-1: indices = (*tt[j].indices[:-1], A[j].indices[2])
        else: indices = tt[j].indices
        return indices

    def micro_optimization(tt, j):
        P = retraction_operator(tt, j)
        # V = TensorNetwork(cores=[P, b], names=['P','b']).contract()#indices=get_idx(j))
        # print(V)
        print(P)
        for core in P.cores:
            core *= b
        return P.contract()

    for i in range(n_iter):
        # Left half sweep
        for j in range(tt.order-1):
            # Micro optimization
            V = micro_optimization(tt, j)
            print(V)

            core_curr = tt.cores[f"core_{j}"]
            core_next = tt.cores[f"core_{j+1}"]

            L = left_unfold(V).view(np.ndarray)
            R = right_unfold(core_next).view(np.ndarray)

            Q, S = np.linalg.qr(L)
            W = S @ R

            tt.cores[f"core_{j}"] = TensorCore.like(Q, core_curr)
            tt.cores[f"core_{j+1}"] = TensorCore.like(W, core_next)

        # Right half sweep
        for j in range(tt.order-1, 0, -1):
            # Micro optimization
            V = micro_optimization(tt, j)

            core_prev = tt.cores[f"core_{j-1}"]
            core_curr = tt.cores[f"core_{j}"]

            L = left_unfold(core_prev).view(np.ndarray)
            R = right_unfold(V).view(np.ndarray)

            Q, S = np.linalg.qr(R.T)
            W = L @ S.T

            tt.cores[f"core_{j-1}"] = TensorCore.like(W, core_prev)
            tt.cores[f"core_{j}"] = TensorCore.like(Q.T, core_curr)

    return tt

from bsde_solver.utils import flatten

b = 12
# Create a tensor of shape (4, 4, 4) from x with each axis with polynomial degree 1, x, x^2, x^3
d = 4
x = np.array([-1, -1, 0.5])
X = np.array([x**i for i in range(d)]).T
n = X.shape[0]

V = scalar_ALS(X, b, n_iter=100, ranks=[1, 3, 3, 1])
# Create tensor dot product of x (4, 4, 4)
# print(X)
# result = contract(*flatten([[X[i], ('a_'+str(i), )] for i in range(n)]))
# X.shape, result.shape, result

In [None]:
xx = x.contract()
bb = b.contract()
print("Reconstruction error:", np.linalg.norm(xx - bb))

In [None]:
x, b

In [None]:
c = b.contract(indices=['m_1', 'm_2', 'm_3'])
y = np.linalg.tensorsolve(a, c.view(np.ndarray))#, axes=([3, 4, 5], [0, 1, 2]))
print(c.view(np.ndarray))
print(a.shape, c.view(np.ndarray).shape)
print(np.tensordot(a, y, axes=([3, 4, 5], [0, 1, 2])))

In [None]:
y = x.contract(indices=['m_1', 'm_2', 'm_3']).view(np.ndarray)

print(np.tensordot(a, y, axes=([3, 4, 5], [0, 1, 2])))

In [None]:
xx = x.contract()
bb = b.contract()
print("Reconstruction error:", np.linalg.norm(xx - bb))

### Tensor decomposition

In [None]:
A = np.arange(10*10*10*10*10).reshape(10, 10, 10, 10, 10)
Att = TensorTrain.from_tensor(A, [1, 2, 2, 2, 2, 1])

print("Decomposition error:", np.linalg.norm(A - Att.contract().squeeze()))

In [None]:
Id = np.eye(4*4).reshape(4, 4, 4, 4)
I = TensorTrain.from_tensor(Id, [1, 2, 2, 2, 1])

print("Identity error:", np.linalg.norm(Id - I.contract().squeeze()))


id = TensorNetwork([TensorCore(Id, ['i', 'j', 'k', 'l'])])
U = TensorNetwork([TensorCore(np.random.randn(4, 4), ['i', 'j'])])

I.rename("m_1", "i")
I.rename("m_2", "j")
I.rename("m_3", "k")
I.rename("m_4", "l")

print(I, )