## Defining $\texttt{block\_matvec}$ which wraps $\texttt{matvec}$:

$$S_Av = \left[\begin{array}{cc}
0 & A \\
A^\top & 0
\end{array}\right]

\left[\begin{array}{c}
u \\
l
\end{array}\right]

= \left[\begin{array}{c}
Al \\
A^\top u
\end{array}\right]
=
\left[\begin{array}{c}
\texttt{matvec}(l) \\
\texttt{vjp(matvec, *)}(u)
\end{array}\right]

$$

In [123]:
import jax
import jax.numpy as jnp
import matfree
import numpy
import random


r, c = (2, 3)
A = jax.random.normal(jax.random.PRNGKey(int(random.random() * 1000)), shape=(r, c))
padded_v0 = jax.random.normal(
    jax.random.PRNGKey(int(random.random() * 1000)), shape=c + r
)

zeros_upper = jnp.zeros((r, r))
zeros_lower = jnp.zeros((c, c))
S = jnp.block([[zeros_upper, A], [A.T, zeros_lower]])
print("Dense matvec:")
print((S @ padded_v0).round(2))


def matvec(v, params):
    scalar = params
    return scalar * A @ v


def block_matvec(v, matvec_fun, *params):
    upper, lower = jnp.split(v, [r])
    upper_matvec, vecmat_fun = jax.vjp(lambda v: matvec_fun(v, *params), lower)
    (lower_vecmat,) = vecmat_fun(upper)
    return jnp.concat((upper_matvec, lower_vecmat))


params = jnp.array(1.0)
print("Sparse matvec:")
print(block_matvec(padded_v0, matvec, 1.0).round(2))

print("Sparse matvec with identity matrix:")
print(jax.vmap(lambda v: block_matvec(v, matvec, 1.0))(jnp.eye(5)))

Dense matvec:
[-0.85  0.62 -0.93 -2.01  0.12]
Sparse matvec:
[-0.85  0.62 -0.93 -2.01  0.12]
Sparse matvec with identity matrix:
[[ 0.          0.         -0.55573966 -1.2556321   0.12333269]
 [ 0.          0.         -0.10629336  0.91457129 -1.10012671]
 [-0.55573966 -0.10629336  0.          0.          0.        ]
 [-1.2556321   0.91457129  0.          0.          0.        ]
 [ 0.12333269 -1.10012671  0.          0.          0.        ]]


## Testing that gradients work as expected

In [124]:
def loss(vector):
    return (jnp.linalg.norm(vector) - 1.0) ** 2


loss_fun = jax.value_and_grad(
    lambda v, *p: loss(block_matvec(v, matvec, *p)), argnums=(0, 1)
)

s = 1.0
for i in range(100):
    loss_, (v_grad, s_grad) = loss_fun(padded_v0, s)
    s -= s_grad * 0.01
    padded_v0 -= v_grad * 0.01
print("loss", loss_)

print(f"s = {s.round(2)}\nv = {padded_v0.round(2)}")
output = block_matvec(padded_v0, matvec, s)
print(output)

loss 2.526851402599607e-09
s = 0.5
v = [ 1.44  0.2  -0.69  0.9   0.39]
[-0.34495453  0.23187889 -0.40998796 -0.81167483 -0.02058161]


## Bidiagonalization on $A$

In [141]:
from tests.test_bidiag_JVP_and_VJP_jax import bidiagonalize_vjpable_matvec


def matvec(v, params):
    A = params
    return A @ v


func = bidiagonalize_vjpable_matvec(5, custom_vjp=True, reorthogonalize=False)

start_vec = jax.random.normal(
    jax.random.PRNGKey(int(random.random() * 1000)), shape=(c)
)

output = func(matvec, start_vec, A)
print(output.B.round(3))
print(output.rs.round(3))
print(output.ls.round(3))

[[0.117 1.31  0.    0.    0.   ]
 [0.    0.917 1.178 0.    0.   ]
 [0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.   ]]
[[-0.747  0.496 -0.443  0.     0.   ]
 [ 0.47   0.865  0.175  0.     0.   ]
 [ 0.47  -0.077 -0.879  0.     0.   ]]
[[-0.998 -0.07   0.     0.     0.   ]
 [-0.07   0.998  0.     0.     0.   ]]


