Skip to content

Commit

Permalink
Changed return order of reps (#17)
Browse files Browse the repository at this point in the history
* Changed return order of reps

- Changed the return order of representations in `featurize.py`
to match the original implementation.
Intentionally left the return order
in `layers.py' as is, because
of the `lax.scan` API.

* Updated readme
  • Loading branch information
ElArkk committed Feb 12, 2020
1 parent d305762 commit 58d88bc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -35,9 +35,9 @@ To generate representations of protein sequences,
pass a list of sequences as strings or a single sequence to `jax_unirep.get_reps`.
It will return a tuple consisting of the following representations for each sequence:

- `h_avg`: Average hidden state of the mLSTM over the whole sequence.
- `h_final`: Final hidden state of the mLSTM
- `c_final`: Final cell state of the mLSTM
- `h_avg`: Average hidden state of the mLSTM over the whole sequence.

From the original paper, `h_avg` is considered the "representation" (or "rep") of the protein sequence.

Expand All @@ -57,7 +57,7 @@ from jax_unirep import get_reps
sequence = "ASDFGHJKL"

# h_avg is the canonical "reps"
h_final, c_final, h_avg = get_reps(sequence)
h_avg, h_final, c_final = get_reps(sequence)
```

And for multiple sequences:
Expand All @@ -68,7 +68,7 @@ from jax_unirep import get_reps
sequences = ["ASDF", "YJKAL", "QQLAMEHALQP"]

# h_avg is the canonical "reps"
h_final, c_final, h_avg = get_reps(sequences)
h_avg, h_final, c_final= get_reps(sequences)

# each of the arrays will be of shape (len(sequences), 1900),
# with the correct order of sequences preserved
Expand Down
22 changes: 11 additions & 11 deletions jax_unirep/featurize.py
Expand Up @@ -25,7 +25,7 @@ def rep_same_lengths(
h_final, c_final, h = mlstm1900(params, embedded_seqs)
h_avg = h.mean(axis=1)

return np.array(h_final), np.array(c_final), np.array(h_avg)
return np.array(h_avg), np.array(h_final), np.array(c_final)


def rep_arbitrary_lengths(
Expand All @@ -43,30 +43,30 @@ def rep_arbitrary_lengths(
"""
order = batch_sequences(seqs)
# TODO: Find a better way to do this, without code triplication
hf_list, cf_list, ha_list = [], [], []
ha_list, hf_list, cf_list = [], [], []
# Each list in `order` contains the indexes of all sequences of a
# given length from the original list of sequences.
for idxs in order:
subset = [seqs[i] for i in idxs]

h_final, c_final, h_avg = rep_same_lengths(subset)
h_avg, h_final, c_final = rep_same_lengths(subset)
ha_list.append(h_avg)
hf_list.append(h_final)
cf_list.append(c_final)
ha_list.append(h_avg)

h_final, c_final, h_avg = (
h_avg, h_final, c_final = (
np.zeros((len(seqs), 1900)),
np.zeros((len(seqs), 1900)),
np.zeros((len(seqs), 1900)),
)
# Re-order generated reps to match sequence order in the original list.
for i, subset in enumerate(order):
for j, rep in enumerate(subset):
h_avg[rep] = ha_list[i][j]
h_final[rep] = hf_list[i][j]
c_final[rep] = cf_list[i][j]
h_avg[rep] = ha_list[i][j]

return h_final, c_final, h_avg
return h_avg, h_final, c_final


def get_reps(
Expand Down Expand Up @@ -105,8 +105,8 @@ def get_reps(
# 1. All sequences in the list have the same length
# 2. There are sequences of different lengths in the list
if len(set([len(s) for s in seqs])) == 1:
h_final, c_final, h_avg = rep_same_lengths(seqs)
return h_final, c_final, h_avg
h_avg, h_final, c_final = rep_same_lengths(seqs)
return h_avg, h_final, c_final
else:
h_final, c_final, h_avg = rep_arbitrary_lengths(seqs)
return h_final, c_final, h_avg
h_avg, h_final, c_final = rep_arbitrary_lengths(seqs)
return h_avg, h_final, c_final

0 comments on commit 58d88bc

Please sign in to comment.