In [145]:
import numpy as np

In [146]:
# Lx = b

L = np.tril(np.random.randn(5, 5))
print(L)

b = np.array([1, 1, 0, 0, 0], dtype=float)


def forward_substitute(L, b):
    sol = np.zeros_like(b)
    for i in range(L.shape[1]):
        sol[i] = (b[i] - np.dot(L[i, :i], sol[:i])) / L[i, i]
        print("sol:", sol)
    return sol


x = forward_substitute(L, b)
L @ x

[[ 1.22497278  0.          0.          0.          0.        ]
 [ 0.46892257  0.74807277  0.          0.          0.        ]
 [ 0.66670839 -0.55772129 -1.12222946  0.          0.        ]
 [ 0.8884206  -0.10084119 -0.14638602  0.41325135  0.        ]
 [ 0.99743608 -0.68464543  0.20622565  0.50372726 -0.30423685]]
sol: [0.81634467 0.         0.         0.         0.        ]
sol: [0.81634467 0.82505016 0.         0.         0.        ]
sol: [0.81634467 0.82505016 0.07495419 0.         0.        ]
sol: [ 0.81634467  0.82505016  0.07495419 -1.52712419  0.        ]
sol: [ 0.81634467  0.82505016  0.07495419 -1.52712419 -1.65795761]


array([1.00000000e+00, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       2.22044605e-16])

# Lanczos Iteration

In [147]:
A_M = np.random.randn(5, 5)
A_M += A_M.T
q1 = np.array([1, 0, 0, 0, 0], float)

qs_v = [np.zeros_like(q1)]
as_s = []
bs_s = [1]
r = q1
for k in range(5):
    # extract q as the normalized vector
    qs_v += (r / bs_s[-1],)

    # compute next alpha
    as_s += (qs_v[-1] @ A_M @ qs_v[-1],)

    # isolate q_k-1 * beta_k-1
    r = A_M @ qs_v[-1] - as_s[-1] * qs_v[-1] - bs_s[-1] * qs_v[-2]
    norm = np.linalg.norm(r)
    if np.allclose(norm, 0, 1e-8):
        break
    # extract beta as the norm
    bs_s += (norm,)


H = np.zeros((k + 1, k + 1))
H += np.diag(as_s)
H += np.diag(bs_s[1:], k=1)
H += np.diag(bs_s[1:], k=-1)

Q = np.array(qs_v[1:]).T
print("H:")
print((H).round(10))

print("Q")
print((Q).round(10))
# print((Q.T @ Q).round(10))

# print((Q@H@Q.T-A_M).round(10))

H:
[[ 2.15894538  2.21640305  0.          0.          0.        ]
 [ 2.21640305 -1.54194762  2.32437696  0.          0.        ]
 [ 0.          2.32437696  2.61173191  0.83624853  0.        ]
 [ 0.          0.          0.83624853 -2.78087231  0.08645143]
 [ 0.          0.          0.          0.08645143  2.57972387]]
Q
[[ 1.          0.          0.         -0.         -0.        ]
 [ 0.         -0.31431629 -0.19752588 -0.27795007 -0.8859642 ]
 [ 0.         -0.34551712 -0.6739704   0.6493088   0.06913691]
 [ 0.         -0.5952016  -0.24050336 -0.61510382  0.45775591]
 [ 0.         -0.65387938  0.67000415  0.35041281 -0.02733235]]


# Bidiagonalization (cooler Lanczos)

In [148]:
import dataclasses
import jax


@dataclasses.dataclass
class BidiagResult:
    rs: list[list[float]]
    ls: list[list[float]]
    L: int
    B: int
    R: int
    alphas: list[float]
    betas: list[float]
    c: float


def bidiagonalize(A, start_vector) -> BidiagResult:
    m = A.shape[0]
    n = A.shape[1]

    any_n_vec = np.zeros(n)
    any_m_vec = np.zeros(m)
    any_number = 0
    betas = [0]
    alphas = [any_number]

    c = 1 / np.linalg.norm(start_vector)
    r_columns = [any_n_vec, start_vector * c]
    l_columns = [any_m_vec]

    for k in range(1, max(n, m) + 1):
        t = A @ r_columns[k] - betas[k - 1] * l_columns[k - 1]
        alpha_k = np.linalg.norm(t)
        alphas.append(alpha_k)
        l_k = t / alpha_k
        l_columns.append(l_k)

        w = A.T @ l_k - alpha_k * r_columns[k]
        beta_k = np.linalg.norm(w)
        betas.append(beta_k)

        r_kp1 = w / beta_k
        r_columns.append(r_kp1)

        if np.allclose(beta_k, 0, atol=1e-10) or np.isnan(beta_k):
            break

    L = np.array(l_columns[1:]).T
    R = np.array(r_columns[1:-1]).T
    B = np.diag(alphas[1:]) + np.diag(betas[1:-1], k=1)

    return BidiagResult(
        ls=l_columns, rs=r_columns, L=L, B=B, R=R, alphas=alphas, betas=betas, c=c
    )


