In [1]:
import torch
torch.set_default_dtype(torch.float64)

from src.congrad import *

This package's default backends all treat their right-hand side inputs as batches of *vectors*.  However, [the original](https://github.com/sbarratt/torch_cg/) treats its right-hand side inputs as batches of *matrices*.  This is generally more complicated to work with, so we avoid it, but it also allows for more powerful batching.  Here, we use a custom backend to recover this exact behavior for PyTorch.

In [2]:
class TorchMatrixBackend(Backend):
    def norm(X):
        return torch.linalg.vector_norm(X, dim=1)
    
    def dot(X, Y):
        return torch.einsum("bij,bij->bj", X, Y).unsqueeze(1)

    def all_true(X):
        return X.all()

    def max_vector_scalar(X, y):
        if not type(X) == type(y):
            y = torch.tensor(y).type(type(X))
        return torch.maximum(X, y)

    def presentable_norm(residual): # We don't actually have to provide this but our monitor looks a lot nicer when we do.
        return torch.max(residual).item()

cg_batch = cg_batch_generic(TorchMatrixBackend)

In [3]:
matrix_batches = 2
vector_batches = 5
N = 100

In [4]:
X = torch.randn(matrix_batches, N, N)
A = torch.bmm(X, X.transpose(1, 2))
B = torch.randn(matrix_batches, N, vector_batches)

In [5]:
def A_bmm(x):
    return torch.einsum("bij,bjk->bik", A, x)

In [6]:
solution, info = cg_batch(A_bmm, B, rtol=1e-6, monitor=True)

020: 1.64440e+01 (5.13959e-03 seconds)
040: 1.96133e+01 (1.01542e-02 seconds)
060: 3.00728e+01 (1.43580e-02 seconds)
080: 3.47338e+01 (1.79937e-02 seconds)
100: 2.00813e+01 (2.20690e-02 seconds)
120: 2.21640e+01 (2.66073e-02 seconds)
140: 1.15412e+01 (3.08127e-02 seconds)
Finished in 3.76298e-02 seconds after 158 iterations (4.19880e+03 iterations/second) with a maximum residual of 3.42252e-06.


In [7]:
for i in range(matrix_batches):
    print((torch.linalg.matrix_norm(A[i] @ solution[i] - B[i]) / torch.linalg.matrix_norm(B[i])).item())

7.577851584112634e-07
1.9529075409704955e-07
