In [3]:
print("test")
import jax
import jax.numpy as jnp
from jax import random, grad, vmap
from flax import linen as nn
import optax

import gymnasium as gym

env = gym.make(
    "LunarLander-v2",
    continuous = False,
    gravity = -10.0,
    enable_wind = False,
    wind_power = 15.0,
    turbulence_power = 1.5
    # render_mode="rgb_array"
)

class FCN(nn.Module):
    n_observations: int
    n_actions: int
    fc1_dims: int = 256
    fc2_dims: int = 256

    def setup(self):
        self.layer1 = nn.Dense(self.fc1_dims)
        self.layer2 = nn.Dense(self.fc2_dims)
        self.layer3 = nn.Dense(self.n_actions)

    def __call__(self, x):
        x = nn.relu(self.layer1(x))
        x = nn.relu(self.layer2(x))
        logits = self.layer3(x)
        return logits  # Changed softmax to logits here.

class Agent():
    def __init__(self, env):
        self.network = FCN(env.observation_space.shape[0], env.action_space.n)
        self.optimizer = optax.adam(0.01)
        self.rewards = []
        self.actions = []
        self.env = env

    def get_distribution(self, state, params):
        logits = self.network.apply(params, state)
        return jax.random.categorical(jax.random.PRNGKey(0), logits)

    def train(self, gamma, num_episodes):
        params = self.network.init(jax.random.PRNGKey(0), jnp.ones((self.env.observation_space.shape[0],)))

        for episode in range(num_episodes):
            state, _ = self.env.reset()
            done = False

            while not done:
                action = self.get_distribution(jnp.array(state), params)
                state, reward, done, _ , _= self.env.step(action)

                self.rewards.append(reward)
                self.actions.append(action)

            loss_value, grads = self.update_policy(params, gamma)
            updates, _ = self.optimizer.update(grads, params)
            params = optax.apply_updates(params, updates)

            if episode % 50 == 0:
                print(f"Completed episode {episode} and achieved score {sum(self.rewards)}")

    def update_policy(self, params, gamma):
        total_future_rewards = []
        total_future_reward = 0

        for reward in self.rewards[::-1]:
            total_future_reward = total_future_reward * gamma + reward
            total_future_rewards.append(total_future_reward)

        returns = jnp.array(total_future_rewards[::-1])
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        actions_onehot = jnp.array([onehot(a, self.env.action_space.n) for a in self.actions])
        log_probs = jnp.sum(jnp.log(self.network.apply(params, jnp.ones((1, self.env.observation_space.shape[0])))) * actions_onehot, axis=1)
        loss = -jnp.mean(returns * log_probs)
        grads = grad(loss)(params)

        return loss, grads

agent = Agent(env)
agent.train(0.99, 200)

test


AssertionError: Array(0, dtype=int32) (<class 'jaxlib.xla_extension.ArrayImpl'>) invalid 