Hyperbolic embedding has proved to be powerful in numerous applications like graph embedding. The goal of this project is to make some of the current advances in Hyperbolic Neural Networks avaialable in JAX.
The neural networks are implemented with Haiku and the optimizers are based on Optax.
Make sure to install jax by following the official guide.
Install the project with pip:
pip install git+https://github.com/Raffaelbdl/hyperbolic-nn-haiku.git
The following content is currently implemented.
Manifolds
Manifold
: interface to implement a Riemannian manifold.Stereographic
: implementation of the Stereographic manifold (generalization of the Poincaré manifold to handle positive curvature)
Activations
: wrapper and most popular activation functions for Riemannian manifolds.
StereographicLinearLayer
: base linear layer in K-stereographic modelStereographicVanillaRNN
: base rnn layer in K-stereographic modelStereographicGRU
: gru cell in K-stereographic model
rsgd
: base riemannian stochastic gradient descentriemannian_adagrad
: riemannian version of the adagrad optimizerriemannian_adam
: riemannian version of the adam optimizerriemannian_adamw
: riemannian version of the adamw optimizer
For fully euclidian architecture :