In [1]:
import torch
torch.set_default_dtype(torch.float64)

from congrad import cg_batch_generic
from congrad.torch import TorchBackend

This package's default backends all treat their right-hand side inputs as batches of *matrices* with arbitrarily many batch dimensions.  However, sometimes, we might want to have batches of *vectors* instead.  We can use some `reshape`s inside our batch matvec function to achieve similar behavior, or, if we need exact compatibility, we can write a custom backend as we do here.

In fact, since all we need to do is take norms and dot products differently, we don't even extend the default `Backend` class.  Instead, we extend `TorchBackend` and change the two functions that matter to us.

Alternatively, we could have extended `Backend`, which is only a little more work.

In [2]:
class TorchVectorBackend(TorchBackend):
    def norm(X):
        return torch.linalg.vector_norm(X, dim=-1)

    def dot(X, Y):
        return torch.matmul(X.unsqueeze(-2), Y.unsqueeze(-1)).squeeze(-1)

cg_batch = cg_batch_generic(TorchVectorBackend)

In [3]:
N = 100
n_batches = 5

X = torch.randn(N, N)
A = X @ X.T + 0.001 * torch.eye(N)
b = torch.randn(n_batches, N)

def A_batch(x):
    return torch.einsum("ij,bj->bi", A, x)

In [4]:
solution, info = cg_batch(A_batch, b, rtol=1e-6, monitor=True)

020: 5.80951e+00 (4.87328e-03 seconds)
040: 1.02998e+01 (8.88085e-03 seconds)
060: 5.38851e+01 (1.13683e-02 seconds)
080: 4.04119e+01 (1.23689e-02 seconds)
100: 3.56821e+01 (1.34747e-02 seconds)
120: 2.32527e+01 (1.48065e-02 seconds)
140: 2.99176e+01 (1.57883e-02 seconds)
Finished in 1.69961e-02 seconds after 157 iterations (9.23739e+03 iterations/second) with a maximum residual of 5.11714e-06.


In [5]:
for i in range(n_batches):
    print(torch.linalg.vector_norm(A @ solution[i] - b[i]) / torch.linalg.vector_norm(b[i]))

tensor(3.1884e-07)
tensor(4.4480e-07)
tensor(3.6125e-07)
tensor(4.5572e-07)
tensor(5.1738e-07)
