Skip to content
Reimplementation of the UniRep protein featurization model.
Jupyter Notebook Python HTML TeX CSS Shell Makefile
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Type Name Latest commit message Commit time
Failed to load latest commit information.
jax_unirep Changed return order of reps (#17) Feb 12, 2020
paper Paper (#20) Feb 18, 2020
scripts make format Feb 4, 2020
tests make format Feb 4, 2020
.gitignore More work being done on paper (#19) Feb 13, 2020
.here Repackaged things to ship weights properly Jan 22, 2020
.travis.yml Fixed copy command Feb 4, 2020
LICENSE Initial commit Jan 21, 2020
Makefile Started reimplementation paper (#8) Feb 4, 2020 Updated README Feb 13, 2020
environment.yml Started reimplementation paper (#8) Feb 4, 2020 make format Feb 4, 2020

Build Status


Reimplementation of the UniRep protein featurization model in JAX.

The UniRep model was developed in George Church's lab, see the original publication here (bioRxiv) or here (Nature Methods), as well as the repository containing the original model.

The idea to reimplement the TF-based model in the much lighter JAX framework was coined by Eric Ma, who also developed a first version of it inside his functional deep-learning library fundl.

This repo is a self-contained version of the UniRep model (so far only the 1900 hidden-unit mLSTM), adapted and extended from fundl.


Ensure that your compute environment allows you to run JAX code. (A modern Linux or macOS with a GLIBC>=2.23 is probably necessary.)

For now, jax-unirep is available by pip installing from source.

Installation from GitHub:

pip install git+


To generate representations of protein sequences, pass a list of sequences as strings or a single sequence to jax_unirep.get_reps. It will return a tuple consisting of the following representations for each sequence:

  • h_avg: Average hidden state of the mLSTM over the whole sequence.
  • h_final: Final hidden state of the mLSTM
  • c_final: Final cell state of the mLSTM

From the original paper, h_avg is considered the "representation" (or "rep") of the protein sequence.

Only valid amino acid sequence letters belonging to the set:


are allowed as inputs to get_reps. They may be passed in as a single string or an iterable of strings, and need not necessarily be of the same length.

In Python code, for a single sequence:

from jax_unirep import get_reps

sequence = "ASDFGHJKL"

# h_avg is the canonical "reps"
h_avg, h_final, c_final = get_reps(sequence)

And for multiple sequences:

from jax_unirep import get_reps

sequences = ["ASDF", "YJKAL", "QQLAMEHALQP"]

# h_avg is the canonical "reps"
h_avg, h_final, c_final= get_reps(sequences)

# each of the arrays will be of shape (len(sequences), 1900),
# with the correct order of sequences preserved

More Details

To read more about how we reimplemented the model in JAX, we wrote it up. Both the HTML and PDF are available.


All the model weights are licensed under the terms of Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

Otherwise the code in this repository is licensed under the terms of GPL v3.

You can’t perform that action at this time.