## Jax-unirep <br> (reimplementation of the UniRep protein featurization model in JAX)

## Fine tuning for GPCR family

In [1]:
from jax.random import PRNGKey
from jax.experimental.stax import Dense, Softmax, serial

from jax_unirep import fit
from jax_unirep.evotuning_models import mlstm64
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMHiddenStates



## Sequences (from [InterPro](https://www.ebi.ac.uk/interpro/entry/InterPro/IPR000276/protein/reviewed/#table))

In [2]:
#!/usr/bin/env python3

# standard library modules
import sys, errno, re, json, ssl
from urllib import request
from urllib.error import HTTPError
from time import sleep
sequences = []        # reviewed GPCRs
BASE_URL = "https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR000276/?page_size=200&extra_fields=sequence"

HEADER_SEPARATOR = "|"
LINE_LENGTH = 80

def output_list():
  #disable SSL verification to avoid config issues
  context = ssl._create_unverified_context()

  next = BASE_URL
  last_page = False

  
  attempts = 0
  while next:
    try:
      req = request.Request(next, headers={"Accept": "application/json"})
      res = request.urlopen(req, context=context)
      # If the API times out due a long running query
      if res.status == 408:
        # wait just over a minute
        sleep(61)
        # then continue this loop with the same URL
        continue
      elif res.status == 204:
        #no data so leave loop
        break
      payload = json.loads(res.read().decode())
      next = payload["next"]
      attempts = 0
      if not next:
        last_page = True
    except HTTPError as e:
      if e.code == 408:
        sleep(61)
        continue
      else:
        # If there is a different HTTP error, it wil re-try 3 times before failing
        if attempts < 3:
          attempts += 1
          sleep(61)
          continue
        else:
          sys.stderr.write("LAST URL: " + next)
          raise e

    for i, item in enumerate(payload["results"]):
         if item["metadata"]["name"].find("Olfactory") == -1: # exclude Olfactory 
            seq = item["extra_fields"]["sequence"]
            sequences.append(seq)
            sys.stdout.write(seq + "\n")
      # Don't overload the server, give it time before asking for more
    if next:
      sleep(1)

if __name__ == "__main__":
  output_list()


MELLKLNRSLPGPGPGAALCRPEGPLLNGSGAGNLSCEPPRIRGAGTRELELAVRITLYAAIFLMSVAGNVLIIVVLGLSRRLRTVTNAFLLSLAVSDLLLAVACMPFTLLPNLMGTFIFGTVVCKAVSYFMGVSVSVSTLSLVAIALERYSAICRPLQARVWQTRSHAARVIVATWMLSGLLMVPYPVYTAVQPAGPRVLQCMHRWPSARIRQTWSVLLLLLLFFVPGVVMAVAYGLISRELYLGLRFDGDSDCESQSQVGSQGGLPGGAGQGPAHPNGHCRSETRLAGEDGDGCYVQLPRSRPALEMSALTAPTPGPGSGPRPAQAKLLAKKRVVRMLLVIVVLFFLCWLPVYSANTWRAFDGPGAHRALSGAPISFIHLLSYASACVNPLVYCFMHRRFRQACLDTCARCCPRPPRARPRPLPDEDPPTPSIASLSRLSYTTISTLGPG
MESMPSSLTHQRFGLLNKHLTRTGNTREGRMHTPPVLGFQAIMSNVTVLDNIEPLDFEMDLKTPYPVSFQVSLTGFLMLEIVLGLSSNLTVLALYCMKSNLVSSVSNIVTMNLHVLDVLVCVGCIPLTIVVVLLPLEGNNALICCFHEACVSFASVATAANVLAITLDRYDISVRPANRVLTMGRAVALLGSIWALSFFSFLVPFIEEGFFSQAGNERNQTEAEEPSNEYYTELGLYYHLLAQIPIFFFTAVVMLVTYYKILQALNIRIGTRFHSVPKKKPRKKKTISMTSTQPESTDASQSSAGRNAPLGMRTSVSVIIALRRAVKRHRERRERQKRVFRMSLLIISTFLLCWTPITVLNTVILSVGPSNFTVRLRLGFLVMAYGTTIFHPLLYAFTRQKFQKVLKSKMKKRVVSVVEADPMPNNVVIHNSWIDPKRNKKVTFEETEVRQKCLSSEDVE
MNAMDNMTADYSPDYFDDAVNSSMCEYDEWEPSYSLIPVLYMLIFILGLTGNGVVIFTVWRAQSKRRAADVYIGNLALADLTFVVT

