-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Speed up evotuning and improve evotuning ergonomics #57
Conversation
- The script builds a conda environment first. - Then it clobbers over with the GPU-based installation based on instructions given by JAX's developers.
Primarily to add black config
Codecov Report
@@ Coverage Diff @@
## master #57 +/- ##
==========================================
+ Coverage 89.27% 93.33% +4.06%
==========================================
Files 11 11
Lines 522 540 +18
==========================================
+ Hits 466 504 +38
+ Misses 56 36 -20
Continue to review full report at Codecov.
|
Hmmm, I'm a little confused as to how the code coverage was impacted. I think I need a second opinion on whether stuff could be refactored a bit better. @ElArkk? |
- One unit test - One lazy man's execution test
Wow, great work @ericmjl ! Very clever to optionally move average loss computation off of GPU, while retaining it in any case (if available) for the work-intensive weight updates! As for the coverage, one thing I think could be responsible for the decrease on |
jax_unirep/evotuning.py
Outdated
global evotune_loss # this is necessary for JIT to reference evotune_loss | ||
evotune_loss_jit = jit(evotune_loss, backend=backend) | ||
|
||
def batch_iter(xs: np.ndarray, ys: np.ndarray, batch_size: int = 25): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to set a default value for batch_size
here, when this function is only used inside avg_loss
, and you then supply another value (100) further below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good catch! We should make it all uniform for simplicity, as there's much less unnecessary complexity this way.
This can be used as part of the test suite. I should have remembered! Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com>
h/t @ElArkk Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com>
At least I tried.
@ivanjayapurna I wanted to get a second eye on the code before we merge. Can you and @ElArkk independently test the UPDATE 30 June 2020 9:20 AM: I just tried it out on your notebook, @ivanjayapurna, and everything runs smooth and fast. By default, I have set it to dump on every epoch. Storage is cheap, human time is not. |
I was just thinking, now that we only use one random batch of sequences to calculate the average loss on the dataset, does it still make sense to have an argument for which backend to use? Since the actual training always uses GPU if available, and needs to be able to handle the same or even larger |
Sometimes, setting the back-end explicitly can help with debugging. For example, in debugging the memory allocation issue, I found it handy to be able to switch freely between CPU and GPU backend. Hence, I think keeping the backend kwarg in there while also setting a sane default (CPU by default) makes a lot of sense, as it gives both convenience and flexibility. |
Just wanted to add on the testing side - trained on the TEM-1 sequences for 25 epochs on AWS with no memory issues |
@ivanjayapurna thank you for the thorough testing! What was the learning rate and batch size you used here? As for It could be a good idea to change epoch indexing back to 0, since right now the loss calculations at |
Yes, let's do that. I think I messed up the epoch calculation when I did this PR up. @ElArkk do you have a spare cycle to handle it? (If not, no worries, I can get to this later in the week.) |
@ericmjl @ivanjayapurna I did a quick rework of the epoch calculation, let me know if you think it makes sense this way. |
It works for me. Anything else blocking this PR? |
No, I guess this is ready :) I just have one concern still: Let's say someone wants to evotune on 50-100k sequences. If they use a batch size of ~100, loss after each epoch would be calculated on just 0.2% and 0.1% of the whole dataset respectively. Do you think this is enough to estimate overall loss @ericmjl ? |
Possibly not for only a few epochs, but in the limit of long-run numbers, it should be not too much of an issue. The key show-stopper that I think we should not compromise on is the interactive feel. |
By interactive feel you mean, not having to wait for too long between epochs for the loss calculation? In any case, we shouldn't delay merging the sped-up and more stable evotuning any longer! If we see any problems with avg loss calculation, we can always come back to it. @ericmjl what do you think? |
Agreed, hit that button when it’s done! (And go to bed soon, it’s awfully late there for you to be responding! 😸) |
Hitting that button after a good nights sleep 😄 |
NOICE! (Dude, you have no idea - I was knocked out for 5 hours this afternoon. I’m lacking sleep myself haha.) |
* Adding pre-commit * Fixed up GPU memory allocation, and added docstrings. * Adding a bash script that makes it easy to install JAX on GPU. - The script builds a conda environment first. - Then it clobbers over with the GPU-based installation based on instructions given by JAX's developers. * Update fit docstring * Set backend to "cpu" by default * Removed parallel kwarg * Switched back to non-Numba-compatible dictionary definition * Add pyproject TOML config file Primarily to add black config * Applied black * Add flake8 to pre-commit hooks * Remove flake8 from pre-commit * Attempting to increase coverage without doing any actual work ^_^ * Add tests for params - One unit test - One lazy man's execution test * Update changelog * Fix test * Add validate_mLSTM1900_params This can be used as part of the test suite. I should have remembered! Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com> * Used validate_mLSTM1900_params as part of test h/t @ElArkk Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com> * Fix batch_size in avg_loss function * Flat is better than nested At least I tried. * Fix test * Make format * Change holdout_seqs to default back to None * Set sane defaults for mLSTM1900 layer * Changed to dumping every epoch by default. * add backend explanation * add backend to fit example * change default batching method of fit function to random * fix epoch calculations * fix seq length choice for holdout seqs Co-authored-by: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com> Co-authored-by: ElArkk <arkadij.kummer@gmail.com>
PR Description
Before you read on, please ignore the branch name. I thought originally that I could use numba to speed up things, but it turned out that once again, with some careful profiling, I found we didn't have to.
This PR does a few things:
In any case, this PR closes #56.
Checklist
General
(
<your_username>
:<feature-branch_name>
), not<your_username>
:master
.CHANGELOG.md
file at the top.README
.Code checks
in the
tests
directory.add the packages, pinned to a version, to
environment.yml
.make test
in a console in the top level directoryto make sure all the tests pass.
make format
in a console in the top level directoryto make the code comply with the formatting standards.