In [208]:
import numpy as np
import time
import scipy 
from scipy.linalg import khatri_rao
import itertools
from tqdm.notebook import tqdm

In [201]:
def dumb_als(tensor, rank, tol=1e-6, maxiter=100, verbose=0, seed=42):
    ndim = tensor.ndim
    dims = tensor.shape
    np.random.seed(seed)
    factors = [np.random.randn(dim, rank) for dim in dims]
    
    # tensor -> A
    # factors -> U^k
    # A_k = U^k @ (khatri_rao(U^i), i!=k).T
    # A_k.T = (khatri_rao(U^i), i!=k) @ U^k.T
    # B_k = (khatri_rao(U^i), i!=k)
    # A_k.T = B_k @ U^k.T
    # U^k.T = least_squares(B_k @ X - A_k.T)
    
    stop = False
    for it in range(maxiter):
        stop = 0
        for k in range(ndim):
            B_k = np.ones(rank).reshape(1, -1)
            for j in range(ndim):
                if j != k:
                    B_k = khatri_rao(B_k, factors[j])
            
            A_k = np.moveaxis(tensor, k, 0).reshape(dims[k], -1)
            prev = factors[k]
            if verbose == 2:
                print(B_k.shape)
                print(A_k.T.shape)
                print()
            factors[k] = np.linalg.lstsq(B_k, A_k.T, rcond=None)[0].T
            if np.linalg.norm(prev - factors[k]) < tol:
                stop += 1
                
        if verbose==1:
            print(f'iteration {it+1}/{maxiter}')
            B_1 = np.ones(rank).reshape(1, -1)
            for j in range(1, ndim):
                B_1 = khatri_rao(B_1, factors[j])
            A_1 = factors[0] @ B_1.T
            print(f'error = {np.linalg.norm(tensor.reshape(-1) - A_1.reshape(-1)):.5f}\n')
        if ndim - stop <= 1:
            break
    print(f'stopping iterations on {it+1}/{maxiter}')
    B_1 = np.ones(rank).reshape(1, -1)
    for j in range(1, ndim):
        B_1 = khatri_rao(B_1, factors[j])
    A_1 = factors[0] @ B_1.T
    error = np.linalg.norm(tensor.reshape(-1) - A_1.reshape(-1))
    print(f'final error = {error:.5f}\n')
    return factors, error

