## Introduction

In this notebook, we are going to use the `fit` function to train a UniRep model.

## Imports

Here are the imports that we are going to need for the notebook.

In [None]:
from jax.random import PRNGKey

from jax_unirep import fit
from jax_unirep.evotuning_models import mlstm64

## Sequences

We'll prepare a bunch of dummy sequences. 

In your _actual_ use case, you'll probably need to find a way to load your sequences into memory as a **list of strings**. (We try our best to stick with Python idioms.)

In [None]:
sequences = ["HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HAS", "HASVASTA"] * 5
holdout_sequences = [
    "HASTA",
    "VISTA",
    "ALAVA",
    "LIMED",
    "HAST",
    "HASVALTA",
] * 5

## Example 1: Default mLSTM model

In this first example, we'll use a default mLSTM1900 model with the shipped weights that are provided.

Nothing needs to be passed in except for:

1. The sequences to evotune against, and
2. The number of epochs.

It's the easiest/fastest way to get up and running.

In [None]:
# First way: Use the default mLSTM1900 weights with mLSTM1900 model.

tuned_params = fit(sequences, n_epochs=2)

## Example 2: Pre-build model architectures

The second way is to use one of the pre-built evotuning models.
The pre-trained weights are not shipped in the repo,
because we are assuming that the major use case here
is to train a "local" protein model (on a subset of sequences)
for a particular application.
Rather, we provide the model architecture function
and leverage JAX to provide a convenient way
to reproducibly initialize parameters.

In this example, we'll use the `mlstm64` model.
The `mlstm256` model is also available,
and it might give you better performance
though at the price of longer training time.

In [None]:
init_func, apply_func = mlstm64()

# The init_func always requires a PRNGKey,
# and input_shape should be set to (-1, 10)
# This creates randomly initialized parameterss
_, params = init_func(PRNGKey(42), input_shape=(-1, 10))


# Now we tune the params.
tuned_params = fit(sequences, n_epochs=2, model_func=apply_func, params=params)

## Obviously...

...you would probably swap in/out a different set of sequences and train for longer :).