In [4]:
import jax
import jax.numpy as jnp
from jax import grad, jit
from flax import linen as nn
import collections

# NetworkType = collections.namedtuple('network', ['q_values', 'representation'])


class AtariDQNNetwork(nn.Module):
  """The convolutional network used to compute the agent's Q-values."""
  num_actions: int

  @nn.compact
  def __call__(self, x):
    initializer = nn.initializers.xavier_uniform()
    x = x.astype(jnp.float32) / 255.
    before_relu1 = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4),
                kernel_init=initializer)(x)
    relu1 = nn.relu(before_relu1)
    self.sow('intermediates', 'relu1', relu1)
    x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2),
                kernel_init=initializer)(relu1)
    x = nn.relu(x)
    self.sow('intermediates', 'relu2', x)
    x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1),
                kernel_init=initializer)(x)
    x = nn.relu(x)
    representation = x.reshape(-1)  # flatten
    before_last_relu = nn.Dense(features=512, kernel_init=initializer)(representation)
    last_relu = nn.relu(before_last_relu)
    q_values = nn.Dense(features=self.num_actions,
                        kernel_init=initializer)(last_relu)
    return q_values, representation
# Example usage
# Initialize the model
# Assuming input shape is (batch_size, height, width, channels)
input_shape = (2, 84, 84, 4)
model = AtariDQNNetwork(num_actions=10)

rng = jax.random.PRNGKey(0)
# Create a random input tensor
x = jax.random.normal(rng, input_shape)
# Initialize the model parameters
params = model.init(rng, x=x)

intermediates_func = lambda _, name: 'relu' in name

output, variables = model.apply(params, x, mutable=['intermediates'])

# after_relu1, after_last_relu, before_relu1, before_last_relu, q_values = output

# assert after_last_relu == variables['intermediates']['Dense_0']['__call__'][0]


In [7]:
output[1].shape

(15488,)

In [74]:
dict_ = {}
for k, v in dict_.items():
  
  pass

In [66]:
variables['intermediates']['relu1'][0]

Array([[[[2.8390093e-03, 1.3419555e-03, 1.5232379e-03, ...,
          0.0000000e+00, 0.0000000e+00, 7.9099555e-05],
         [9.6281350e-04, 0.0000000e+00, 1.0437160e-03, ...,
          1.2910143e-03, 1.6324209e-03, 0.0000000e+00],
         [0.0000000e+00, 5.2411872e-04, 6.4213009e-04, ...,
          1.5542484e-03, 5.6491315e-04, 0.0000000e+00],
         ...,
         [1.7350663e-03, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 1.0979945e-03],
         [7.0092370e-05, 9.8694570e-04, 8.3725603e-04, ...,
          3.4312049e-03, 2.7857597e-03, 0.0000000e+00],
         [0.0000000e+00, 7.2109204e-04, 5.1860255e-04, ...,
          0.0000000e+00, 0.0000000e+00, 9.6743827e-04]],

        [[0.0000000e+00, 0.0000000e+00, 4.3021061e-04, ...,
          4.2238974e-04, 0.0000000e+00, 8.1538083e-04],
         [0.0000000e+00, 0.0000000e+00, 3.1893165e-03, ...,
          2.4073306e-03, 1.3981439e-03, 0.0000000e+00],
         [0.0000000e+00, 5.0572320e-03, 2.1074673e-03, .

In [42]:
# check if the whole after_relu1 is the same as the output of the model
# variables['intermediates']['Conv_0']['__call__'][0] using any.all()
assert jnp.all(jnp.isclose(before_relu1, variables['intermediates']['Conv_0']['__call__'][0]))

In [70]:
dummy_arg = jnp.ones((64,))
dummy_arg = dummy_arg.astype(jnp.float32) / 255.

jnp.count_nonzero(dummy_arg <= 0.0)

Array(0, dtype=int32)