In [7]:
from typing import Optional
import numpy as np


import jax
from jax import numpy as jnp

from flax import struct
from flax.training import train_state, common_utils
from flax import linen as nn
import optax

import tensorflow as tf

In [2]:
key = jax.random.PRNGKey(0)
key, key_ = jax.random.split(key)
a = jax.random.normal(key_, shape=(10, 2))
a

DeviceArray([[ 5.1998299e-01, -2.7347109e-01],
             [ 3.1157458e-01,  7.3022044e-01],
             [ 2.0017810e+00,  4.9277100e-01],
             [ 4.5230365e-01,  8.8437307e-01],
             [ 6.8794307e-04, -7.3253053e-01],
             [ 8.3312637e-01, -5.2632433e-01],
             [-3.3986712e-01, -4.2656019e-01],
             [ 6.6787893e-01, -1.1360155e+00],
             [-3.0554804e-01, -1.9193236e+00],
             [-2.1068285e+00,  5.7977307e-01]], dtype=float32)

In [3]:
key, key_ = jax.random.split(key)
carry = nn.GRUCell.initialize_carry(key, (10, ), 20)
carry

DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.],
             [0., 0

In [9]:
key, key_ = jax.random.split(key)
a = jax.random.categorical(key_, np.asarray([0.1, 0.1, 0.1, 0.1, 0.1]), shape=(10, 5))
a

DeviceArray([[1, 0, 2, 2, 3],
             [4, 1, 1, 4, 1],
             [4, 3, 3, 0, 0],
             [0, 4, 1, 3, 3],
             [4, 2, 0, 4, 2],
             [0, 3, 0, 1, 2],
             [4, 4, 1, 1, 0],
             [3, 0, 0, 3, 2],
             [4, 3, 0, 2, 3],
             [1, 3, 4, 2, 3]], dtype=int32)

In [11]:
class GRU(nn.Module):
    batch_size : int
    hidden_size: int
    @nn.compact
    def __call__(
        self, 
        key, 
        inputs # (batch_size, seq_len, embed_dim)
    ):
        x = nn.GRUCell.initialize_carry(key, (self.batch_size, ), self.hidden_size)
        
        def gru_output(i, x):
            _, x =  nn.GRUCell()(x, inputs[:, i, :])
            return x
        
        for i in range(inputs.shape[1]):
            x =  gru_output(i, x)
        return x


@struct.dataclass
class BCQConfig:
    hidden_size: int
    output_size: int
    embed_dim: int

class QNet(nn.Module):
    embed_dim: int
    output_size: int
    hidden_dim: int
    dropout_rate: float

    @staticmethod
    def initialize_carry(key:jax.random.PRNGKey, batch_size:int, hidden_size:int):
        return nn.GRUCell.initialize_carry(key, (batch_size, ), hidden_size)


    @nn.compact
    def __call__(
        self,
        key:jax.random.PRNGKey,
        inputs:jnp.ndarray,     # (batch_size, seq_len)
        training: Optional[bool]=True
    ):

        x = nn.Embed(self.embed_dim, self.output_size+1)(inputs)
        x = GRU(inputs.shape[0], self.hidden_dim)(key, x)
        x = nn.Dropout(self.dropout_rate)(x, deterministic=not training)
    
        behavior = nn.Dense(self.output_size)(x)
        behavior = nn.softmax(behavior)

        qvalue = nn.Dense(self.output_size)(x)
        return behavior, qvalue
    


In [16]:
key, key1, key2, key3 = jax.random.split(key, 4)
model = QNet(10, 5, 10, 0.5)
params = model.init({"params": key1, "dropout": key3}, key2, jnp.ones((10, 5), dtype=jnp.int32))["params"]