In [149]:
m = 3
n = 2

A = np.random.randn(m, n)
start_vector = 2 * np.eye(1, n).flatten()

result = bidiagonalize(A, start_vector)
L = result.L
B = result.B
R = result.R


print()
print("L^TL and R^TR:")
print(np.array2string(L.T @ L, precision=2, floatmode="maxprec_equal"))
print(np.array2string(R.T @ R, precision=2, floatmode="maxprec_equal"))

print()
print("A vs LBR^T:")
print(np.array2string(L @ B @ R.T, precision=2, floatmode="maxprec_equal"))
print(np.array2string(A, precision=2, floatmode="maxprec_equal"))

print()
print("AR vs LB:")
print(np.array2string(A @ R, precision=2, floatmode="maxprec_equal"))
print(np.array2string(L @ B, precision=2, floatmode="maxprec_equal"))


are_close = np.allclose(A, L @ B @ R.T, atol=1e-5)
print("Reconstruction is good:", are_close)
assert are_close


L^TL and R^TR:
[[1.00e+00 2.78e-17]
 [2.78e-17 1.00e+00]]
[[1. 0.]
 [0. 1.]]

A vs LBR^T:
[[ 0.63  1.82]
 [-1.21  0.23]
 [ 1.20  0.36]]
[[ 0.63  1.82]
 [-1.21  0.23]
 [ 1.20  0.36]]

AR vs LB:
[[ 0.63  1.82]
 [-1.21  0.23]
 [ 1.20  0.36]]
[[ 0.63  1.82]
 [-1.21  0.23]
 [ 1.20  0.36]]
Reconstruction is good: True


In [150]:
def bidiagonalize_jvp(primals, tangents) -> tuple[BidiagResult, BidiagResult]:
    A, start_vector = primals
    dA, d_start_vector = tangents

    m = A.shape[0]
    n = A.shape[1]

    any_n_vec = np.zeros(n)
    any_m_vec = np.zeros(m)
    any_number = 0
    bs = [0]
    as_ = [any_number]

    c = 1 / np.linalg.norm(start_vector)
    rs = [any_n_vec, start_vector * c]
    ls = [any_m_vec]

    for k in range(1, max(n, m) + 1):
        t = A @ rs[k] - bs[k - 1] * ls[k - 1]
        alpha_k = np.linalg.norm(t)
        as_.append(alpha_k)
        l_k = t / alpha_k
        ls.append(l_k)

        w = A.T @ l_k - alpha_k * rs[k]
        beta_k = np.linalg.norm(w)
        bs.append(beta_k)

        r_kp1 = w / beta_k
        rs.append(r_kp1)

        if np.allclose(beta_k, 0, atol=1e-10) or np.isnan(beta_k):
            break

    L = np.array(ls[1:]).T
    R = np.array(rs[1:-1]).T
    B = np.diag(as_[1:]) + np.diag(bs[1:-1], k=1)

    primal_output = BidiagResult(rs=rs, ls=ls, L=L, B=B, R=R, alphas=as_, betas=bs, c=c)

    d_as = [0] * len(as_)
    d_bs = [0] * len(bs)
    d_rs = [any_n_vec] * (len(rs))
    d_rs[1] = (
        d_start_vector
        - start_vector
        * (
            start_vector.T @ d_start_vector
        )  # This should be corrected for when start vector is not unit length
    ) / np.linalg.norm(start_vector)
    d_ls = [any_m_vec * 0] * (len(ls))

    # d_rs[1] = d_start_vector, known
    # d_ls[0] = doesn't matter because bs_[0] = 0
    # d_bs[0] = 0

    # In each iteration, assume we already know d_rs[n], d_ls[n-1], d_bs[n-1]
    for n in range(1, len(as_)):
        d_a_n = ls[n].T @ (A @ d_rs[n] + dA @ rs[n] - d_ls[n - 1] * bs[n - 1])
        d_as[n] = d_a_n
        d_l_n = (
            A @ d_rs[n]
            + dA @ rs[n]
            - ls[n] * d_as[n]
            - ls[n - 1] * d_bs[n - 1]
            + d_ls[n - 1] * bs[n - 1]
        ) / as_[n]
        d_ls[n] = d_l_n
        d_b_n = (
            rs[n + 1].T @ A.T @ d_ls[n]
            + rs[n + 1].T @ dA.T @ ls[n]
            - rs[n + 1].T @ d_rs[n] * as_[n]
        )
        d_bs[n] = d_b_n
        d_rs[n + 1] = (
            A.T @ d_ls[n]
            + dA.T @ ls[n]
            - rs[n] * d_as[n]
            - rs[n + 1] * d_bs[n]
            - d_rs[n] * as_[n]
        ) / d_bs[n]

    d_c = (
        -(start_vector @ d_start_vector)
        / (start_vector @ start_vector)
        * np.linalg.norm(start_vector)
    )

    dL = np.array(d_ls[1:]).T
    dR = np.array(d_rs[1:-1]).T
    dB = np.diag(d_as[1:]) + np.diag(d_bs[1:-1], k=1)  # beta index -1?

    tangent_output = BidiagResult(
        rs=d_rs, ls=d_ls, L=dL, B=dB, R=dR, alphas=d_as, betas=d_bs, c=d_c
    )

    return primal_output, tangent_output

