Skip to content

Commit

Permalink
add tests for 256 and 64 paper weights
Browse files Browse the repository at this point in the history
  • Loading branch information
ElArkk committed Oct 2, 2021
1 parent 012f9f7 commit 91b114b
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jax import vmap
from jax.random import PRNGKey, normal

from jax_unirep.evotuning_models import mlstm64, mlstm1900
from jax_unirep.evotuning_models import mlstm64, mlstm256, mlstm1900
from jax_unirep.utils import (
aa_seq_to_int,
batch_sequences,
Expand All @@ -27,7 +27,6 @@
one_hots,
right_pad,
seq_to_oh,
validate_mLSTM_params,
)


Expand Down Expand Up @@ -105,13 +104,21 @@ def test_load_embedding():
assert emb.shape == (26, 10)


def test_load_params():
@pytest.mark.parametrize(
"size, model",
[
(64, mlstm64),
(256, mlstm256),
(1900, mlstm1900),
],
)
def test_load_params(size, model):
"""
Make sure that all parameters needed for the evotuning stax model
Make sure that all parameters needed for the mlstm stax models
get loaded with the correct shapes.
"""
_, apply_fun = mlstm1900()
params = load_params()
_, apply_fun = model()
params = load_params(paper_weights=size)
validate_params(model_func=apply_fun, params=params)


Expand Down

0 comments on commit 91b114b

Please sign in to comment.