In [202]:
def norm_als(tensor, rank, tol=1e-6, maxiter=100, verbose=False, seed=42):
    ndim = tensor.ndim
    dims = tensor.shape
    np.random.seed(seed)
    factors = [np.random.randn(dim, rank) for dim in dims]
    
    # tensor -> A
    # factors -> U^k
    # A_k = U^k @ (khatri_rao(U^i), i!=k).T
    # A_k.T = (khatri_rao(U^i), i!=k) @ U^k.T
    # B_k = (khatri_rao(U^i), i!=k)
    # A_k.T = B_k @ U^k.T
    # U^k.T = least_squares(B_k @ X - A_k.T)
    
    # smart part:
    # X_k = B_k^* @ B_k = (khatri_rao(U^i), i!=k)^* @ (khatri_rao(U^i), i!=k) 
    # X_k = pointwise_product(U^i^*, U_i) for j != k
    # X<k = pointwise_product(U^i^*, U_i) for j < k
    # X>k = pointwise_product(U^i^*, U_i) for j > k
    # X_k = pointwise_product(X<k, X>k)
    # X<k+1 = pointwise_product(X<k, U^k+1^* @ U_k+1)
    
    # B_k * X = A_k.T - least squares
    # (B_k* @ B_k) * X = B_k* @ A_k.T
    # X = least_squares (X_k, B_k* @ A_k.T) 
    # wow! very cool!
    # very smart als
    
    stop = False
    X_less = [0] * (ndim + 1)
    X_more = [0] * (ndim + 1)    
    it = 1
    while it < maxiter:
        stop = 0
        X_less[0] = np.ones((rank, rank))
        X_more[-1] = np.ones((rank, rank))
        for j in reversed(list(range(ndim))):
            X_more[j] = np.multiply(X_more[j+1], factors[j].T @ factors[j])
        
        # left to right
        for k in range(ndim):
            X_k = np.multiply(X_less[k], X_more[k + 1])
            A_k = np.moveaxis(tensor, k, 0).reshape(dims[k], -1)
            prev = factors[k]
            
            B_k = np.ones(rank).reshape(1, -1)
            for j in range(ndim):
                if j != k:
                    B_k = khatri_rao(B_k, factors[j])
            
            if verbose == 2:
                print(k)
                print(X_k.shape)
                print(B_k.shape)
                print(A_k.T.shape)
                print()
            
            factors[k] = np.linalg.lstsq(X_k, B_k.T @ A_k.T, rcond=None)[0].T
            X_less[k + 1] = np.multiply(X_less[k], factors[k].T @ factors[k])

            if np.linalg.norm(prev - factors[k]) < tol:
                stop += 1
                
        if verbose==1:
            print(f'iteration {it}/{maxiter}')
            B_1 = np.ones(rank).reshape(1, -1)
            for j in range(1, ndim):
                B_1 = khatri_rao(B_1, factors[j])
            A_1 = factors[0] @ B_1.T
            print(f'error = {np.linalg.norm(tensor.reshape(-1) - A_1.reshape(-1)):.5f}\n')
        if ndim - stop <= 1:
            break
        it += 1
        
        stop = 0
        for j in range(1, ndim + 1):
            X_less[j] = np.multiply(X_less[j-1], factors[j-1].T @ factors[j-1])                                
        # right to left
        for k in reversed(list(range(ndim))):
            X_k = np.multiply(X_less[k], X_more[k + 1])
            A_k = np.moveaxis(tensor, k, 0).reshape(dims[k], -1)
            prev = factors[k]
            
            B_k = np.ones(rank).reshape(1, -1)
            for j in range(ndim):
                if j != k:
                    B_k = khatri_rao(B_k, factors[j])
            
            if verbose == 2:
                print(k)
                print(X_k.shape)
                print(B_k.shape)
                print(A_k.T.shape)
                print()
                
            factors[k] = np.linalg.lstsq(X_k, B_k.T @ A_k.T, rcond=None)[0].T
            X_more[k] = np.multiply(X_more[k + 1], factors[k].T @ factors[k])

            if np.linalg.norm(prev - factors[k]) < tol:
                stop += 1
                                
        if verbose==1:
            print(f'iteration {it}/{maxiter}')
            B_1 = np.ones(rank).reshape(1, -1)
            for j in range(1, ndim):
                B_1 = khatri_rao(B_1, factors[j])
            A_1 = factors[0] @ B_1.T
            print(f'error = {np.linalg.norm(tensor.reshape(-1) - A_1.reshape(-1)):.5f}\n')
        if ndim - stop <= 1:
            break
        it += 1
        
    print(f'stopping iterations on {it+1}/{maxiter}')
    B_1 = np.ones(rank).reshape(1, -1)
    for j in range(1, ndim):
        B_1 = khatri_rao(B_1, factors[j])
    A_1 = factors[0] @ B_1.T
    error = np.linalg.norm(tensor.reshape(-1) - A_1.reshape(-1))
    print(f'final error = {error:.5f}\n')
    return factors, error

In [203]:
sizes = np.array((10, 35, 20))
T = np.zeros(sizes)
for i in range(sizes[0]):
    for j in range(sizes[1]):
        for k in range(sizes[2]):
            T[i, j, k] = np.sin(i + j + k+20)
rank = 10
print('dumb_als')
factors, error = dumb_als(T, rank, verbose=1)
print(f'dumb_err={error}')

print('norm_als')
factors, error = norm_als(T, rank, verbose=1)
print(f'norm_err={error}')

