In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax_unirep.utils import get_embeddings
from jax_unirep.params import add_dense_params
from jax_unirep.layers import mlstm1900, dense
from jax_unirep.activations import softmax, identity
from jax_unirep.utils import aa_seq_to_int, aa_to_int, load_params_1900

import numpy as np

from functools import partial
from jax import grad, jit
from jax.experimental.optimizers import adam
from typing import List, Dict
from jax_unirep.utils import load_embeddings
from typing import Tuple



In [8]:
oh_arrs = np.eye(len(aa_to_int))
one_hots = {aa_to_int[k]: oh_arrs[i] for i, k in enumerate(aa_to_int.keys())}

In [44]:
def evotuning_pairs(s: str) -> Tuple[np.ndarray, np.ndarray]:
    """
    Given a sequence, return input-output pairs for evotuning.
    
    The goal of evotuning is to get the RNN to accurately predict
    the next character in a sequence.
    This convenience function exists to prep a single sequence
    into its corresponding input-output tensor pairs.
    
    Given a 1D sequence of length `k`,
    it gets represented as a 2D array of shape (k, 10),
    where 10 is the size of the embedding of each amino acid,
    and k-1 ranges from the zeroth a.a. to the nth a.a.
    This is the first element in the returned tuple.
    
    Given the same 1D sequence,
    the output is defined as a 2D array of shape (k-1, 28),
    where 28 is the number of possible characters
    present in the ``aa_to_int`` dictionary keys,
    and k-1 corresponds to the first a.a. to the nth a.a.
    This is the second element in the returned tuple.
    
    :param s: The protein sequence to featurize.
    :returns: Two 2D NumPy arrays,
        the first corresponding to the input to evotuning with shape (n_letters, 10),
        and the second corresponding to the output amino acid to predict with shape (n_letters, 28).
    """
    seq_int = aa_seq_to_int(s[:-1])
    next_letters_int = aa_seq_to_int(s[1:])
    embeddings = load_embeddings()
    x = np.stack([embeddings[i] for i in seq_int])
    y = np.stack([one_hots[i] for i in next_letters_int])
    return x, y