## Tridiagonalization on $S_A$ 

In [126]:
print("A:")
print(A.round(2))
print("matvec:")
print(jax.vmap(lambda v: matvec(v, A))(jnp.eye(3)).round(2).T)


def matvec_big(v, *params):
    return block_matvec(v, matvec, *params)


print("S_A:")
print(jax.vmap(lambda v: matvec_big(v, A))(jnp.eye(5)).round(2))

A:
[[-0.56 -1.26  0.12]
 [-0.11  0.91 -1.1 ]]
matvec:
[[-0.56 -1.26  0.12]
 [-0.11  0.91 -1.1 ]]
S_A:
[[ 0.    0.   -0.56 -1.26  0.12]
 [ 0.    0.   -0.11  0.91 -1.1 ]
 [-0.56 -0.11  0.    0.    0.  ]
 [-1.26  0.91  0.    0.    0.  ]
 [ 0.12 -1.1   0.    0.    0.  ]]


In [127]:
from matfree import decomp

new_start = jnp.pad(start_vec, (r, 0))
print("padded start vector:", new_start.round(3))

print((new_start / jnp.linalg.norm(new_start)).round(3))

func_big = matfree.decomp.tridiag_sym(5, materialize=False)
output_big = func_big(matvec_big, new_start, A)

print("Symmetric tridiagonal output:")
print(output_big.J_small)
print("Q matrix, orthogonal vectors:")
print(output_big.Q_tall.round(3))

padded start vector: [ 0.     0.    -0.14  -0.45   0.927]
[ 0.     0.    -0.134 -0.433  0.891]
Symmetric tridiagonal output:
(Array([0., 0., 0., 0., 0.], dtype=float64), Array([1.54452401, 0.8098921 , 0.7597578 , 0.58307431], dtype=float64))
Q matrix, orthogonal vectors:
[[ 0.     0.471  0.     0.882  0.   ]
 [ 0.    -0.882  0.     0.471  0.   ]
 [-0.134  0.     0.049  0.    -0.99 ]
 [-0.433  0.    -0.901  0.     0.015]
 [ 0.891  0.    -0.43   0.    -0.142]]


## Wrapping `block_matvec`

In [147]:
from typing import NamedTuple
from jax import Array
import jax.numpy as jnp


class DecompResult(NamedTuple):
    rs: Array
    ls: Array
    as_: Array
    bs_: Array


def arnoldi_bidiagonalization(
    num_matvecs: int,
    shape: tuple[int, int],
    custom_vjp: bool = True,
    reortho: str = "full",
):
    def block_matvec(v0, matvec_fun, *params):
        upper, lower = jnp.split(v0, [shape[0]])
        upper_matvec, vecmat_fun = jax.vjp(lambda v: matvec_fun(v, *params), lower)
        (lower_vecmat,) = vecmat_fun(upper)
        return jnp.concat((upper_matvec, lower_vecmat))

    tridiag = matfree.decomp.tridiag_sym(
        num_matvecs,
        materialize=False,
        custom_vjp=custom_vjp,
        reortho=reortho,
    )

    def wrapped_bidiag(v0, matvec, *params):
        padded_v0 = jnp.pad(v0, (shape[0], 0))

        def block_matved_curried(v, *params):
            return block_matvec(v, matvec, *params)

        output = tridiag(block_matved_curried, padded_v0, *params)

        ababab = output.J_small[1]
        as_ = ababab[0::2]
        bs_ = ababab[1::2]

        rlrlrl = output.Q_tall
        rs = rlrlrl[shape[0] :, 0::2].round(3)
        ls = rlrlrl[: shape[0], 1::2].round(3)

        return DecompResult(rs=rs, ls=ls, as_=as_, bs_=bs_)

    return wrapped_bidiag


func = arnoldi_bidiagonalization(5, shape=(2, 3))

func(start_vec, matvec, A)

DecompResult(rs=Array([[-0.747,  0.496, -0.443],
       [ 0.47 ,  0.865,  0.175],
       [ 0.47 , -0.077, -0.879]], dtype=float64), ls=Array([[-0.998, -0.07 ],
       [-0.07 ,  0.998]], dtype=float64), as_=Array([0.11746602, 0.91740758], dtype=float64), bs_=Array([1.31022036, 1.17752914], dtype=float64))