Skip to content

Commit

Permalink
[ENH] Speed up evotuning and improve evotuning ergonomics (#57)
Browse files Browse the repository at this point in the history
* Adding pre-commit

* Fixed up GPU memory allocation, and added docstrings.

* Adding a bash script that makes it easy to install JAX on GPU.

- The script builds a conda environment first.
- Then it clobbers over with the GPU-based installation
based on instructions given by JAX's developers.

* Update fit docstring

* Set backend to "cpu" by default

* Removed parallel kwarg

* Switched back to non-Numba-compatible dictionary definition

* Add pyproject TOML config file

Primarily to add black config

* Applied black

* Add flake8 to pre-commit hooks

* Remove flake8 from pre-commit

* Attempting to increase coverage without doing any actual work ^_^

* Add tests for params

- One unit test
- One lazy man's execution test

* Update changelog

* Fix test

* Add validate_mLSTM1900_params

This can be used as part of the test suite. I should have remembered!

Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com>

* Used validate_mLSTM1900_params as part of test

h/t @ElArkk

Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com>

* Fix batch_size in avg_loss function

* Flat is better than nested

At least I tried.

* Fix test

* Make format

* Change holdout_seqs to default back to None

* Set sane defaults for mLSTM1900 layer

* Changed to dumping every epoch by default.

* add backend explanation

* add backend to fit example

* change default batching method of fit function to random

* fix epoch calculations

* fix seq length choice for holdout seqs

Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com>
Co-authored-by: ElArkk <arkadij.kummer@gmail.com>
  • Loading branch information
3 people committed Jul 18, 2020
1 parent 661a31b commit 54ab8e6
Show file tree
Hide file tree
Showing 15 changed files with 252 additions and 82 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,5 @@ paper/cabios-template/fig01.eps
default.profraw
paper/cabios-template/speed-comparison.png
paper/cabios-template/Untitled.ipynb
*/temp/*
temp/*
14 changes: 14 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 19.10b0
hooks:
- id: black
13 changes: 11 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,18 @@ In the changelog, @ElArkk and @ericmjl would like to acknowledge contributors wh
1. option to supply an out-domain holdout set and print params as training progresses,
2. evotuning without Optuna by directly calling fit function,
3. added avg_loss() function for calculation outputting of training and holdout set loss to a log file (number and length of batches are also calculated and printed to log file),
4. introduction of "steps_per_print" to periodically calculate losses and dump parameters

4. introduction of "epochs_per_print" to periodically calculate losses and dump parameters
5. Implemented adamW in JAX and switched optimizer to adamW,
6. added option to change the number of folds in optuna KFolds,
7. update evotuning-prototype.py example script
- 30 March 2020: Code fixes for correctness and readability, and a parameter dumping function by @ivanjayapurna,
- 9 July 2020: Add progress bar to sequence sampler.
- 28 June 2020: Improvements to evotuning ergonomics, by @ericmjl
1. Adds a pre-commit configuration.
1. Adds an installation script that makes easy the installation of jax on GPU.
1. Provided backend specification of device (GPU/CPU).
1. Switched preparation of sequences as input-output pairs exclusively on CPU, for speed.
1. Added ergonomic UI features - progressbars! - that improve user experience.
1. Added docs on recommended batch size and its relationship to GPU RAM consumption.
1. Switched from exact calculation of train/holdout loss to estimated calculation.
- 9 July 2020: Add progress bar to sequence sampler, by @ericmjl
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ style:
@printf "\033[1;34mPylint passes!\033[0m\n\n"

test: # Test code using pytest.
pytest -v . --cov=./jax_unirep
pytest -v . --cov=./jax_unirep --cov-report term-missing

paper:
cd paper && bash build.sh
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ The `fit` function has further customization options,
such as different batching strategies.
Please see the function docstring for more information.

NOTE: The `fit` function will always default to using a
GPU `backend` if available for the forward and backward passes
during training of the LSTM.
However, for the calulation of the average loss
on the dataset after every epoch, you can decide
if the CPU or GPU `backend` should be used (default is CPU).

You can find an example usages of both `evotune` and `fit` [here][examples].

If you want to pass a set of mLSTM and dense weights
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- pytest=5.3.4
- pandoc=2.9.1.1
- pytest-cov=2.8.1
- pre-commit
- ipykernel=5.1.4
- optuna=1.1.0
- scikit-learn=0.22.1
Expand Down
4 changes: 2 additions & 2 deletions examples/evotuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
n_splits=2,
n_epochs_config=n_epochs_config,
learning_rate_config=lr_config,
steps_per_print=1,
epochs_per_print=1,
)

dump_params(evotuned_params, PROJECT_NAME)
Expand All @@ -47,7 +47,7 @@
step_size=LEARN_RATE,
holdout_seqs=holdout_sequences,
proj_name=PROJECT_NAME,
steps_per_print=1,
epochs_per_print=1,
)

dump_params(evotuned_params, PROJECT_NAME)
Expand Down
5 changes: 3 additions & 2 deletions examples/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import logging
from random import shuffle

from Bio import SeqIO
from pyprojroot import here

from Bio import SeqIO
from jax_unirep import fit
from jax_unirep.utils import dump_params

Expand Down Expand Up @@ -37,7 +37,8 @@
holdout_seqs=holdout_sequences,
batch_method="random",
proj_name=PROJECT_NAME,
steps_per_print=None,
epochs_per_print=None,
backend="gpu", # default is "cpu"
)

dump_params(evotuned_params, PROJECT_NAME)
Expand Down
10 changes: 10 additions & 0 deletions install_jaxgpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
conda env update -f environment.yml

PYTHON_VERSION=cp37
CUDA_VERSION=cuda101
PLATFORM=manylinux2010_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.50-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade jax

jupyter labextension install @jupyter-widgets/jupyterlab-manager
Loading

0 comments on commit 54ab8e6

Please sign in to comment.