Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model weights #106

Merged
merged 9 commits into from
Oct 2, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 146 additions & 14 deletions docs/fitting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
ElArkk marked this conversation as resolved.
Show resolved Hide resolved
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -29,6 +29,7 @@
"\n",
"from jax_unirep import fit\n",
"from jax_unirep.evotuning_models import mlstm64\n",
"from jax_unirep.utils import load_params\n",
"from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMHiddenStates"
]
},
Expand All @@ -45,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -78,9 +79,69 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b7be0780213041a4be8fc5f956fecacc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"right-padding sequences: 0%| | 0/35 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "698160777e9a4d5ea06ec890675b0e75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Random batching done: All sequences padded to max sequence length of 8\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bef6662c112d4b509a33078ec36a28d3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Calculations for training set:\n",
"INFO:evotuning:Epoch 0: Estimated average loss: 0.20966953039169312. \n",
"INFO:evotuning:Calculations for training set:\n",
"INFO:evotuning:Epoch 1: Estimated average loss: 0.16754050552845. \n"
]
}
],
"source": [
"# First way: Use the default mLSTM1900 weights with mLSTM1900 model.\n",
"\n",
Expand All @@ -94,13 +155,8 @@
"## Example 2: Pre-build model architectures\n",
"\n",
"The second way is to use one of the pre-built evotuning models.\n",
"The pre-trained weights are not shipped in the repo,\n",
"because we are assuming that the major use case here\n",
"is to train a \"local\" protein model (on a subset of sequences)\n",
"for a particular application.\n",
"Rather, we provide the model architecture function\n",
"and leverage JAX to provide a convenient way\n",
"to reproducibly initialize parameters.\n",
"The pre-trained weights for the three model architectures from the paper are shipped with the repo (1900, 256, 64).\n",
"You can also leverage JAX to reproducibly initialize random parameters.\n",
"\n",
"In this example, we'll use the `mlstm64` model.\n",
"The `mlstm256` model is also available,\n",
Expand All @@ -110,9 +166,82 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6bea4254c8544750bb616f8649c0add7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"right-padding sequences: 0%| | 0/35 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6d4b9a833c74d13b34a89e5182c01e8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Random batching done: All sequences padded to max sequence length of 8\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3742fc89c30b4506971e5bd2ec05ca0c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Calculations for training set:\n",
"INFO:evotuning:Epoch 0: Estimated average loss: 0.18236996233463287. \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"created directory at temp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Calculations for training set:\n",
"INFO:evotuning:Epoch 1: Estimated average loss: 0.1780482828617096. \n"
]
}
],
"source": [
"init_fun, apply_fun = mlstm64()\n",
"\n",
Expand All @@ -121,6 +250,9 @@
"# This creates randomly initialized parameters\n",
"_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))\n",
"\n",
"# Alternatively, you can load the paper weights\n",
"params = load_params(paper_weights=64)\n",
"\n",
"\n",
"# Now we tune the params.\n",
"tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)"
Expand Down Expand Up @@ -184,7 +316,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
"version": "3.7.10"
}
},
"nbformat": 4,
Expand Down
43 changes: 35 additions & 8 deletions jax_unirep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,25 @@
proposal_valid_letters = "ACDEFGHIKLMNPQRSTVWY"


def get_weights_dir(folderpath: Optional[str] = None):
def get_weights_dir(
folderpath: Optional[str] = None, paper_weights: Optional[int] = 1900
):
"""
Fetch model weights.

If `folderpath` is None, retrieve the mLSTM1900 weights.
If `folderpath` and `paper_weights` is None, retrieve the mLSTM1900 weights.

:param folderpath: Path to the folder containing the model weights
:param paper_weights: If paper weights should be loaded (folderpath set to None),
specify from which model architecture. Possible values are 1900, 256 and 64.
Defaults to 1900 weights.
"""
if folderpath:
return Path(folderpath)
else:
return Path(
pkg_resources.resource_filename(
"jax_unirep", "weights/1900_weights/uniref50"
"jax_unirep", f"weights/uniref50/{paper_weights}_weights"
)
)

Expand Down Expand Up @@ -119,9 +126,20 @@ def aa_seq_to_int(s: str) -> List[int]:
return [24] + [aa_to_int[a] for a in s] + [25]


def load_embedding(folderpath: Optional[str] = None):
"""Load pre-trained embedding weights for UniRep1900 model."""
weights_dir = get_weights_dir(folderpath=folderpath)
def load_embedding(
folderpath: Optional[str] = None, paper_weights: Optional[int] = 1900
):
"""
Load pre-trained embedding weights for UniRep paper models.

:param folderpath: Path to the folder containing the model weights
:param paper_weights: If paper weights should be loaded (folderpath set to None),
specify from which model architecture. Possible values are 1900, 256 and 64.
ElArkk marked this conversation as resolved.
Show resolved Hide resolved
Defaults to 1900 weights.
"""
weights_dir = get_weights_dir(
folderpath=folderpath, paper_weights=paper_weights
)
with open(weights_dir / "model_weights.pkl", "rb") as f:
params = pkl.load(f)
return params[0]
Expand Down Expand Up @@ -195,7 +213,9 @@ def validate_mLSTM_params(params: Dict, n_outputs):
)


def load_params(folderpath: Optional[str] = None):
def load_params(
folderpath: Optional[str] = None, paper_weights: Optional[int] = 1900
):
"""
Load params for passing to evotuning stax model.

Expand Down Expand Up @@ -224,8 +244,15 @@ def load_params(folderpath: Optional[str] = None):
The return should be identical to the following:

MD5 (model_weights.pkl) = 87c89ab62929485e43474c8b24cda5c8

:param folderpath: Path to the folder containing the model weights
:param paper_weights: If paper weights should be loaded (folderpath set to None),
specify from which model architecture. Possible values are 1900, 256 and 64.
ElArkk marked this conversation as resolved.
Show resolved Hide resolved
Defaults to 1900 weights.
"""
weights_dir = get_weights_dir(folderpath=folderpath)
weights_dir = get_weights_dir(
folderpath=folderpath, paper_weights=paper_weights
)
with open(weights_dir / "model_weights.pkl", "rb") as f:
params = pkl.load(f)
return params
Expand Down
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
version=version,
packages=["jax_unirep"],
package_data={
"jax_unirep": ["weights/1900_weights/uniref50/*.pkl"],
"jax_unirep": ["weights/uniref50/*/*.pkl"],
},
install_requires=[
"jax",
Expand Down
8 changes: 5 additions & 3 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

In particular, we are looking for tests that cause NaN errors in grads.
"""
from hypothesis import strategies as st, given, settings
from jax_unirep.activations import sigmoid
import pytest
import jax.numpy as np
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import grad

from jax_unirep.activations import sigmoid


@pytest.mark.parametrize("version", ["tanh", "exp"])
@given(x=st.floats(allow_nan=False, allow_infinity=False))
Expand Down