Skip to content

Commit

Permalink
adapt LSTMCell usage to new RNNCellBase Upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Dec 11, 2023
1 parent d5aa679 commit a0307a3
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions qdax/core/neuroevolution/networks/seq2seq_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __call__(
) -> Tuple[Tuple[Array, Array], Array]:
"""Applies the module."""
lstm_state, is_eos = carry
new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
features = lstm_state[0].shape[-1]
new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x)

def select_carried_state(new_state: Array, old_state: Array) -> Array:
return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
Expand All @@ -51,8 +52,8 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array:
@staticmethod
def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]:
# Use a dummy key since the default state init fn is just zeros.
return nn.LSTMCell.initialize_carry( # type: ignore
jax.random.PRNGKey(0), (batch_size,), hidden_size
return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore
jax.random.PRNGKey(0), (batch_size, hidden_size)
)


Expand Down Expand Up @@ -101,7 +102,10 @@ def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
lstm_state, last_prediction = carry
if not self.teacher_force:
x = last_prediction
lstm_state, y = nn.LSTMCell()(lstm_state, x)

features = lstm_state[0].shape[-1]
new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x)

logits = nn.Dense(features=self.obs_size)(y)

return (lstm_state, logits), (logits, logits)
Expand Down

0 comments on commit a0307a3

Please sign in to comment.