From 58d88bc52f910f3af847f3380a8c7346474f5b13 Mon Sep 17 00:00:00 2001 From: Arkadij Kummer <43340666+ElArkk@users.noreply.github.com> Date: Wed, 12 Feb 2020 21:19:48 +0100 Subject: [PATCH] Changed return order of reps (#17) * Changed return order of reps - Changed the return order of representations in `featurize.py` to match the original implementation. Intentionally left the return order in `layers.py' as is, because of the `lax.scan` API. * Updated readme --- README.md | 6 +++--- jax_unirep/featurize.py | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 900ddfee..60cafae1 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,9 @@ To generate representations of protein sequences, pass a list of sequences as strings or a single sequence to `jax_unirep.get_reps`. It will return a tuple consisting of the following representations for each sequence: +- `h_avg`: Average hidden state of the mLSTM over the whole sequence. - `h_final`: Final hidden state of the mLSTM - `c_final`: Final cell state of the mLSTM -- `h_avg`: Average hidden state of the mLSTM over the whole sequence. From the original paper, `h_avg` is considered the "representation" (or "rep") of the protein sequence. @@ -57,7 +57,7 @@ from jax_unirep import get_reps sequence = "ASDFGHJKL" # h_avg is the canonical "reps" -h_final, c_final, h_avg = get_reps(sequence) +h_avg, h_final, c_final = get_reps(sequence) ``` And for multiple sequences: @@ -68,7 +68,7 @@ from jax_unirep import get_reps sequences = ["ASDF", "YJKAL", "QQLAMEHALQP"] # h_avg is the canonical "reps" -h_final, c_final, h_avg = get_reps(sequences) +h_avg, h_final, c_final= get_reps(sequences) # each of the arrays will be of shape (len(sequences), 1900), # with the correct order of sequences preserved diff --git a/jax_unirep/featurize.py b/jax_unirep/featurize.py index 0a6357d1..4ee3c159 100644 --- a/jax_unirep/featurize.py +++ b/jax_unirep/featurize.py @@ -25,7 +25,7 @@ def rep_same_lengths( h_final, c_final, h = mlstm1900(params, embedded_seqs) h_avg = h.mean(axis=1) - return np.array(h_final), np.array(c_final), np.array(h_avg) + return np.array(h_avg), np.array(h_final), np.array(c_final) def rep_arbitrary_lengths( @@ -43,18 +43,18 @@ def rep_arbitrary_lengths( """ order = batch_sequences(seqs) # TODO: Find a better way to do this, without code triplication - hf_list, cf_list, ha_list = [], [], [] + ha_list, hf_list, cf_list = [], [], [] # Each list in `order` contains the indexes of all sequences of a # given length from the original list of sequences. for idxs in order: subset = [seqs[i] for i in idxs] - h_final, c_final, h_avg = rep_same_lengths(subset) + h_avg, h_final, c_final = rep_same_lengths(subset) + ha_list.append(h_avg) hf_list.append(h_final) cf_list.append(c_final) - ha_list.append(h_avg) - h_final, c_final, h_avg = ( + h_avg, h_final, c_final = ( np.zeros((len(seqs), 1900)), np.zeros((len(seqs), 1900)), np.zeros((len(seqs), 1900)), @@ -62,11 +62,11 @@ def rep_arbitrary_lengths( # Re-order generated reps to match sequence order in the original list. for i, subset in enumerate(order): for j, rep in enumerate(subset): + h_avg[rep] = ha_list[i][j] h_final[rep] = hf_list[i][j] c_final[rep] = cf_list[i][j] - h_avg[rep] = ha_list[i][j] - return h_final, c_final, h_avg + return h_avg, h_final, c_final def get_reps( @@ -105,8 +105,8 @@ def get_reps( # 1. All sequences in the list have the same length # 2. There are sequences of different lengths in the list if len(set([len(s) for s in seqs])) == 1: - h_final, c_final, h_avg = rep_same_lengths(seqs) - return h_final, c_final, h_avg + h_avg, h_final, c_final = rep_same_lengths(seqs) + return h_avg, h_final, c_final else: - h_final, c_final, h_avg = rep_arbitrary_lengths(seqs) - return h_final, c_final, h_avg + h_avg, h_final, c_final = rep_arbitrary_lengths(seqs) + return h_avg, h_final, c_final