In [1]:
import jax 
import jax.numpy as jnp
import equinox as eqx
import gymnasium as gym
from jaxtyping import Float32, Int8, Int32, PyTree
from jaxtyping import PRNGKeyArray, Array
import numpy as np
from typing import Tuple
import tensorflow_probability.substrates.jax as tfp
from gymnasium.wrappers import TimeLimit
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
class Critic(eqx.Module):
    layers: list
    def __init__(self, state_dim: int, key: PRNGKeyArray):
        super().__init__()
        subkey1, subkey2 = jax.random.split(key)
        
        self.layers = [
            eqx.nn.Linear(state_dim, 32, key=subkey1),
            jax.nn.relu,
            eqx.nn.Linear(32, 1, key=subkey2)
        ]
    
    def __call__(self, state: Float32[jnp.ndarray, "state_dim"]):
        for layer in self.layers:
            state = layer(state)
        return state


In [3]:
class Actor(eqx.Module):
    layers: list
    def __init__(self, state_dim: int, action_dim: int, key: PRNGKeyArray):
        super().__init__()
        subkey1, subkey2 = jax.random.split(key)
        
        self.layers = [
            eqx.nn.Linear(state_dim, 32, key=subkey1),
            jax.nn.relu,
            eqx.nn.Linear(32, action_dim, key=subkey2)
        ]
    
    def __call__(self, state: Float32[jnp.ndarray, "state_dim"]):
        for layer in self.layers:
            state = layer(state)
        return state

In [4]:
def estimate_advantages(rewards: Float32[Array, "batch_size max_steps"],
                        states: Float32[Array, "batch_size state_dim max_steps"],
                        next_states: Float32[Array, "batch_size state_dim max_steps"],
                        dones: Int8[Array, "batch_size max_steps"],
                        critic: PyTree,
                        gamma=0.99,
                        lambda_=0.95):
    values = jax.vmap(critic)(states)
    next_values = jax.vmap(critic)(next_states)
    
    td_errors = rewards + gamma * next_values * (1 - dones) - values
    
    advantages = jnp.zeros_like(td_errors)
    advantages.at[-1].set(td_errors[-1])
    
    for t in reversed(range(len(td_errors) - 1)):
        a = td_errors[t] + gamma * lambda_ * advantages[t+1]
        advantages = advantages.at[t].set(a)
    
    return advantages
                        

In [5]:
def get_action(logits: Float32[Array, "action_dim"], key: PRNGKeyArray):
    key, subkey = jax.random.split(key)
    probabilities = tfp.distributions.Categorical(logits=logits)
    sampled_action = probabilities.sample(seed=subkey)
    
    return sampled_action

In [6]:
def calculate_log_probs(policy: PyTree, state: Float32[Array, "state_dim"], action: int):
    logits = policy(state)
    action_probs = jax.nn.softmax(logits)   
    log_probs = jnp.log(action_probs)
    
    return log_probs[action]

In [15]:
env = gym.make("LunarLander-v2")

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

print(f"{state_dim=}, {action_dim=}")

state_dim=8, action_dim=4


In [16]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
critic = Critic(state_dim, key=key)
actor = Actor(state_dim, action_dim, key=subkey) 


In [26]:
def rollout(env: gym.Env, key: PRNGKeyArray, actor: PyTree) -> Tuple[Float32[Array, "eps_steps state_dim"], 
                                                                    Float32[Array, "eps_steps"],
                                                                    Float32[Array, "eps_steps"],
                                                                    Float32[Array, "eps_steps"]]:
    state_dim = env.observation_space.shape[0]
    
    # Run one episode
    obs, info = env.reset()
    
    # Store episode data
    observations = []
    rewards = []
    actions = []
    dones = []
    
    while True:
        key, subkey = jax.random.split(key)
        logits = actor(obs)
        action = int(get_action(logits, subkey))

        obs, reward, terminated, truncated, info = env.step(action)

        observations.append(obs)
        rewards.append(reward)
        actions.append(action)
        
        
        if terminated or truncated:
            dones.append(1)
            break
        
        dones.append(0)
        
    return jnp.array(observations), jnp.array(rewards), jnp.array(actions), jnp.array(dones)

In [36]:
key, subkey = jax.random.split(key)
observations, rewards, actions, dones = rollout(env, subkey, actor)