## Jax-unirep <br> (reimplementation of the UniRep protein featurization model in JAX)

## Fine tuning for GPCR family

In [1]:
from jax.random import PRNGKey
from jax.experimental.stax import Dense, Softmax, serial

from jax_unirep import fit
from jax_unirep.evotuning_models import mlstm64
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMHiddenStates



In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set()

In [3]:
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegressionCV
import warnings
from warnings import simplefilter
from sklearn.exceptions import ConvergenceWarning
simplefilter(action='ignore', category=ConvergenceWarning)
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR

## Sequences (from [InterPro](https://www.ebi.ac.uk/interpro/entry/InterPro/IPR000276/protein/reviewed/#table))

In [10]:
#!/usr/bin/env python3

# standard library modules
import sys, errno, re, json, ssl
from urllib import request
from urllib.error import HTTPError
from time import sleep
sequences = []        # reviewed GPCRs
sys.stdout = open('stdoutf.csv', 'a')
sys.stdout.write("name" + "," + "class"  + "," + "organism" + ","  + "sequence" + "\n")
def output_list(BASE_URL, cl):
  #disable SSL verification to avoid config issues
  context = ssl._create_unverified_context()

  next = BASE_URL
  last_page = False

  
  attempts = 0
  while next:
    try:
      req = request.Request(next, headers={"Accept": "application/json"})
      res = request.urlopen(req, context=context)
      # If the API times out due a long running query
      if res.status == 408:
        # wait just over a minute
        sleep(61)
        # then continue this loop with the same URL
        continue
      elif res.status == 204:
        #no data so leave loop
        break
      payload = json.loads(res.read().decode())
      next = payload["next"]
      attempts = 0
      if not next:
        last_page = True
    except HTTPError as e:
      if e.code == 408:
        sleep(61)
        continue
      else:
        # If there is a different HTTP error, it wil re-try 3 times before failing
        if attempts < 3:
          attempts += 1
          sleep(61)
          continue
        else:
          sys.stderr.write("LAST URL: " + next)
          raise e
    for i, item in enumerate(payload["results"]):
         seq = item["extra_fields"]["sequence"]
         if item["metadata"]["name"].find("Olfactory") == -1: # exclude Olfactory 
            sys.stdout.write(item["metadata"]["name"] + "," + str(cl) + "," + item["metadata"]["source_organism"]["scientificName"] + ","  + seq + "\n")
      # Don't overload the server, give it time before asking for more
    if next:
      sleep(1)


#### Links for different classes:

In [11]:
urls = []
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR000276/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR000832/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR000337/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/cdd/CD14939/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR015526/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR007960/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR009637/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR029723/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR018781/?page_size=200&extra_fields=sequence")
urls.append("https://www.ebi.ac.uk:443/interpro/api/protein/reviewed/entry/InterPro/IPR001414/?page_size=200&extra_fields=sequence")


In [12]:
for i, url in enumerate(urls):
    if i < 6:
        output_list(url, i + 1)
    else:
        output_list(url, 7)

### Adding TMcore

In data folder:

- human_msa.fasta msa of the cut sequences from gpcrdb.org
- msa.fasta msa of all the cut sequences from Interpro
- stdoutf.csv 

In [5]:
from Bio import SeqIO
fields = []
df = pd.read_csv('stdoutf.csv')
for record in SeqIO.parse("msa.fasta", "fasta"):
       fields.append(str(record.seq))
df["TMcore"]  = fields

In [6]:
df.head()