dumb_als
iteration 1/100
error = 14.88554

iteration 2/100
error = 1.10638

iteration 3/100
error = 0.13741

iteration 4/100
error = 0.01712

iteration 5/100
error = 0.00213

iteration 6/100
error = 0.00027

iteration 7/100
error = 0.00003

iteration 8/100
error = 0.00000

iteration 9/100
error = 0.00000

stopping iterations on 9/100
final error = 0.00000

norm_als
iteration 1/100
error = 14.88554

iteration 2/100
error = 3.59734

iteration 3/100
error = 0.81960

iteration 4/100
error = 0.20631

iteration 5/100
error = 0.05132

iteration 6/100
error = 0.01290

iteration 7/100
error = 0.00326

iteration 8/100
error = 0.00083

iteration 9/100
error = 0.00021

iteration 10/100
error = 0.00005

stopping iterations on 11/100
final error = 0.00005



In [204]:
sizes = np.array((10, 35, 20, 50, 40))
T = np.zeros(sizes)

st = time.time()
for I in itertools.product(*(range(i) for i in sizes)):
    T[I] = np.sin(sum(I))
print(f'filled tensor in {time.time() - st} seconds')
print(f'total elements in tensor = {np.prod(sizes)}')

print('\ndumb als')
st = time.time()
rank = 5
factors, dumb_error = dumb_als(T, rank, verbose=1)
dumb_als_time = time.time() - st

print('\nnorm als')
st = time.time()
rank = 5
factors, norm_error = norm_als(T, rank, verbose=1)
norm_als_time = time.time() - st
print(f'dumb_als_time = {dumb_als_time:.2f} seconds, error = {dumb_error:.5f}')
print(f'norm_als_time = {norm_als_time:.2f} seconds, error = {norm_error:.5f}')

filled tensor in 8.876933813095093 seconds
total elements in tensor = 14000000

dumb als
iteration 1/100
error = 2280.83467

iteration 2/100
error = 2046.71046

iteration 3/100
error = 1307.57345

iteration 4/100
error = 625.85631

iteration 5/100
error = 62.54004

iteration 6/100
error = 8.11693

iteration 7/100
error = 0.99112

iteration 8/100
error = 0.10669

iteration 9/100
error = 0.01519

iteration 10/100
error = 0.00153

iteration 11/100
error = 0.00011

iteration 12/100
error = 0.00002

iteration 13/100
error = 0.00000

stopping iterations on 13/100
final error = 0.00000


norm als
iteration 1/100
error = 2280.83467

iteration 2/100
error = 1922.39162

iteration 3/100
error = 1627.97677

iteration 4/100
error = 1261.17731

iteration 5/100
error = 1034.03022

iteration 6/100
error = 931.01424

iteration 7/100
error = 726.50284

iteration 8/100
error = 520.40127

iteration 9/100
error = 43.14724

iteration 10/100
error = 8.13925

iteration 11/100
error = 1.61804

iteration 12/100

In [205]:
sizes = np.array((10, 35, 20, 50, 40))
T = np.zeros(sizes)

st = time.time()
for I in itertools.product(*(range(i) for i in sizes)):
    T[I] = np.sum(I)
print(f'filled tensor in {time.time() - st} seconds')
print(f'total elements in tensor = {np.prod(sizes)}')

print('\ndumb als')
st = time.time()
rank = 10
factors, dumb_error = dumb_als(T, rank, maxiter=20, verbose=1)
dumb_als_time = time.time() - st

print('\nnorm als')
st = time.time()
rank = 10
factors, norm_error = norm_als(T, rank, maxiter=20, verbose=1)
norm_als_time = time.time() - st
print(f'dumb_als_time = {dumb_als_time:.2f} seconds, error = {dumb_error:.5f}')
print(f'norm_als_time = {norm_als_time:.2f} seconds, error = {norm_error:.5f}')

