In [None]:
%load_ext autoreload 
%autoreload 2

# Gradients w.r.t. input

For input design, 
i.e. designing a protein sequence to maximize (or minimize) 
the output of a neural network,
one strategy is to perform gradient ascent/descent on the inputs.
To do so, we need to take the gradient of the neural network output w.r.t. the input.
At its core, this is taking the Jacobian the neural network;
put simply, the Jacobian is a generalized form of the first derivative,
while the Hessian, correspondingly, is the second derivative.

In [None]:
from jax_unirep.evotuning_models import mlstm256
from jax.random import PRNGKey
from jax_unirep.utils import seq_to_oh

init_func, model_func = mlstm256()
_, params = init_func(PRNGKey(42), input_shape=(-1, 26))


## Soft Label

To begin, we need soft labels rather than one-hot encodings.

In [None]:
import jax.numpy as np 
from jax import vmap

def normalize_probability(v):
    """Normalize a vector to sum to 1."""
    return v / np.sum(v)

def soft_label(v: np.ndarray, delta: float = 1e-2) -> np.ndarray:
    """Apply a soft label transformation on a vector `v`.
    
    This function assumes that `v` is a one-hot encoding vector.
    """
    return normalize_probability(v + delta)

soft_label_protein = vmap(soft_label)

In [None]:
import jax.numpy as np

sequences = ["HAPPYNEWYEAK", "HAPPYNEWYEAR"]
sequences = [soft_label_protein(seq_to_oh(i)) for i in sequences]

In [None]:
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMHiddenStates, mLSTMAvgHidden
from jax.example_libraries.stax import Dense, Softmax, serial, Relu

model_layers = (
    AAEmbedding(20),
    mLSTM(512),
    mLSTMHiddenStates(),
    mLSTM(512),
    mLSTMAvgHidden(),
    Dense(1),
    Relu,
)

init_fun, apply_fun = serial(*model_layers)

_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))


In [None]:
from jax import jacfwd
apply_fun(params, sequences[0]).shape

In [None]:
from functools import partial 

jacfwd(partial(apply_fun, params))(sequences[0])