Unnamed: 0,name,class,organism,sequence,TMcore
0,Gastrin/cholecystokinin type B receptor,1,Sus scrofa,MELLKLNRSLPGPGPGAALCRPEGPLLNGSGAGNLSCEPPRIRGAG...,MNLSCEPPRIRGAGTRELELAVRITLYAA-IFLMSVAGNVLIIVVL...
1,G-protein coupled receptor 22,1,Danio rerio,MESMPSSLTHQRFGLLNKHLTRTGNTREGRMHTPPVLGFQAIMSNV...,MEEPLDFEMDLKTPYPVSFQVSLTGFLML-EIVLGLSSNLTVLALY...
2,Apelin receptor B,1,Danio rerio,MNAMDNMTADYSPDYFDDAVNSSMCEYDEWEPSYSLIPVLYMLIFI...,MNADAVNSSMCEYDEWEPSYSLIPVLYML-IFILGLTGNGVVIFTV...
3,N-arachidonyl glycine receptor,1,Rattus norvegicus,MAIPSNRDQLALSNGSHPEEYKIAALVFYSCIFLIGLLVNVTALWV...,MASNRDQLALSNGSHPEEYKIAALVFYSC-IFLIGLLVNVTALWVF...
4,Neuropeptide CCHamide-1 receptor,1,Drosophila melanogaster,MIANLVSMETDLAMNIGLDTSGEAPTALPPMPNVTETLWDLAMVVS...,MSELVTTETPYVPYGRRPETYIVPILFAL-IFVVGVLGNGTLIVVF...


## Pre-build model architecture 

 `mlstm64` model.


In [7]:
init_fun, apply_fun = mlstm64()

# The init_func always requires a PRNGKey,
# and input_shape should be set to (-1, 26)
# This creates randomly initialized parameters
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))


# Now we tune the params.
tuned_params = fit(df["TMcore"], n_epochs=2, model_func=apply_fun, params=params)

HBox(children=(HTML(value='right-padding sequences'), FloatProgress(value=0.0, max=2484.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

INFO:evotuning:Random batching done: All sequences padded to max sequence length of 932





HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=198.0), HTML(value='')))

INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 0: Estimated average loss: 0.16762538254261017. 


created directory at temp


INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 1: Estimated average loss: 0.16596537828445435. 





#### I've also tried representations by mlstm256 (256-dimensional vectors) and in fact I found no big difference according to the classification results (see results below)

In [8]:
from functools import partial
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from jax import vmap

from jax_unirep.errors import SequenceLengthsError
from jax_unirep.utils import (
    batch_sequences,
    get_embeddings,
    load_params,
    validate_mLSTM_params,
)


