In [122]:
import random
import numpy as np
import time
import scipy 

In [278]:
def ada_cross(A, max_rank, verbose=False, r_tol=1e-6):
    m, n = A.shape
    max_rank = min(max_rank, m, n)
    U, V = np.zeros((m, 0)), np.zeros((0, n))
    seen_cols, seen_rows = np.zeros((m, 0)), np.zeros((0, n))
    seen_cols_ind, seen_rows_ind = [], []
    I = set(range(m))
    J = set(range(n))
    r = 0
    UV_norm = 0.0
    saw_cols = 0
    saw_rows = 0
    
    while r < max_rank:
        i = random.choice(tuple(I))
        j = random.choice(tuple(J))
        I -= {i}
        J -= {j}
        
        seen_rows_ind.append(i)
        seen_cols_ind.append(j)
        saw_cols += 1
        saw_rows += 1

        seen_rows = np.insert(seen_rows, seen_rows.shape[0], A[i,:] - U[i,:] @ V, axis=0)
        seen_cols = np.insert(seen_cols, seen_cols.shape[1], A[:,j] - U @ V[:,j], axis=1)

        rows_max_ind = np.unravel_index(np.abs(seen_rows).argmax(), seen_rows.shape)
        orig_rows_max_ind = seen_rows_ind[rows_max_ind[0]], rows_max_ind[1]
        
        cols_max_ind = np.unravel_index(np.abs(seen_cols).argmax(), seen_cols.shape)
        orig_cols_max_ind = cols_max_ind[0], seen_cols_ind[cols_max_ind[1]]

        u = None
        v = None
        a_max = None
        
        if orig_cols_max_ind == orig_rows_max_ind:
            a_max = seen_rows[rows_max_ind]
            u = seen_cols[:,cols_max_ind[1]] / np.sqrt(np.abs(a_max))
            v = seen_rows[rows_max_ind[0],:] * np.sqrt(np.abs(a_max)) / a_max
            
            seen_cols = np.delete(seen_cols, (cols_max_ind[1]), axis=1)
            seen_rows = np.delete(seen_rows, (rows_max_ind[0]), axis=0)
            del seen_rows_ind[rows_max_ind[0]]
            del seen_cols_ind[cols_max_ind[1]]
            
        else:
            i, j = cols_max_ind
            if np.abs(seen_rows[rows_max_ind]) > np.abs(seen_cols[cols_max_ind]):
                # search col once
                j = rows_max_ind[1]
                seen_cols_ind.append(j)
                J -= {j}
                seen_cols = np.insert(seen_cols, seen_cols.shape[1], A[:,j] - U @ V[:,j], axis=1)
                saw_cols += 1
                i = np.abs(seen_cols[:,-1]).argmax()

            #search row-col until convergence or maxiter
            n_iter = 0
            while n_iter < 10:
                seen_rows_ind.append(i)
                I -= {i}
                seen_rows = np.insert(seen_rows, seen_rows.shape[0], A[i,:] - U[i,:] @ V, axis=0)
                j = np.abs(seen_rows[-1,:]).argmax()
                max_rows = abs(seen_rows[-1,j])

                seen_cols_ind.append(j)
                J -= {j}
                seen_cols = np.insert(seen_cols, seen_cols.shape[1], A[:,j] - U @ V[:,j], axis=1)
                i = np.abs(seen_cols[:,-1]).argmax()
                max_cols = abs(seen_cols[i,-1])

                saw_cols += 1
                saw_rows += 1
                if np.isclose(max_rows, max_cols):
                    break
                n_iter += 1
            
            a_max = seen_cols[i,-1]
            if np.isclose(np.abs(a_max), 0):
                print(f'stopping iterations\na_max = {np.abs(a_max)}\n')
                break
            u = A[:,j] - U @ V[:,j]
            v = A[i,:] - U[i,:] @ V
            u = u / np.sqrt(np.abs(a_max))
            v = v * np.sqrt(np.abs(a_max)) / a_max
            del seen_cols_ind[-1]
            seen_cols = np.delete(seen_cols, -1, axis=1)
            for ind, val in enumerate(seen_rows_ind):
                if val == i:
                    seen_rows = np.delete(seen_rows, ind, axis=0)
                    del seen_rows_ind[ind]
                    break
        
        if u is None or v is None or a_max is None:
            print('very bad')
            return U, V, r

        if verbose:
            print(f'iter = {r + 1}')
            print(f'shape of U = {U.shape}')
            print(f'shape of V = {V.shape}')
            print(f'maxvol element = {a_max}')
        if verbose == 2:
            print(f'shape of seen_rows = {seen_rows.shape}')
            print(f'shape of seen_cols = {seen_cols.shape}')
            print(seen_rows_ind)
            print(seen_cols_ind)
            print(v.T.shape)
        seen_rows = seen_rows - np.outer(u[seen_rows_ind], v)
        seen_cols = seen_cols - np.outer(u, v[seen_cols_ind])
        
        cols_norm = np.power(np.linalg.norm(seen_cols), 2)
        rows_norm = np.power(np.linalg.norm(seen_rows), 2)
        intersect_norm = np.power(np.linalg.norm(seen_cols[seen_rows_ind,:]), 2)
        # remainder_norm =  np.power(np.abs(a_max) * np.sqrt(m - len(seen_rows_ind) - r) * np.sqrt(n - len(seen_cols_ind) - r), 2)
        remainder_norm = (rows_norm + cols_norm - intersect_norm) * (m - r) * (n - r)
        remainder_norm /= n * len(seen_rows_ind) + m * len(seen_cols_ind) - len(seen_rows_ind) * len(seen_cols_ind)
        if verbose == 2:
            print(f'{cols_norm=}')
            print(f'{rows_norm=}')
            print(f'{intersect_norm=}')
            print(f'{remainder_norm=}')
        A_minus_UV_norm = remainder_norm + cols_norm + rows_norm - intersect_norm
        UV_norm += (U.T @ u) @ (U.T @ u) + (V @ v.T) @ (V @ v.T) + (u.T @ u) * (v @ v.T)
        U = np.insert(U, U.shape[1], u, axis=1)
        V = np.insert(V, V.shape[0], v, axis=0)
        r += 1
        if verbose:
            print(f'{A_minus_UV_norm=}')
            print(f'{UV_norm=}')
            print(f'r_tol * UV_norm = {r_tol * UV_norm}\n')
        if A_minus_UV_norm <= r_tol * UV_norm:
            print('stopping iterations')
            print('absolute error estimate =', np.abs(A_minus_UV_norm), '\n')
            break
    if verbose:
        print('summary:')
        print(f'amount of seen cols = {saw_cols}')
        print(f'amount of seen rows = {saw_rows}')
        print(f'total seen elements = {saw_cols * m+ saw_rows * n} out of {m * n}')
        print(f'which is {(saw_cols * m+ saw_rows * n) / m / n * 100:.2f} %')
    return U, V, r

