In [27]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.recurrent import GRUCell

class Encoder(nn.Module):
    input_vocab_size: int
    embedding_dim: int
    enc_units: int

    def setup(self):
        # The embedding layer converts tokens to vectors
        self.embedding = nn.Embed(num_embeddings=self.input_vocab_size, features=self.embedding_dim)

        # The GRU RNN layer processes those vectors sequentially.
        self.gru = nn.scan(GRUCell,
                           variable_broadcast="params",
                           in_axes=0,
                           out_axes=0,
                           split_rngs={'params': False})(features=self.embedding_dim)

    def __call__(self, tokens, initial_state=None):
        # 2. The embedding layer looks up the embedding for each token.
        vectors = self.embedding(tokens)

        # 3. The GRU processes the embedding sequence.
        if initial_state is None:
            initial_state = jnp.zeros((tokens.shape[0], self.enc_units))

        def scan_fn(carry, x):
            state, _ = carry
            new_state, y = self.gru(state, x)
            return (new_state, y), y

        (final_state, output), outputs = jax.lax.scan(scan_fn, (initial_state, None), vectors)

        # 4. Returns the new sequence and its state.
        return output, final_state

# Example usage
input_vocab_size = 10000
embedding_dim = 256
enc_units = 512

encoder = Encoder(input_vocab_size=input_vocab_size, embedding_dim=embedding_dim, enc_units=enc_units)

# Initialize parameters
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
tokens = jax.random.randint(key1, (32, 10), 0, input_vocab_size)  # batch_size=32, seq_length=10
initial_state = jnp.zeros((32, enc_units))

params = encoder.init(key2, tokens, initial_state)['params']
output, state = encoder.apply({'params': params}, tokens, initial_state)

print("Output shape:", output.shape)
print("State shape:", state.shape)


TypeError: scan body function carry input and carry output must have the same pytree structure, but they differ:

The input carry component carry[1] is a <class 'NoneType'> but the corresponding component of the carry output is a <class 'jax._src.core.ShapedArray'>, so their Python types differ.

Revise the function so that the carry output has the same pytree structure as the carry input.