Rieoptax is library for Riemannian Optimization in JAX. The proposed library is mainly driven by the needs of efficient implementation of manifold-valued operations, optimization solvers and neural network layers readily compatible with GPU and even TPU processors.
Riemannian optimization considers the following problem
Two main differences between Euclidean Optimization and Riemannian Optimization is Riemannian Gradient
For a complete example, see notebooks folder
It consists of three module
-
geometry : Implements several Riemannian manifolds of interest along with useful operations like Riemanian Exponential, Logarithmic and Euclidean gradient to Riemannian gradeint conversion rules
-
mechanism : Noise calibration for differentially private mechanism with manifold valued outputs
-
optimizers : Riemannian Optimization algorithms
Currently installaion is done directly through github and it will soon be available through PyPI.
pip install git+https://github.com/SaitejaUtpala/rieoptax.git
Preprint availabe at https://arxiv.org/pdf/2210.04840.pdf