# jit, vmap, grad, and pytrees

All `LinearOperator` objects are native jax and pytorch pytrees.

This means that we can vmap over them, jit functions with them, as well as other operations.

## Example: Tree Map



## Jit example (in jax)

Let's choose jitting a function involving matrix square roots.

In [1]:
from jax import jit
import numpy as np
import jax.numpy as jnp
import jax.config as config
# set cpu
config.update("jax_platform_name", "cpu")
import cola

# construct a linear operator
A = jnp.array(np.random.randn(2, 2))
B = cola.SelfAdjoint(cola.lazify(A.T@A+ 1e-4*jnp.eye(2)))
D = cola.SelfAdjoint(cola.diag(jnp.array([3.,0.2, 1.])))
K = cola.kron(B,D)

Lets verify that CoLA indeed computes the sqrts on this matrix

In [2]:
v = jnp.array(np.random.randn(6))
K_half_v = cola.sqrt(K, tol=1e-4)@v
Kv = cola.sqrt(K, tol=1e-4)@K_half_v
print("error:",jnp.linalg.norm(Kv - K@v))

error: 9.573923e-07


    https://beartype.readthedocs.io/en/latest/api_roar/#pep-585-deprecations
  warn(


Now let's jit a function with a `LinearOperator` as an argument

In [3]:
@jit
def sqrt_mvm(K, v):
    return cola.sqrt(K, tol=1e-4)@v

print(sqrt_mvm(K,v))
print(sqrt_mvm(4*K,v)/2)

[-1.9293232   0.9868831  -1.6336229  -1.9024148   0.29426882 -2.5727682 ]
[-1.9293234  0.9868836 -1.6336229 -1.9024148  0.2942687 -2.5727682]


## Batched LinearOperator operations using vmap (in pytorch)

Let's consider a function that constructs some linear operators, and a separate function that applies some transformations.


In [4]:
import cola
import torch
torch.set_default_device('cuda')
import numpy as np

def construct_complicated_linops(X):
    X = cola.lazify(X)
    Y = X@X.T
    Y = cola.PSD(Y+cola.ops.I_like(Y))
    D = cola.PSD(cola.diag(torch.linspace(0.1,1,2)))
    W = cola.ops.BlockDiag(Y,D, multiplicities=[2, 1])
    diag_W = cola.diag(W)
    return W, cola.PSD(cola.diag(diag_W))

W,diag_W = construct_complicated_linops(torch.randn(3,3))
print(W[:5,:5].to_dense())
print(diag_W[:5,:5].to_dense())

tensor([[ 1.5247,  0.0422, -0.1941,  0.0000,  0.0000],
        [ 0.0422,  4.6091,  1.7247,  0.0000,  0.0000],
        [-0.1941,  1.7247,  2.0739,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.5247,  0.0422],
        [ 0.0000,  0.0000,  0.0000,  0.0422,  4.6091]], device='cuda:0')
tensor([[1.5247, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 4.6091, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 2.0739, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.5247, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 4.6091]], device='cuda:0')


For the example, let's consider a function that uses the diagonal of `W` as a symmetric preconditioner, used explicitly (rather than supplied as an argument to inverse).

In [5]:
def perform_operations(W,D,v):
    P = cola.pow(D, -0.5) # D^{-1/2}
    y = P@cola.inv(P@W@P,tol=1e-4)@P@v
    return y 

Now suppose that we want to perform this operation over a batch of LinearOperators, each with different data.

First we can vmap over the functions constructing the LinearOperators:

In [18]:
f = torch.jit.script(construct_complicated_linops)

UnsupportedNodeError: import statements aren't supported:
  File "/vast/aa11803/miniconda3/envs/graph/lib/python3.11/site-packages/cola/ops/operator_base.py", line 163
        # check if first element is ellipsis
        xnp = self.xnp
        from cola.ops import Sliced
        ~~~~ <--- HERE
        match ids:
            case int(i):
'__torch__.cola.ops.operator_base.LinearOperator' is being compiled since it was called from 'lazify'
  File "/vast/aa11803/miniconda3/envs/graph/lib/python3.11/site-packages/cola/fns.py", line 19
@export
def lazify(A: Union[LinearOperator, Array]):
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    """ Convert an array to a linear operator if it is not already one. """
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if isinstance(A, LinearOperator):
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        return A
        ~~~~~~~~
    else:
    ~~~~~
        return Dense(A)
        ~~~~~~~~~~~~~~ <--- HERE
'lazify' is being compiled since it was called from 'construct_complicated_linops'
  File "/tmp/ipykernel_4107257/1548879809.py", line 7
def construct_complicated_linops(X):
    X = cola.lazify(X)
    ~~~~~~~~~~~~~~~~~ <--- HERE
    Y = X@X.T
    Y = cola.PSD(Y+cola.ops.I_like(Y))


In [None]:
from torch.func import vmap
bW, bD = f(torch.randn(5, 5))

Notice that the objects are the same types and shapes,

In [14]:
print(bW.shape, type(bW))
print(bD.shape, type(bD))

(12, 12) <class 'cola.ops.operators.BlockDiag[cola.ops.operators.Sum[cola.ops.operators.Product[cola.ops.operators.Dense, cola.ops.operators.Dense], cola.ops.operators.Identity], cola.ops.operators.Diagonal]'>
(12, 12) <class 'cola.ops.operators.Diagonal'>


However the data the makes up these objects now has a batch dimension:

In [12]:
bD.diag.shape

torch.Size([1, 12])

In general these objects should not be used except precisely in conjunction with a function which is vmapped over a LinearOperator input, as shown below with perform_operations.

In [9]:
all_outs = torch.func.vmap(perform_operations)(bW, bD, torch.randn(3,bW.shape[0]))
print(all_outs.shape)

torch.Size([3, 12])


🚧 Note: Not all LinearOperators with pytorch backend support vmap 🚧

For example kronecker:

In [29]:
def get_entries(M):
    return M[:5,:5].to_dense()

try:
    vmap(get_entries)(vmap(cola.kron)(bW, bD))
except RuntimeError as e:
    print("raised exception:", e)

raised exception: Batching rule not implemented for aten::moveaxis.int; the fallback path doesn't work on out= or view ops.


## Gradients and PyTrees (jax example)