Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model weights #106

Merged
merged 9 commits into from
Oct 2, 2021
Merged

Add model weights #106

merged 9 commits into from
Oct 2, 2021

Conversation

ElArkk
Copy link
Owner

@ElArkk ElArkk commented Mar 9, 2021

Small PR to add the pickled version of the 256 and 64 original model weights. This adds another 10 MB to the package, @ericmjl I think that is ok given that the 1900 weights that we already ship are around 70 MB in size anyways?

This doesn't change anything about the fact that 1900 weights will be loaded by default, but one can now use load_params to get the original 256 and 64 weights (those models being stacked mLSTMs makes it a bit of a pain to load up the weights manually from the .npy files)

@ElArkk ElArkk requested a review from ericmjl March 9, 2021 10:31
@ElArkk ElArkk linked an issue Mar 9, 2021 that may be closed by this pull request
@codecov
Copy link

codecov bot commented Mar 9, 2021

Codecov Report

Merging #106 (91b114b) into master (4cbf902) will increase coverage by 0.38%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #106      +/-   ##
==========================================
+ Coverage   97.51%   97.90%   +0.38%     
==========================================
  Files          12       12              
  Lines         524      524              
==========================================
+ Hits          511      513       +2     
+ Misses         13       11       -2     
Flag Coverage Δ
python 97.90% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
jax_unirep/utils.py 97.67% <100.00%> (ø)
jax_unirep/sampler.py 98.52% <0.00%> (+2.94%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 12209f5...91b114b. Read the comment docs.

@ElArkk
Copy link
Owner Author

ElArkk commented Mar 9, 2021

I just realised that we need to decide on whether we want to just add the weights to the repo, or also ship them with the package. If we go for the latter, we would need to adapt the API a bit to make the new 256 and 64 weights loadable, probably by a keyword arg in load_params?

@ericmjl
Copy link
Collaborator

ericmjl commented May 9, 2021

@ElArkk, let's get back to this PR once you're done with your thesis work :). I'm going to leave it open

@ElArkk ElArkk linked an issue Jun 8, 2021 that may be closed by this pull request
@ElArkk
Copy link
Owner Author

ElArkk commented Oct 1, 2021

I think we can ship the 256 and 64 weights with the package as well, those 10 MB more don't make much of a difference on package size anymore. Do you agree @ericmjl ? I made them loadable thorugh a kwarg in load_params and added them to the package data.

I wanted to migrate all weight files to git LFS in this PR (I'm following this guide here https://notiz.dev/blog/migrate-git-repo-to-git-lfs). But doing so overwrites branch history, and I need to force push the migration. Better to do this in a seperate PR in case stuff breaks I think.

Copy link
Collaborator

@ericmjl ericmjl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small changes suggested, @ElArkk. Would you also like to add in tests to make sure that the weights are loaded correctly and that we can do a forward pass of protein sequences through them? I think once that is done we can call this a reliably written model.

docs/fitting.ipynb Outdated Show resolved Hide resolved
jax_unirep/utils.py Outdated Show resolved Hide resolved
jax_unirep/utils.py Outdated Show resolved Hide resolved
ElArkk and others added 4 commits October 2, 2021 01:46
Co-authored-by: Eric Ma <ericmjl@users.noreply.github.com>
Co-authored-by: Eric Ma <ericmjl@users.noreply.github.com>
@ElArkk
Copy link
Owner Author

ElArkk commented Oct 1, 2021

Yes tests are on the todo! If we'd want to also allow for users to easily rep protein sequences using the paper's 256 and 64 model we'd need to rewrite the get_reps function a bit since those two models are stacked lstm's (right now only single layer lstm's of variable size can be used with get_reps)

@ericmjl ericmjl merged commit 9d9e771 into master Oct 2, 2021
@ericmjl
Copy link
Collaborator

ericmjl commented Oct 2, 2021

Merged!

@ElArkk ElArkk deleted the add-model-weights branch October 4, 2021 11:32
ElArkk added a commit that referenced this pull request Oct 14, 2021
* add 256 and 64 weights

* make format

* make 256 and 64 paper weights loadable

* add paper weight loading of different architectures to example

* Update jax_unirep/utils.py

Co-authored-by: Eric Ma <ericmjl@users.noreply.github.com>

* Update jax_unirep/utils.py

Co-authored-by: Eric Ma <ericmjl@users.noreply.github.com>

* clear outputs

* add tests for 256 and 64 paper weights

Co-authored-by: Eric Ma <ericmjl@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Question about weights and using different architectures 64 weight file formatting
2 participants