#### Playground to understand the code

In [7]:
import jax
import jax.numpy as jnp
import numpy as np
import gym
import optax
from jax import grad, jit, vmap
from jax import random
from functools import partial
import haiku as hk


In [29]:
def mlp(x, sizes, activation=jax.nn.relu, output_activation=None):
    layers = []
    for size in sizes[:-1]:
        layers.append(hk.Linear(size))
        layers.append(activation)
    layers.append(hk.Linear(sizes[-1]))
    if output_activation is not None:
        layers.append(output_activation)
    return hk.Sequential(layers)(x)

def actor_fn(x):
    return mlp(x, [64, 64, 10], jax.nn.relu, jax.nn.softmax)

def critic_fn(x):
    return mlp(x, [64, 64, 1], jax.nn.relu, None)

# Transform the function into Haiku modules
actor = hk.transform(actor_fn)
critic = hk.transform(critic_fn)

# Example usage
rng = jax.random.PRNGKey(42)
x = jnp.array([1.0, 2.0, 3.0])  # Example input

# Initialize parameters for actor and critic
actor_params = actor.init(rng, x)
critic_params = critic.init(rng, x)

# Forward pass
actor_output = actor.apply(actor_params, None, rng, x)
critic_output = critic.apply(critic_params, None, rng, x)

print("Actor output:", actor_output)
print("Critic output:", critic_output)


TypeError: actor_fn() takes 1 positional argument but 2 were given

In [19]:
class ActorCritic:
    def __init__(self, observation_space, action_space, hidden_sizes=(64, 64), seed=0):
        self.key = jax.random.PRNGKey(seed)
        self.actor_key, self.critic_key = jax.random.split(self.key)
        
        # Initialize the networks
        dummy_input = jnp.zeros(observation_space.shape[0])
        self.actor_params = actor.init(self.actor_key, dummy_input)
        self.critic_params = critic.init(self.critic_key, dummy_input)
        
    def actor(self, params, x):
        return actor.apply(params, None, self.actor_key, x)  # Updated

    def critic(self, params, x):
        return critic.apply(params, None, self.critic_key, x)  # Updated


In [20]:
def train(env_name='CartPole-v1', epochs=50, steps_per_epoch=4000, gamma=0.99, clip_ratio=0.2, pi_lr=3e-4, vf_lr=1e-3):
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    
    ac = ActorCritic(env.observation_space, env.action_space)
    
    opt_pi = optax.adam(pi_lr)
    opt_vf = optax.adam(vf_lr)
    
    opt_pi_state = opt_pi.init(ac.actor_params)
    opt_vf_state = opt_vf.init(ac.critic_params)
    
    for epoch in range(epochs):
        obs = env.reset()
        obs_buffer, act_buffer, rew_buffer, val_buffer, logp_buffer = [], [], [], [], []
        ep_ret, ep_len = 0, 0
        
        for t in range(steps_per_epoch):
            pi = ac.actor(ac.actor_params, obs)
            action = np.random.choice(act_dim, p=np.array(pi))
            value = ac.critic(ac.critic_params, obs)
            logp = jnp.log(pi[action])
            
            obs_buffer.append(obs)
            act_buffer.append(action)
            rew_buffer.append(reward)
            val_buffer.append(value)
            logp_buffer.append(logp)
            
            obs, reward, done, _ = env.step(action)
            ep_ret += reward
            ep_len += 1
            
            if done or (t == steps_per_epoch - 1):
                last_val = reward if done else ac.critic(ac.critic_params, obs)
                rew_to_go = last_val
                
                returns = []
                for r in rew_buffer[::-1]:
                    rew_to_go = r + gamma * rew_to_go
                    returns.insert(0, rew_to_go)
                
                returns = jnp.array(returns)
                obs_buffer = jnp.array(obs_buffer)
                act_buffer = jnp.array(act_buffer)
                logp_buffer = jnp.array(logp_buffer)
                advantages = returns - jnp.array(val_buffer)
                
                opt_pi_state, ac.actor_params = opt_pi.update(grad(ppo_loss)(ac.actor_params, ac.critic_params, obs_buffer, act_buffer, advantages, logp_buffer, returns, clip_ratio), ac.actor_params)
                opt_vf_state, ac.critic_params = opt_vf.update(grad(ppo_loss)(ac.actor_params, ac.critic_params, obs_buffer, act_buffer, advantages, logp_buffer, returns, clip_ratio), ac.critic_params)
                
                obs, ep_ret, ep_len = env.reset(), 0, 0
                obs_buffer, act_buffer, rew_buffer, val_buffer, logp_buffer = [], [], [], [], []
                
        print(f'Epoch {epoch+1}/{epochs}: Reward = {ep_ret}')


In [21]:
train()

TypeError: actor_fn() takes 1 positional argument but 2 were given