filled tensor in 31.067357063293457 seconds
total elements in tensor = 14000000

dumb als
iteration 1/20
error = 968.80154

iteration 2/20
error = 118.55878

iteration 3/20
error = 34.23975

iteration 4/20
error = 32.98620

iteration 5/20
error = 32.35108

iteration 6/20
error = 31.91948

iteration 7/20
error = 31.57903

iteration 8/20
error = 31.27703

iteration 9/20
error = 30.98773

iteration 10/20
error = 30.69899

iteration 11/20
error = 30.40507

iteration 12/20
error = 30.10312

iteration 13/20
error = 29.79148

iteration 14/20
error = 29.46896

iteration 15/20
error = 29.13452

iteration 16/20
error = 28.78712

iteration 17/20
error = 28.42570

iteration 18/20
error = 28.04911

iteration 19/20
error = 27.65621

iteration 20/20
error = 27.24581

stopping iterations on 20/20
final error = 27.24581


norm als
iteration 1/20
error = 968.80154

iteration 2/20
error = 74.64715

iteration 3/20
error = 60.58117

iteration 4/20
error = 52.67718

iteration 5/20
error = 47.55089

iteratio

In [206]:
sizes = np.array((10, 35, 20, 50, 40))
T = np.zeros(sizes)

st = time.time()
for I in itertools.product(*(range(i) for i in sizes)):
    T[I] = 1 / (np.sum(I) + 1)
print(f'filled tensor in {time.time() - st} seconds')
print(f'total elements in tensor = {np.prod(sizes)}')

print('\ndumb als')
st = time.time()
rank = 5
factors, dumb_error = dumb_als(T, rank, maxiter=20, verbose=1)
dumb_als_time = time.time() - st

print('\nnorm als')
st = time.time()
rank = 5
factors, norm_error = norm_als(T, rank, maxiter=20, verbose=1)
norm_als_time = time.time() - st
print(f'dumb_als_time = {dumb_als_time:.2f} seconds, error = {dumb_error:.5f}')
print(f'norm_als_time = {norm_als_time:.2f} seconds, error = {norm_error:.5f}')

filled tensor in 31.71893286705017 seconds
total elements in tensor = 14000000

dumb als
iteration 1/20
error = 5.31220

iteration 2/20
error = 2.43273

iteration 3/20
error = 1.70007

iteration 4/20
error = 1.54805

iteration 5/20
error = 1.34100

iteration 6/20
error = 1.22884

iteration 7/20
error = 1.17647

iteration 8/20
error = 1.14575

iteration 9/20
error = 1.12366

iteration 10/20
error = 1.10633

iteration 11/20
error = 1.09200

iteration 12/20
error = 1.07961

iteration 13/20
error = 1.06847

iteration 14/20
error = 1.05813

iteration 15/20
error = 1.04831

iteration 16/20
error = 1.03880

iteration 17/20
error = 1.02945

iteration 18/20
error = 1.02017

iteration 19/20
error = 1.01087

iteration 20/20
error = 1.00146

stopping iterations on 20/20
final error = 1.00146


norm als
iteration 1/20
error = 5.31220

iteration 2/20
error = 1.93002

iteration 3/20
error = 1.57271

iteration 4/20
error = 1.43170

iteration 5/20
error = 1.27071

iteration 6/20
error = 1.12180

iterat

In [207]:
sizes = np.array((15, 15, 15, 15, 15, 15))
T = np.zeros(sizes)

st = time.time()
for I in itertools.product(*(range(i) for i in sizes)):
    T[I] = np.sum(I)
print(f'filled tensor in {time.time() - st} seconds')
print(f'total elements in tensor = {np.prod(sizes)}')

print('\ndumb als')
st = time.time()
rank = 5
factors, dumb_error = dumb_als(T, rank, verbose=1)
dumb_als_time = time.time() - st

