diff --git a/tests/test_utils.py b/tests/test_utils.py index e013e65..344303a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,7 +12,7 @@ from jax import vmap from jax.random import PRNGKey, normal -from jax_unirep.evotuning_models import mlstm64, mlstm1900 +from jax_unirep.evotuning_models import mlstm64, mlstm256, mlstm1900 from jax_unirep.utils import ( aa_seq_to_int, batch_sequences, @@ -27,7 +27,6 @@ one_hots, right_pad, seq_to_oh, - validate_mLSTM_params, ) @@ -105,13 +104,21 @@ def test_load_embedding(): assert emb.shape == (26, 10) -def test_load_params(): +@pytest.mark.parametrize( + "size, model", + [ + (64, mlstm64), + (256, mlstm256), + (1900, mlstm1900), + ], +) +def test_load_params(size, model): """ - Make sure that all parameters needed for the evotuning stax model + Make sure that all parameters needed for the mlstm stax models get loaded with the correct shapes. """ - _, apply_fun = mlstm1900() - params = load_params() + _, apply_fun = model() + params = load_params(paper_weights=size) validate_params(model_func=apply_fun, params=params)