### Small test

In [283]:
m = 15
n = 15
A = np.zeros((m, n))
for i in range(m):
    for j in range(n):
        A[i,j] = np.sin(i + j)
print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 4, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 112.78744283246412
iter = 1
shape of U = (15, 0)
shape of V = (0, 15)
maxvol element = -0.9999902065507035
A_minus_UV_norm=46.73648683896666
UV_norm=53.1879738908165
r_tol * UV_norm = 5.31879738908165e-05

iter = 2
shape of U = (15, 1)
shape of V = (1, 15)
maxvol element = 0.9788393263750047
A_minus_UV_norm=8.831644836283554e-31
UV_norm=112.78752916659667
r_tol * UV_norm = 0.00011278752916659667

stopping iterations
absolute error estimate = 8.831644836283554e-31 

summary:
amount of seen cols = 5
amount of seen rows = 4
total seen elements = 135 out of 225
which is 60.00 %
rank = 2
absolute error = 1.8626696952826167e-30
relative error = 1.6514867688324568e-32
UV norm = 112.78744283246412
cross time = 0.0019996166229248047 seconds


### Average tests

In [284]:
m = 1000
n = 1000
A = np.zeros((m, n))
for i in range(m):
    for j in range(n):
        A[i,j] = np.sin(3*i + 7*j + 17)
print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 4, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 500000.4078253339
iter = 1
shape of U = (1000, 0)
shape of V = (0, 1000)
maxvol element = -0.9999930346319144
A_minus_UV_norm=22005.876445223796
UV_norm=250528.91990645067
r_tol * UV_norm = 0.25052891990645065

