In [1]:
import torch
torch.set_default_dtype(torch.float64)

from src.congrad import cg_batch_generic
from src.congrad.torch import TorchBackend

This package's default backends all treat their right-hand side inputs as batches of *vectors* with arbitrarily many batch dimensions.  However, [the original](https://github.com/sbarratt/torch_cg/) treats its right-hand side inputs as batches of *matrices* with a single batch dimension.  We can use some `reshape`s inside our batch matvec function to achieve similar behavior, or, if we need exact compatibility, we can write a custom backend as we do here.

In fact, since all we need to do is take norms and dot products differently, we don't even extend the default `Backend` class.  Instead, we extend `TorchBackend` and change the two functions that matter to us.

Alternatively, we could have extended `Backend`, which is only a little more work.

In [2]:
class TorchMatrixBackend(TorchBackend):
    def norm(X):
        return torch.linalg.vector_norm(X, dim=1)
    
    def dot(X, Y):
        return torch.einsum("bij,bij->bj", X, Y).unsqueeze(1)

cg_batch = cg_batch_generic(TorchMatrixBackend)

In [3]:
matrix_batches = 2
vector_batches = 5
N = 100

In [4]:
X = torch.randn(matrix_batches, N, N)
A = torch.bmm(X, X.transpose(1, 2))
B = torch.randn(matrix_batches, N, vector_batches)

In [5]:
def A_bmm(x):
    return torch.einsum("bij,bjk->bik", A, x)

In [6]:
solution, info = cg_batch(A_bmm, B, rtol=1e-6, monitor=True)

020: 9.11536e+00 (3.12042e-03 seconds)
040: 1.69207e+01 (6.46067e-03 seconds)
060: 2.70843e+01 (1.17743e-02 seconds)
080: 4.94065e+01 (1.70772e-02 seconds)
100: 3.86073e+01 (2.22566e-02 seconds)
120: 1.72182e+01 (2.75743e-02 seconds)
140: 5.90912e+01 (3.27470e-02 seconds)
160: 1.78831e-05 (3.79047e-02 seconds)
Finished in 3.81114e-02 seconds after 160 iterations (4.19821e+03 iterations/second) with a maximum residual of 6.07842e-06.


In [7]:
for i in range(matrix_batches):
    print((torch.linalg.matrix_norm(A[i] @ solution[i] - B[i]) / torch.linalg.matrix_norm(B[i])).item())

1.0450675315219223e-09
8.660558474604321e-07
