In [1]:
import torch
torch.set_default_dtype(torch.float64)

from congrad import cg_batch_generic
from congrad.torch import TorchBackend

This package's default backends all treat their right-hand side inputs as batches of *matrices* with arbitrarily many batch dimensions.  However, sometimes, we might want to have batches of *vectors* instead.  We can use some `reshape`s inside our batch matvec function to achieve similar behavior, or, if we need exact compatibility, we can write a custom backend as we do here.

In fact, since all we need to do is take norms and dot products differently, we don't even extend the default `Backend` class.  Instead, we extend `TorchBackend` and change the two functions that matter to us.

Alternatively, we could have extended `Backend`, which is only a little more work.

In [2]:
class TorchVectorBackend(TorchBackend):
    def norm(X):
        return torch.linalg.vector_norm(X, dim=-1)

    def dot(X, Y):
        return torch.matmul(X.unsqueeze(-2), Y.unsqueeze(-1)).squeeze(-1)

cg_batch = cg_batch_generic(TorchVectorBackend)

In [3]:
N = 100
n_batches = 5

X = torch.randn(N, N)
A = X @ X.T + 0.001 * torch.eye(N)
b = torch.randn(n_batches, N)

def A_batch(x):
    return torch.einsum("ij,bj->bi", A, x)

In [4]:
solution, info = cg_batch(A_batch, b, rtol=1e-6, monitor=True)

020: 1.30577e+01 (1.93238e-03 seconds)
040: 1.36897e+01 (3.00837e-03 seconds)
060: 3.29272e+01 (3.99160e-03 seconds)
080: 5.76979e+01 (4.97651e-03 seconds)
100: 5.30374e+01 (5.94234e-03 seconds)
120: 2.88872e+01 (6.94442e-03 seconds)
140: 6.34805e+00 (7.92074e-03 seconds)
160: 6.68151e-05 (8.89492e-03 seconds)
Finished in 9.03511e-03 seconds after 162 iterations (1.79301e+04 iterations/second) with a maximum residual of 6.72177e-06.


In [5]:
for i in range(n_batches):
    print(torch.linalg.vector_norm(A @ solution[i] - b[i]) / torch.linalg.vector_norm(b[i]))

tensor(1.5848e-06)
tensor(2.9012e-07)
tensor(2.0339e-07)
tensor(3.4547e-07)
tensor(1.0593e-06)