In [9]:
# instantiate the mLSTM
def rep_same_lengths(
    seqs: Iterable[str], params: Dict, apply_fun: Callable
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    This function generates representations of protein sequences that have the same length
    """

    embedded_seqs = get_embeddings(seqs)

    h_final, c_final, h = vmap(partial(apply_fun, params))(embedded_seqs)
    h_avg = h.mean(axis=1)

    return np.array(h_avg), np.array(h_final), np.array(c_final)


In [10]:
def rep_arbitrary_lengths(
    seqs: Iterable[str], params: Dict, apply_fun: Callable, mlstm_size: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    This function generates representations of protein sequences of arbitrary length
    """
    order = batch_sequences(seqs)
    # TODO: Find a better way to do this, without code triplication
    ha_list, hf_list, cf_list = [], [], []
    # Each list in `order` contains the indexes of all sequences of a
    # given length from the original list of sequences.
    for idxs in order:
        subset = [seqs[i] for i in idxs]

        h_avg, h_final, c_final = rep_same_lengths(subset, params, apply_fun)
        ha_list.append(h_avg)
        hf_list.append(h_final)
        cf_list.append(c_final)

    h_avg, h_final, c_final = (
        np.zeros((len(seqs), mlstm_size)),
        np.zeros((len(seqs), mlstm_size)),
        np.zeros((len(seqs), mlstm_size)),
    )
    # Re-order generated reps to match sequence order in the original list.
    for i, subset in enumerate(order):
        for j, rep in enumerate(subset):
            h_avg[rep] = ha_list[i][j]
            h_final[rep] = hf_list[i][j]
            c_final[rep] = cf_list[i][j]

    return h_avg, h_final, c_final


In [11]:
def get_reps(
    seqs: Union[str, Iterable[str]],
    params: Optional[Dict] = tuned_params[1],
    mlstm_size: Optional[str] = 64,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Each element of the output 3-tuple is a `np.array`
    of shape (n_input_sequences, mlstm_size):
    - `h_avg`: Average hidden state of the mLSTM over the whole sequence.
    - `h_final`: Final hidden state of the mLSTM
    - `c_final`: Final cell state of the mLSTM
"""
    _, apply_fun = mLSTM(output_dim=mlstm_size)
    if params is None:
        params = tuned_params[1]
    # Check that params have correct keys and shapes
    validate_mLSTM_params(params, n_outputs=mlstm_size)
    # If single string sequence is passed, package it into a list
    if isinstance(seqs, str):
        seqs = [seqs]
    # Make sure list is not empty
    if len(seqs) == 0:
        raise SequenceLengthsError("Cannot pass in empty list of sequences.")

    # Differentiate between two cases:
    # 1. All sequences in the list have the same length
    # 2. There are sequences of different lengths in the list
    if len(set([len(s) for s in seqs])) == 1:
        h_avg, h_final, c_final = rep_same_lengths(
            seqs,
            params,
            apply_fun,
        )
        return h_avg, h_final, c_final
    else:
        h_avg, h_final, c_final = rep_arbitrary_lengths(
            seqs, params, apply_fun, mlstm_size
        )
        return h_avg, h_final, c_final

### Representations for GPCRs TM cores

In [12]:
fields1 = []
for seq in df["TMcore"]:
    fields1.append(get_reps(seq)[0])
df["embs"]  = fields1

In [13]:
df["embs"]

0       [[0.008311076, 0.002365658, -0.0029429803, -0....
1       [[0.008278446, 0.0025310859, -0.0031384525, -0...
2       [[0.008318783, 0.002004366, -0.0031882334, -0....
3       [[0.008310798, 0.002047191, -0.0032696477, -0....
4       [[0.008311579, 0.002151715, -0.0031045913, -0....
                              ...                        
2479    [[0.008304645, 0.0021230436, -0.0033034435, -0...
2480    [[0.0083246315, 0.0013236668, -0.0035572413, -...
2481    [[0.008303336, 0.0021339504, -0.0032998412, -0...
2482    [[0.00831899, 0.0019768146, -0.0031942346, -0....
2483    [[0.008321734, 0.0020802803, -0.0030481943, -0...
Name: embs, Length: 2484, dtype: object

### Classifications

In [14]:
import argparse
import sys
import os
import gzip
from collections import Counter

import numpy as np
from scipy.spatial.distance import cosine
from Bio import SeqIO

from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
from sklearn import preprocessing
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, average_precision_score, coverage_error
from sklearn.model_selection import cross_val_score


In [15]:
X = []
for seq in df['embs']:
    X.append(seq[0])

### Predict the class of GPCR

In [16]:
cls = []
for cl in df['class']:
     cls.append(cl)

In [17]:
from collections import Counter
print(Counter(cls).keys()) 
Counter(cls).values() 

dict_keys([1, 2, 3, 4, 5, 6, 7])


dict_values([1882, 191, 73, 4, 81, 214, 39])

In [41]:
X = np.asarray(X)

In [42]:
vectors_train, vectors_test, cls_train, cls_test = train_test_split(X, 
                                                                                  cls, 
                                                                                  test_size=0.1, random_state=42)

In [52]:
import xgboost
xgb = xgboost.XGBClassifier(objective='multi:softprob')


In [44]:
xgb.fit(vectors_train, cls_train)
test_preds = xgb.predict(vectors_test)





In [45]:
accuracy_score(cls_test, test_preds)

0.9518072289156626

In [35]:
kn = KNeighborsClassifier()
param_grid = {'n_neighbors': np.arange(3, 25)}
models = {"lr": LogisticRegressionCV(cv=5, random_state=0, max_iter=500,solver='lbfgs', multi_class='ovr'),
          "kn": GridSearchCV(kn, param_grid, cv=5)}
for model in models:
    models[model].fit(vectors_train, cls_train)
    y_predicted = models[model].predict(vectors_test)
    print('Accuracy is {}'.format(accuracy_score(cls_test, y_predicted)))
    print(np.unique(y_predicted))
models["kn"].best_params_



Accuracy is 0.7670682730923695
[1 3]




Accuracy is 0.9437751004016064
[1 2 3 5 6 7]


{'n_neighbors': 3}

### Predict the source_organism

In [46]:
orgs = []
for org in df['organism']:
     orgs.append(org)

label_encoder = preprocessing.LabelEncoder()
label_encoder.fit(orgs)      
orgs_encoded = np.array(label_encoder.transform(orgs), dtype=np.int32)

In [47]:
print(Counter(orgs).keys())
Counter(orgs).values()

dict_keys(['Sus scrofa', 'Danio rerio', 'Rattus norvegicus', 'Drosophila melanogaster', 'Mus musculus', 'Oryctolagus cuniculus', 'Mustela putorius furo', 'Bos taurus', 'Homo sapiens', 'Xenopus tropicalis', 'Micropogonias undulatus', 'Salmo salar', 'Branchiostoma floridae', 'Macaca fascicularis', 'Xenopus laevis', 'Caenorhabditis elegans', 'Human cytomegalovirus (strain Merlin)', 'Macaca mulatta', 'Chilo suppressalis', 'Tribolium castaneum', 'Lymnaea stagnalis', 'Manduca sexta', 'Canis lupus familiaris', 'Callithrix jacchus', 'Felis catus', 'Cavia porcellus', 'Rat cytomegalovirus (strain Maastricht)', 'Gallus gallus', 'Saimiri boliviensis boliviensis', 'Conger conger', 'Mizuhopecten yessoensis', 'Sepia officinalis', 'Lacunicambarus ludovicianus', 'Cambarellus shufeldtii', 'Orconectes virilis', 'Procambarus milleri', 'Cambarus hubrichti', 'Cambarus maculatus', 'Orconectes australis', 'Procambarus orcinus', 'Procambarus seminolae', 'Equus caballus', 'Odocoileus virginianus virginianus', '

dict_values([61, 50, 308, 49, 417, 33, 3, 106, 426, 8, 1, 2, 1, 26, 47, 39, 3, 56, 1, 1, 4, 4, 53, 5, 13, 41, 1, 40, 3, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 1, 1, 12, 1, 1, 4, 1, 1, 27, 1, 1, 8, 1, 10, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 7, 1, 5, 1, 1, 1, 3, 1, 2, 35, 1, 1, 1, 1, 7, 1, 3, 2, 5, 1, 6, 1, 2, 1, 5, 1, 1, 2, 10, 1, 2, 2, 1, 1, 1, 3, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 39, 64, 1, 1, 5, 1, 1, 1, 1, 25, 3, 1, 1, 2, 1, 1, 31, 1, 12, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1, 1, 4, 2, 1, 4, 1, 1, 1, 1, 2, 2, 2, 2, 4, 3, 2, 1, 1, 1, 1, 1, 8, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 5, 3, 1, 1, 3, 1, 

In [48]:
vectors_train, vectors_test, orgs_train, orgs_test = train_test_split(X, 
                                                                                  orgs_encoded, 
                                                                                  test_size=0.1, random_state=42)

In [49]:
kn = KNeighborsClassifier()
param_grid = {'n_neighbors': np.arange(2, 30)}
models = {"lr": LogisticRegressionCV(cv=5, random_state=0, max_iter=500,solver='lbfgs', multi_class='ovr'),
          "kn": GridSearchCV(kn, param_grid, cv=5)}
for model in models:
    models[model].fit(vectors_train, orgs_train)
    y_predicted = models[model].predict(vectors_test)
    print('Accuracy is {}'.format(accuracy_score(orgs_test, y_predicted)))
    print(np.unique(y_predicted))



Accuracy is 0.18072289156626506
[128]




Accuracy is 0.14457831325301204
[  2   3  12  14  16  17  22  25  31  36  51  57  90  97 119 128 167 194
 219 248 287 316]


In [50]:
models["kn"].best_params_

{'n_neighbors': 28}

In [53]:
xgb.fit(vectors_train, orgs_train)
test_preds = xgb.predict(vectors_test)
accuracy_score(orgs_test, test_preds)





0.08032128514056225

### 256-dimensional representations results

#### Predict the class:

Accuracy is 0.7951685958731757
[1 3 6]

Accuracy is 0.954906391545043
[1 2 3 4 5 6 7]
{'n_neighbors': 3}

#### Predict the source organism:

Accuracy is 0.17262204328132863
[128]

Accuracy is 0.21389028686462003
[  2   3   4   5  14  16  17  22  31  36  41  48  54  56  58  67  72  73
  90  97 119 124 128 167 181 194 219 220 223 234 248 287 316]
{'n_neighbors': 29}