def input_output_pairs(sequences: List[str]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate input-output tensor pairs for evo-tuning.
    
    We check that lengths of sequences are identical,
    as this is necessary to ensure stacking of tensors happens correctly.
    
    :param sequences: A list of sequences 
        to generate input-output tensor pairs.
    :returns: Two NumPy arrays,
        the first corresponding to the input to evotuning 
        with shape (n_sequences, n_letters, 10),
        and the second corresponding to the output amino acids to predict 
        with shape (n_sequences, n_letters, 28).
        Both will have an additional "sample" dimension as the first dim.
    """
    seqlengths = set(map(len, sequences))
    if not len(seqlengths) == 1:
        raise ValueError("""
Sequences should be of uniform length, but are not. 
Please ensure that they are all of the same length before passing them in.
""")

    xs = []
    ys = []
    for s in sequences:
        x, y = evotuning_pairs(s)
        xs.append(x)
        ys.append(y)
    return np.stack(xs), np.stack(ys)

In [45]:
sequences = [
    "HASTA",
    "HASTH",
    "VISTA",
]

x, y = input_output_pairs(sequences)

In [46]:
x.shape, y.shape

((3, 6, 10), (3, 6, 28))

In [56]:
params = dict()
params["mlstm1900"] = load_params_1900()
params = add_dense_params(params, "dense", 1900, 28)

In [57]:
def predict(params, x) -> np.ndarray:
    """
    Prediction model for evotuning.
    
    Architecture is a single softmax layer on top of the RNN.
    
    :param params: Dictionary of parameters.
        Should have keys ``mlstm1900`` and ``dense`` in there.
    :param x: Input tensor.
        Should be the result of calling ``input_output_pairs``,
        and be of shape (n_sequences, n_letters, 10).
    :returns: Prediction tensor, of shape (n_sequences, n_letters, 28).
    """
    # Defensive programming checks.
    if not len(x.shape) == 3:
        raise ValueError("Input tensor should be 3-dimensional.")
    if not x.shape[-1] == 10:
        raise ValueError("Input tensor's 3rd dimension should be of length 10.")

    # Actual forward model happens here.
    _, _, x = mlstm1900(params["mlstm1900"], x)
    x = dense(params["dense"], x, activation=softmax)
    return x

In [58]:
from jax_unirep.losses import neg_cross_entropy_loss, mseloss, _neg_cross_entropy_loss


In [59]:
def evotune_loop(params: Dict[str, Dict[str, np.ndarray]], x: np.ndarray, y: np.ndarray, n: int, verbose=False):
    """
    Master function for tuning.

    :param x: Input tensor.
    :param y: Output tensor to train against.
    :param n: Number of epochs (iterations) to train model for.
    """
    # `predict` must be defined in the same source file as this function.
    loss = partial(neg_cross_entropy_loss, model=predict)
    dloss = jit(grad(loss))

    init, update, get_params = adam(step_size=0.005)

    state = init(params)
    
    for i in range(20):
        l = loss(params, x=x, y=y)
        if np.isnan(l):
            break
        if verbose:
            print(f"Iteration: {i}, Loss: {l:.4f}")

        g = dloss(params, x=x, y=y)

        state = update(i, g, state)
        params = get_params(state)
    return params

In [60]:
tuned_params = evotune_loop(params, x, y, 20, verbose=True)

Iteration: 0, Loss: 0.1536
Iteration: 1, Loss: 0.0788
Iteration: 2, Loss: 0.0362
Iteration: 3, Loss: 0.0135
Iteration: 4, Loss: 0.0076
Iteration: 5, Loss: 0.0081
Iteration: 6, Loss: 0.0083
Iteration: 7, Loss: 0.0058
Iteration: 8, Loss: 0.0061
Iteration: 9, Loss: 0.0056
Iteration: 10, Loss: 0.0056
Iteration: 11, Loss: 0.0057
Iteration: 12, Loss: 0.0056
Iteration: 13, Loss: 0.0056
Iteration: 14, Loss: 0.0055
Iteration: 15, Loss: 0.0055
Iteration: 16, Loss: 0.0055
Iteration: 17, Loss: 0.0055
Iteration: 18, Loss: 0.0055
Iteration: 19, Loss: 0.0055


## DEBUGGING/DIAGNOSIS

In [63]:
y_hat = predict(tuned_params, x)

In [64]:
np.round(y_hat)

array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        

In [67]:
def arr_to_letter(arr) -> str:
    """
    Convert a 1D one-hot array into a letter.
    
    TODO: More docstrings needed.
    """
    for k, v in one_hots.items():
        if np.allclose(arr, v):
            break
    for key, val in aa_to_int.items():
        if k == val:    
            return key

def letter_seq(arr: np.array) -> str:
    """
    Convert a 2D one-hot array into a string representation.
    
    TODO: More docstrings needed.
    """
    sequence = ""
    for letter in arr:
        sequence += arr_to_letter(np.round(letter))
    return sequence

In [70]:
sequences = []
for seq in y_hat:
    sequence = letter_seq(seq)
    print(sequence)
    sequences.append(sequence)

startASTHstop
startASTHstop
startISTAstop


# DUMP

In [None]:
a;iofle;odvh;ilkjewklfdsshj

In [None]:
params["dense"]["b"].shape

In [36]:
batch.shape

(3, 239, 1900)

In [37]:
def softmax_noaxis(x):
    print(x.shape)
    return np.exp(x) / np.sum(np.exp(x))

def softmax_newaxis(x):
    e = np.exp(batch)
    return np.divide(e, np.sum(e, axis=-1)[:, :, None])

In [39]:
_, _, batch = mlstm1900(params["mlstm1900"], x)
batch = dense(params["dense"], batch, activation=identity)

sfm_noaxis = softmax_noaxis(batch)
sfm_newaxis = softmax_newaxis(batch)

(3, 239, 28)


In [None]:
original = np.exp() / np.sum(np.exp(x))

In [32]:
batch.shape

(3, 239, 28)