iter = 2
shape of U = (1000, 1)
shape of V = (1, 1000)
maxvol element = -0.9999895696560521
A_minus_UV_norm=2.848272939910339e-27
UV_norm=500012.1408511504
r_tol * UV_norm = 0.5000121408511503

stopping iterations
absolute error estimate = 2.848272939910339e-27 

summary:
amount of seen cols = 5
amount of seen rows = 4
total seen elements = 9000 out of 1000000
which is 0.90 %
rank = 2
absolute error = 6.968088987310707e-27
relative error = 1.3936166607577813e-32
UV norm = 500000.40782533406
cross time = 0.001995563507080078 seconds


In [285]:
m = 1000
n = 1000
A = np.zeros((m, n))
for i in range(m):
    for j in range(n):
        A[i,j] = np.sum([i, j, 23])
print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 10, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 1211150499999.9998
iter = 1
shape of U = (1000, 0)
shape of V = (0, 1000)
maxvol element = 2021.0
A_minus_UV_norm=41043555364.56631
UV_norm=1408227296053.5483
r_tol * UV_norm = 1408227.2960535483

iter = 2
shape of U = (1000, 1)
shape of V = (1, 1000)
maxvol element = -493.81543790202875
A_minus_UV_norm=6.23584735682541e-21
UV_norm=1659548057791.6296
r_tol * UV_norm = 1659548.0577916296

stopping iterations
absolute error estimate = 6.23584735682541e-21 

summary:
amount of seen cols = 5
amount of seen rows = 4
total seen elements = 9000 out of 1000000
which is 0.90 %
rank = 2
absolute error = 1.3484470813110842e-20
relative error = 1.1133604628913455e-32
UV norm = 1211150499999.9998
cross time = 0.002002239227294922 seconds


### Big tests

In [286]:
m = 5000
n = 5000
A = np.zeros((m, n))
for i in range(m):
    for j in range(n):
        A[i,j] = np.sin(3*i + 7*j + 17)
print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 10, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 12500002.0460438
iter = 1
shape of U = (5000, 0)
shape of V = (0, 5000)
maxvol element = 0.999999946835794
A_minus_UV_norm=3696954.528816059
UV_norm=6258014.112929742
r_tol * UV_norm = 6.258014112929741

iter = 2
shape of U = (5000, 1)
shape of V = (1, 5000)
maxvol element = -0.9999934469943874
A_minus_UV_norm=1.1143433030009311e-25
UV_norm=12500011.02254409
r_tol * UV_norm = 12.50001102254409

stopping iterations
absolute error estimate = 1.1143433030009311e-25 

summary:
amount of seen cols = 5
amount of seen rows = 4
total seen elements = 45000 out of 25000000
which is 0.18 %
rank = 2
absolute error = 1.621779136527103e-25
relative error = 1.297423096854924e-32
UV norm = 12500002.0460438
cross time = 0.004000186920166016 seconds


In [287]:
m = 5000
n = 5000
A = np.zeros((m, n))
for i in range(m):
    for j in range(n):
        A[i,j] = np.sum([i, j, 23])
print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 10, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 734678762500000.0
iter = 1
shape of U = (5000, 0)
shape of V = (0, 5000)
maxvol element = 10021.0
A_minus_UV_norm=21251373927283.84
UV_norm=856539009269865.8
r_tol * UV_norm = 856539009.2698658

iter = 2
shape of U = (5000, 1)
shape of V = (1, 5000)
maxvol element = -2493.7631972857
A_minus_UV_norm=4.1816851024918906e-18
UV_norm=1012955361407124.9
r_tol * UV_norm = 1012955361.4071249

stopping iterations
absolute error estimate = 4.1816851024918906e-18 

summary:
amount of seen cols = 5
amount of seen rows = 4
total seen elements = 45000 out of 25000000
which is 0.18 %
rank = 2
absolute error = 7.421237954005824e-18
relative error = 1.0101337254873791e-32
UV norm = 734678762500000.0
cross time = 0.00299835205078125 seconds


### Non-square tests

In [288]:
m = 500
n = 300
A = np.zeros((m, n))
for i in range(m):
    for j in range(n):
        A[i,j] = np.sin(3*i + 7*j + 17)