In [None]:
A = np.random.randn(3, 2)
da = np.random.randn(3, 2)
start_vector = 1 * np.eye(2, 1).flatten()
d_start_vector = np.random.randn(2)
# d_start_vector = -2 * np.eye(2, 1, k=0).flatten()


# numeric test
h = 0.0001

result, tangents = bidiagonalize_jvp(
    primals=(A, start_vector), tangents=(da, d_start_vector)
)
result = bidiagonalize(A, start_vector)
result_wiggled = bidiagonalize(A, start_vector + d_start_vector * h)


fields = ["rs", "alphas", "ls", "betas"]

for field in fields:
    print("- Field:", field)

    # print("initial:", result.__getattribute__(field)[1])
    # print("wiggled:", result_wiggled.__getattribute__(field)[1])

    print(
        "aprox:",
        (result_wiggled.__getattribute__(field)[1] - result.__getattribute__(field)[1])
        / h,
    )

    print("exact:", tangents.__getattribute__(field)[1])

print("field:", "c")
print(
    "aprox:",
    (result_wiggled.c - result.c) / h,
)
print("exact:", tangents.c)


- Field: rs
aprox: [-9.38230604e-07 -1.36983927e-01]
exact: [ 0.         -0.13698954]
- Field: alphas
aprox: 0.35490620373113124
exact: -0.009035158446323355
- Field: ls
aprox: [-0.11170777 -0.03778973 -0.26896889]
exact: [-0.62078384  0.83186898 -1.349808  ]
- Field: betas
aprox: 0.4483334061422539
exact: 0.30347325837350303
field: c
aprox: -0.40976143724003045
exact: -0.40977729010234437


In [197]:
A = np.random.randn(3, 2)
da = np.random.randn(3, 2)
start_vector = 1 * np.eye(2, 1).flatten()
d_start_vector = np.random.randn(2)
# d_start_vector = -2 * np.eye(2, 1, k=0).flatten()

result, tangents = bidiagonalize_jvp(
    primals=(A, start_vector), tangents=(da, d_start_vector)
)

result_wiggled, _ = bidiagonalize_jvp(
    primals=(A, start_vector + d_start_vector * h), tangents=(da, d_start_vector)
)


# numeric test

fields = ["rs", "alphas", "ls", "betas"]

for field in fields:
    print("-- Field:", field)

    # print("initial:", result.__getattribute__(field)[1])
    # print("wiggled:", result_wiggled.__getattribute__(field)[1])

    print(
        "aprox:",
        (result_wiggled.__getattribute__(field)[1] - result.__getattribute__(field)[1])
        / h,
    )

    print("exact:", tangents.__getattribute__(field)[1])

print("field:", "c")
print(
    "aprox:",
    (result_wiggled.c - result.c) / h,
)
print("exact:", tangents.c)


-- Field: rs
aprox: [-4.33431069e-07  9.31055351e-02]
exact: [0.         0.09307973]
-- Field: alphas
aprox: 0.07314036367889187
exact: 0.39033317241133014
-- Field: ls
aprox: [ 0.11427791  0.05928417 -0.15206785]
exact: [ 0.86352674 -0.96384005  1.5244294 ]
-- Field: betas
aprox: 0.17390904949010633
exact: -1.7229245117998953
field: c
aprox: 2.772419717826935
exact: 2.77165173306653


In [164]:
A = np.random.randn(3, 2)
da = np.random.randn(3, 2)
start_vector = 1 * np.eye(2, 1).flatten()
d_start_vector = np.random.randn(2)

result = bidiagonalize(A, start_vector)
result__, _ = bidiagonalize_jvp((A, start_vector), (A, start_vector))

print(np.array(result.alphas) - np.array(result__.alphas))
print(np.array(result.betas) - np.array(result__.betas))
print(np.array(result.rs) - np.array(result__.rs))
print(np.array(result.ls) - np.array(result__.ls))
print(np.array(result.c) - np.array(result__.c))

[0. 0. 0.]
[0. 0. 0.]
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
0.0