In [3]:
len(sequences)

1882

## Pre-build model architecture 

 `mlstm64` model.


In [4]:
init_fun, apply_fun = mlstm64()

# The init_func always requires a PRNGKey,
# and input_shape should be set to (-1, 26)
# This creates randomly initialized parameters
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))


# Now we tune the params.
tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)

HBox(children=(HTML(value='right-padding sequences'), FloatProgress(value=0.0, max=1882.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

INFO:evotuning:Random batching done: All sequences padded to max sequence length of 1256





HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=150.0), HTML(value='')))

INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 0: Estimated average loss: 0.16765545308589935. 
INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 1: Estimated average loss: 0.16634780168533325. 





In [26]:
tuned_params[1]

{'b': DeviceArray([-3.01743811e-03,  9.55921691e-03, -6.70112204e-03,
              -1.34719093e-03, -7.70191802e-03, -4.02246689e-04,
              -2.07949448e-02, -1.41131552e-02,  9.23951250e-03,
              -1.39025704e-03, -7.83124566e-03,  3.68304714e-03,
               1.35498820e-02, -9.67398472e-03,  6.68355171e-03,
               1.34443054e-02,  1.06319850e-02, -1.51176313e-02,
               8.21476802e-03, -1.47431986e-02, -1.52692874e-03,
              -2.89831543e-03,  7.15851970e-03,  2.21278500e-02,
               2.11675558e-03,  1.23528698e-02,  2.91329785e-03,
              -1.32503696e-02, -2.71233311e-03,  3.19028017e-03,
              -2.35173106e-03, -9.91042610e-03,  1.19520808e-02,
               1.79033098e-03,  1.93504374e-02,  1.36826215e-02,
               1.53878480e-02,  2.99599650e-03,  1.17149763e-02,
               3.78699158e-04, -3.97172477e-03, -4.93948348e-03,
              -1.72145851e-02,  6.77454192e-03, -4.92998865e-03,
               3.304

In [6]:
from functools import partial
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from jax import vmap

from jax_unirep.errors import SequenceLengthsError
from jax_unirep.utils import (
    batch_sequences,
    get_embeddings,
    load_params,
    validate_mLSTM_params,
)


In [7]:
# instantiate the mLSTM
def rep_same_lengths(
    seqs: Iterable[str], params: Dict, apply_fun: Callable
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    This function generates representations of protein sequences that have the same length
    """

    embedded_seqs = get_embeddings(seqs)

    h_final, c_final, h = vmap(partial(apply_fun, params))(embedded_seqs)
    h_avg = h.mean(axis=1)

    return np.array(h_avg), np.array(h_final), np.array(c_final)


In [8]:
def rep_arbitrary_lengths(
    seqs: Iterable[str], params: Dict, apply_fun: Callable, mlstm_size: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    This function generates representations of protein sequences of arbitrary length
    """
    order = batch_sequences(seqs)
    # TODO: Find a better way to do this, without code triplication
    ha_list, hf_list, cf_list = [], [], []
    # Each list in `order` contains the indexes of all sequences of a
    # given length from the original list of sequences.
    for idxs in order:
        subset = [seqs[i] for i in idxs]

        h_avg, h_final, c_final = rep_same_lengths(subset, params, apply_fun)
        ha_list.append(h_avg)
        hf_list.append(h_final)
        cf_list.append(c_final)

    h_avg, h_final, c_final = (
        np.zeros((len(seqs), mlstm_size)),
        np.zeros((len(seqs), mlstm_size)),
        np.zeros((len(seqs), mlstm_size)),
    )
    # Re-order generated reps to match sequence order in the original list.
    for i, subset in enumerate(order):
        for j, rep in enumerate(subset):
            h_avg[rep] = ha_list[i][j]
            h_final[rep] = hf_list[i][j]
            c_final[rep] = cf_list[i][j]

    return h_avg, h_final, c_final


In [9]:
def get_reps(
    seqs: Union[str, Iterable[str]],
    params: Optional[Dict] = tuned_params[1],
    mlstm_size: Optional[str] = 64,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Each element of the output 3-tuple is a `np.array`
    of shape (n_input_sequences, mlstm_size):
    - `h_avg`: Average hidden state of the mLSTM over the whole sequence.
    - `h_final`: Final hidden state of the mLSTM
    - `c_final`: Final cell state of the mLSTM
"""
    _, apply_fun = mLSTM(output_dim=mlstm_size)
    if params is None:
        params = tuned_params[1]
    # Check that params have correct keys and shapes
    validate_mLSTM_params(params, n_outputs=mlstm_size)
    # If single string sequence is passed, package it into a list
    if isinstance(seqs, str):
        seqs = [seqs]
    # Make sure list is not empty
    if len(seqs) == 0:
        raise SequenceLengthsError("Cannot pass in empty list of sequences.")

    # Differentiate between two cases:
    # 1. All sequences in the list have the same length
    # 2. There are sequences of different lengths in the list
    if len(set([len(s) for s in seqs])) == 1:
        h_avg, h_final, c_final = rep_same_lengths(
            seqs,
            params,
            apply_fun,
        )
        return h_avg, h_final, c_final
    else:
        h_avg, h_final, c_final = rep_arbitrary_lengths(
            seqs, params, apply_fun, mlstm_size
        )
        return h_avg, h_final, c_final

#### I tested on Olfactory receptor 4N4C, since olfactory receptors were not included in the training set

In [13]:
sequence = "MKIANNTVVTEFILLGLTQSQDIQLLVFVLILIFYLIILPGNFLIIFTIRSDPGLTAPLYLFLGNLAFLDASYSFIVAPRMLVDFLSEKKVISYRGCITQLFFLHFLGGGEGLLLVVMAFDRYIAICRPLHCSTVMNPRACYAMMLALWLGGFVHSIIQVVLILRLPFCGPNQLDNFFCDVRQVIKLACTDMFVVELLMVFNSGLMTLLCFLGLLASYAVILCHVRRAASEGKNKAMSTCTTRVIIILLMFGPAIFIYICPFRALPADKMVSLFHTVIFPLMNPMIYTLRNQEVKTSMKRLLSRHVVCQVDFIIRN"
# h_avg is the canonical "reps"
h_avg, h_final, c_final = get_reps(sequence)

In [14]:
h_avg.shape

(1, 64)

In [15]:
h_avg

array([[ 5.84673276e-03,  7.40852859e-03,  2.05500424e-03,
        -2.51952559e-03, -8.31959397e-03, -5.05171076e-04,
        -4.78189997e-03, -6.77470304e-03,  4.48500505e-03,
         5.62339416e-03,  1.48779480e-02,  1.06186653e-02,
         1.45087978e-02,  1.51437828e-02, -1.39320502e-03,
        -5.76475309e-03, -6.60806196e-03, -1.32840429e-03,
         1.33311504e-03,  1.81338168e-03, -2.27392204e-02,
         5.71414968e-03,  4.57233022e-04, -4.84840712e-03,
         1.37638850e-02,  4.60113119e-03, -9.59245302e-03,
         3.76701908e-04, -3.23476363e-03, -6.59611906e-05,
        -2.77698021e-02, -5.23280934e-04,  1.29870679e-02,
         3.32557363e-03,  3.82985687e-03, -1.70661544e-03,
        -9.93605983e-03,  5.74699044e-03,  9.19170026e-03,
        -4.03180625e-03,  1.01292115e-02, -1.68675166e-02,
         1.13263931e-02, -2.15364760e-03, -3.10993567e-03,
         1.53990567e-03,  7.65566109e-03, -4.31377487e-03,
         9.25498083e-03, -3.07183582e-02,  2.28487607e-0