print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 10, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 75000.91099522403
iter = 1
shape of U = (500, 0)
shape of V = (0, 300)
maxvol element = -0.999991118156833
A_minus_UV_norm=27238.786659159432
UV_norm=37278.54508528938
r_tol * UV_norm = 0.037278545085289376

iter = 2
shape of U = (500, 1)
shape of V = (1, 300)
maxvol element = -0.9999892913106446
A_minus_UV_norm=5.8127229554375125e-28
UV_norm=75027.14591841237
r_tol * UV_norm = 0.07502714591841236

stopping iterations
absolute error estimate = 5.8127229554375125e-28 

summary:
amount of seen cols = 5
amount of seen rows = 4
total seen elements = 3700 out of 150000
which is 2.47 %
rank = 2
absolute error = 1.0536599983760672e-27
relative error = 1.4048629335224516e-32
UV norm = 75000.91099522403
cross time = 0.0009999275207519531 seconds


In [289]:
m = 300
n = 500
A = np.zeros((m, n))
for i in range(m):
    for j in range(n):
        A[i,j] = np.sin(3*i + 7*j + 17)
print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 10, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 74998.86422671474
iter = 1
shape of U = (300, 0)
shape of V = (0, 500)
maxvol element = -0.9999916190837964
A_minus_UV_norm=18646.321458256258
UV_norm=36647.52257139209
r_tol * UV_norm = 0.03664752257139209

iter = 2
shape of U = (300, 1)
shape of V = (1, 500)
maxvol element = 0.9999893158452438
A_minus_UV_norm=8.859789249670206e-28
UV_norm=75004.11719833335
r_tol * UV_norm = 0.07500411719833334

stopping iterations
absolute error estimate = 8.859789249670206e-28 

summary:
amount of seen cols = 6
amount of seen rows = 4
total seen elements = 3800 out of 150000
which is 2.53 %
rank = 2
absolute error = 1.1986490038733055e-27
relative error = 1.598222874749008e-32
UV norm = 74998.86422671474
cross time = 0.0009763240814208984 seconds


### Arbitrary ranks test

In [291]:
m = 1000
n = 1000
r = 7
A = np.zeros((m, n))
rows = np.zeros((r, n))
for k in range(r):
    for j in range(n):
        rows[k,j] = np.random.rand() * 10 + 20

for i in range(m):
    res = np.zeros(n)
    for k in range(r):
        res += rows[k] * (np.random.rand() * 100)
    A[i] = res

print('norm A =', np.linalg.norm(A)**2)
st = time.time()
U, V, r = ada_cross(A, 2 * r, verbose=1)
end = time.time()
print('rank =', r)
print('absolute error =', np.linalg.norm(A - U @ V) ** 2)
print('relative error =', (np.linalg.norm(A - U @ V) / np.linalg.norm(A)) ** 2)
print('UV norm =', np.linalg.norm(U @ V) ** 2)
print(f'cross time = {end - st} seconds')

norm A = 80379839061554.84
iter = 1
shape of U = (1000, 0)
shape of V = (0, 1000)
maxvol element = 16054.728668644775
A_minus_UV_norm=28821268260.620583
UV_norm=80455803598597.7
r_tol * UV_norm = 80455803.5985977

iter = 2
shape of U = (1000, 1)
shape of V = (1, 1000)
maxvol element = 1189.6409344103831
A_minus_UV_norm=35106468286.673874
UV_norm=80811311240113.62
r_tol * UV_norm = 80811311.24011362

iter = 3
shape of U = (1000, 2)
shape of V = (2, 1000)
maxvol element = -1241.7687573910116
A_minus_UV_norm=23579031497.38481
UV_norm=81310105828166.22
r_tol * UV_norm = 81310105.82816622

iter = 4
shape of U = (1000, 3)
shape of V = (3, 1000)
maxvol element = 1196.1956356428582
A_minus_UV_norm=23428183015.08569
UV_norm=81730264534042.34
r_tol * UV_norm = 81730264.53404234

iter = 5
shape of U = (1000, 4)
shape of V = (4, 1000)
maxvol element = -768.4733676538426
A_minus_UV_norm=10983535817.95273
UV_norm=81814341596620.92
r_tol * UV_norm = 81814341.59662092

iter = 6
shape of U = (1000, 5)
