This is a reimplementation of Spectral State-Space Models in JAX/Flax. Here we only test Spectral SSMs on Long-range Arena (LRA).
- everything's configurable using
gin
- you can turn off auto-regressive part of the model
- multi-gpu -> data parallelization for bigger batches
- includes a stolen implementation of S5 as well
- monitor on
wandb
The bash scripts inside the ./bin
directory are executable. Before everything download the datasets:
./bin/download_lra.sh
If you're at Mila, set up your virtualenv run:
sbatch launch.sh spec_listops
To run things inside a singularity container first pull the docker image from:
module load singularity
singularity pull docker://mahanfathi/specssm:v1.0
Set your wandb
key inside claunch.sh
and run:
# claunch is containerized via singularity
sbatch claunch.sh spec_listops