CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA supports both PyTorch and JAX.
pip install cola-ml
- Large scale linear algebra routines for
solve(A,b)
,eig(A)
,logdet(A)
,exp(A)
,trace(A)
,diag(A)
,sqrt(A)
. - Provides (user extendible) compositional rules to exploit structure through multiple dispatch.
- Has memory-efficient autodiff rules for iterative algorithms.
- Works with PyTorch or JAX, supporting GPU hardware acceleration.
- Supports operators with complex numbers and low precision.
- Provides linear algebra operations for both symmetric and non-symmetric matrices.
See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.
- LinearOperators. The core object in CoLA is the LinearOperator. You can add and subtract them
+, -
, multiply by constants*, /
, matrix multiply them@
and combine them in other ways:kron, kronsum, block_diag
etc.
import jax.numpy as jnp
import cola
A = cola.ops.Diagonal(jnp.arange(5) + .1)
B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]]))
C = B.T @ B
D = C + 0.01 * cola.ops.I_like(C)
E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2))))
F = cola.ops.BlockDiag(E, D)
v = jnp.ones(F.shape[-1])
print(F @ v)
[0.2 0.2 2.2 2.2 4.2 4.2 6.2
6.2 8.2 8.2 7.8121004 2.062 ]
- Performing Linear Algebra. With these objects we can perform linear algebra operations even when they are very big.
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inv(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F)[0][:5])
print(cola.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j 0.0000000e+00+0.j 2.1999998e+00+0.j
-1.1920929e-07+0.j 4.1999998e+00+0.j]
diag([0.31622776 1.0488088 1.4491377 1.7606816 2.0248456 ])
For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.
Qs = cola.SelfAdjoint(Q)
%timeit cola.linalg.inv(Q) @ v
%timeit cola.linalg.inv(Qs) @ v
- JAX and PyTorch. We support both ML frameworks.
import torch
A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
import jax.numpy as jnp
A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
tensor(25.)
25.0
and both support autograd (and jit):
from jax import grad, jit, vmap
def myloss(x):
A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
return jnp.ones(2) @ cola.linalg.inv(A) @ jnp.ones(2)
g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]
If you use CoLA, please cite the following paper:
@article{potapczynski2023cola,
title={{CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
journal={arXiv preprint arXiv:2309.03060},
year={2023}
}
Linear Algebra | inverse | eig | diag | trace | logdet | exp | sqrt | f(A) | SVD | pseudoinverse |
---|---|---|---|---|---|---|---|---|---|---|
Implementation | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
LinearOperators | Diag | BlockDiag | Kronecker | KronSum | Sparse | Jacobian | Hessian | Fisher | Concatenated | Triangular | FFT | Tridiagonal |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Implementation | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Annotations | SelfAdjoint | PSD | Unitary |
---|---|---|---|
Implementation | ✓ | ✓ | ✓ |
See the contributing guidelines docs/CONTRIBUTING.md for information on submitting issues and pull requests.
CoLA is Apache 2.0 licensed.
Please raise an issue if you find a bug or slow performance when using CoLA.