print('\nnorm als')
st = time.time()
rank = 5
factors, norm_error = norm_als(T, rank, verbose=1)
norm_als_time = time.time() - st
print(f'dumb_als_time = {dumb_als_time:.2f} seconds, error = {dumb_error:.5f}')
print(f'norm_als_time = {norm_als_time:.2f} seconds, error = {norm_error:.5f}')

filled tensor in 25.64196467399597 seconds
total elements in tensor = 11390625

dumb als
iteration 1/100
error = 4347.31785

iteration 2/100
error = 1832.43976

iteration 3/100
error = 927.12193

iteration 4/100
error = 680.38640

iteration 5/100
error = 617.80766

iteration 6/100
error = 583.63039

iteration 7/100
error = 549.64443

iteration 8/100
error = 517.10516

iteration 9/100
error = 486.38453

iteration 10/100
error = 457.11323

iteration 11/100
error = 428.69463

iteration 12/100
error = 400.42147

iteration 13/100
error = 371.46489

iteration 14/100
error = 340.80682

iteration 15/100
error = 307.16606

iteration 16/100
error = 269.05960

iteration 17/100
error = 225.54174

iteration 18/100
error = 178.93376

iteration 19/100
error = 138.31815

iteration 20/100
error = 112.55171

iteration 21/100
error = 99.21376

iteration 22/100
error = 91.98484

iteration 23/100
error = 87.30881

iteration 24/100
error = 83.80344

iteration 25/100
error = 80.85907

iteration 26/100
error 

In [210]:
sizes = np.array((10, 10, 10, 10, 10, 10))
T = np.zeros(sizes)
tensor_t = []
dumb_e = []
dumb_t = []
norm_e = []
norm_t = []
N = 100

for i in tqdm(range(N)):
    st = time.time()
    for I in itertools.product(*(range(i) for i in sizes)):
        T[I] = np.sum(I)
    tensor_t.append(time.time() - st)

    st = time.time()
    rank = 5
    factors, dumb_error = dumb_als(T, rank, verbose=0, seed=i**3)
    dumb_e.append(dumb_error)
    dumb_t.append(time.time() - st)
    
    st = time.time()
    rank = 5
    factors, norm_error = norm_als(T, rank, verbose=0, seed=i**3)
    norm_e.append(norm_error)
    norm_t.append(time.time() - st)
    


  0%|          | 0/100 [00:00<?, ?it/s]

stopping iterations on 100/100
final error = 2.23651

stopping iterations on 102/100
final error = 4.97805

stopping iterations on 100/100
final error = 2.10944

stopping iterations on 102/100
final error = 23.27816

stopping iterations on 100/100
final error = 1.67158

stopping iterations on 102/100
final error = 0.11373

stopping iterations on 100/100
final error = 24.78722

stopping iterations on 102/100
final error = 17.55162

stopping iterations on 100/100
final error = 2.64028

stopping iterations on 102/100
final error = 0.50024

stopping iterations on 100/100
final error = 4.36220

stopping iterations on 102/100
final error = 7.70669

stopping iterations on 100/100
final error = 27.39996

stopping iterations on 102/100
final error = 15.40020

stopping iterations on 100/100
final error = 2.07043

stopping iterations on 102/100
final error = 13.14782

stopping iterations on 100/100
final error = 26.83788

stopping iterations on 102/100
final error = 107.21645

stopping iterations

In [211]:
print(f'statistics of {N} trials')
print(f'mean tensor creation time = {np.mean(tensor_t):.2f}')
print(f'mean dumb als time  = {np.mean(dumb_t):.2f}')
print(f'mean dumb als error = {np.mean(dumb_e):.5f}')
print(f'mean norm als time  = {np.mean(norm_t):.2f}')
print(f'mean norm als error = {np.mean(norm_e):.5f}')

statistics of 100 trials
mean tensor creation time = 2.42
mean dumb als time  = 6.36
mean dumb als error = 8.55113
mean norm als time  = 2.34
mean norm als error = 15.87520
