A collection of jax functions to help with common machine/deep learning related functionality.
This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.
pip install jax_toolkit
Or for additional loss function utils:
pip install jax_toolkit[losses_utils]