This repo uses ideas and code that can be both found at HazyResearch/state-spaces. This code base implements the ideas and code in jax.
There are several ways to install HiPPO-Jax:
- Use a package manager
- poetry install (recommended for users)
- pip install from PyPI
- Clone repo to local machine and install from source (recommended for developers/contributors)
Ensure your CUDA drivers have been installed correctly, this will effect dependencies like Jax and PyTorch
Note: these instructions are for Linux. Commands may be different for other platforms.
- Install poetry:
curl -sSL https://install.python-poetry.org | python3 -
- Ensure python version is set to 3.8:
$ python --version
> 3.8.x
- Activate poetry virtual environment
poetry shell
- (optional) Update the dependencies to ensure dependencies work with your system
poetry update
- Install lock file dependencies:
poetry install --with jax,torch,mltools,jupyter,additional,dataset
- Create and activate virtual environment
conda create --name hippo_jax python=3.8
conda activate hippo_jax
- Install dependencies
pip install -r requirements.txt
- Clone repo:
via HTTPS:
git clone https://github.com/Dana-Farber-AIOS/HiPPO-Jax.git
cd HiPPO-Jax
via SSH
git clone git@github.com:Dana-Farber-AIOS/HiPPO-Jax.git
cd HiPPO-Jax
- Create conda environment:
conda env create -f requirements.txt
conda activate hippo_jax
- Install
Hippo-Jax
from source:
pip install -e .
Thats it!
import jax.random as jr
key, subkey = jr.split(jr.PRNGKey(0), 2)
HiPPO Matrices
from src.models.hippo.transition import TransMatrix
N = 100
measure = "legs"
matrices = TransMatrix(N=N, measure=measure)
A = matrices.A
B = matrices.B
HiPPO (LTI) Operator
from src.models.hippo.hippo import HiPPOLTI
N = 50
T = 3
step = 1e-3
measure = "legs"
desc_val = 0.0
hippo = HiPPOLTI(
N=N,
step_size=step,
GBT_alpha=desc_val,
measure=measure,
basis_size=T,
unroll=False,
)
HiPPO (LSI) Operator
from src.models.hippo.hippo import HiPPOLSI
N = 50
T = 3
step = 1e-3
L = int(T / step)
measure = "legs"
desc_val = 0.0
hippo = HiPPOLSI(
N=N,
max_length=L,
step_size=step,
GBT_alpha=desc_val,
measure=measure,
unroll=True,
)
Use right out of the box, no training needed
params = hippo.init(key, f=x)
c, y = hippo.apply(params, f=x)
HiPPO-Jax
is an open source project. Consider contributing to benefit the entire community!
There are many ways to contribute to HiPPO-Jax
, including:
- Submitting bug reports
- Submitting feature requests
- Writing documentation and examples
- Fixing bugs
- Writing code for new features
- Sharing workflows
- Sharing trained model parameters
- Sharing
HiPPO-Jax
with colleagues, students, etc.
The GNU GPL v2 version of HiPPO-Jax is made available via Open Source licensing. The user is free to use, modify, and distribute under the terms of the GNU General Public License version 2.
Commercial license options are available also.
Questions? Comments? Suggestions? Get in touch!