In [13]:
import jax.numpy as jnp
import flax
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
import jax
import matplotlib.pyplot as plt
import sys

In [5]:
class ContinuousQNetwork(nn.Module):

    @nn.compact
    def __call__(self, s, a):
        # x = jnp.concatenate((s, a), axis=-1)
        x = s
        x = nn.silu(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2.0)))(x))
        x = nn.LayerNorm()(x)
        x = nn.silu(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2.0)))(x))
        x = nn.LayerNorm()(x)
        x = nn.silu(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2.0)))(x))
        x = nn.LayerNorm()(x)
        x = nn.silu(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2.0)))(x))
        x = nn.LayerNorm()(x)
        q_vals = nn.Dense(1, kernel_init=orthogonal(1.0))(x)

        return jnp.squeeze(q_vals, axis=-1)

In [6]:
value_network = ContinuousQNetwork()

In [7]:
def plot_value_function(model_forward_fn, params,
                       x1_bounds=(-10, 10), x2_bounds=(-10, 10),
                       num_points=100, title="Value Function"):
    """
    Creates a 2D plot of the value function over the state space.

    Args:
        model_forward_fn: The forward function of your neural network
        params: The trained parameters of your model
        x1_bounds: Tuple of (min, max) for first state dimension
        x2_bounds: Tuple of (min, max) for second state dimension
        num_points: Number of points to evaluate in each dimension
        title: Title for the plot
    """
    # Create a grid of points
    x1 = np.linspace(x1_bounds[0], x1_bounds[1], num_points)
    x2 = np.linspace(x2_bounds[0], x2_bounds[1], num_points)
    X1, X2 = np.meshgrid(x1, x2)

    # Reshape the grid points into a batch of states
    states = jnp.array([[x1, x2] for x1, x2 in
                        zip(X1.flatten(), X2.flatten())])

    # Vectorized prediction using vmap
    batch_predict = jax.vmap(lambda x: model_forward_fn(params, x))
    values = batch_predict(states)

    # Reshape predictions back to grid
    value_grid = values.reshape(num_points, num_points)

    # Create the plot
    plt.figure(figsize=(10, 8))
    plt.contourf(X1, X2, value_grid, levels=20, cmap='viridis')
    plt.colorbar(label='Value')
    plt.xlabel('State Dimension 1')
    plt.ylabel('State Dimension 2')
    plt.title(title)
    plt.grid(True)

    return plt.gcf()

In [16]:
loaded_params = jnp.load("./ersac_critic_params.npy", allow_pickle=True)
print(loaded_params)
# sys.exit()

# Assuming you have already defined your network and loaded params:
def network_forward(params, x):
    # Your network's forward pass implementation
    return value_network.apply(params, x)

# Create the visualization
fig = plot_value_function(network_forward,
                          loaded_params,
                          x1_bounds=(-5, 5),
                          x2_bounds=(-5, 5))
plt.show()

{'params': {'Dense_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed