# DDPG algorithm in flax

In [2]:
import jax.numpy as jnp
import jax
from flax.training import train_state
import jax.tree_util as jtu
from jax import random
from flax import linen as nn  # Linen API
import optax
from collections import deque
import gymnasium as gym

env = gym.make("MountainCarContinuous-v0")
seed = 0
key = random.PRNGKey(seed)

F_CPP_MIN_LOG_LEVEL=0
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Create the actor and critic model

In [2]:
#create the actor and critic newtorks like multilayer perceptrons
class Critic(nn.Module):
    """critic model MLP"""
    
    @nn.compact
    def __call__(self, observations, actions):
        x = jnp.concatenate([observations, actions], axis=-1)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return jnp.squeeze(x, axis=-1)
    
class Actor(nn.Module):
    """actor model MLP"""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=action_dim)(x)
        x = nn.tanh(x)
        return x

## Create necessary methods to the training

In [3]:
#random process N for action exploration
def noise(noise_scale=0.1, key=key, action_dim=action_dim):
    return noise_scale * jax.random.normal(key, (action_dim,))

In [4]:
# Define the method to update model parameters

# update critic
@jax.jit
def update_critic(model, states, actions, y):
    def compute_critic_loss(params):
        Q = model.apply_fn(params, states, actions)
        return jnp.mean((Q - y)**2) #compute loss
    
    grad_fn = jax.grad(compute_critic_loss)
    grads = grad_fn(model.params)
    updated_model = model.apply_gradients(grads=grads)
    return updated_model

# udate actor
@jax.jit
def update_actor(model, states):
    def compute_actor_loss(params):
        actions = model.apply_fn(params, states)
        return -jnp.mean(actions)  # Compute the actor loss
    
    grad_fn = jax.grad(compute_actor_loss)
    grads = grad_fn(model.params)
    updated_model = model.apply_gradients(grads=grads)
    return updated_model

In [5]:
# Define the soft update function
@jax.jit
def soft_update(target_params, source_params, tau):
    # Convert the source_params to a JAX-compatible data structure
    source_params_tree = jtu.tree_map(lambda x: jnp.asarray(x), source_params)
    target_params_tree = jtu.tree_map(lambda x: jnp.asarray(x), target_params)


    # Compute the updated target parameters using a soft update
    updated_params = jtu.tree_map(lambda x, y: tau * x + (1 - tau) * y, source_params_tree, target_params_tree)

    return updated_params

## Algorithm

In [6]:
# Randomly initialize critic network Q(s, a|θ_Q ) and actor μ(s|θ_μ ) with weights θ_Q and θ_μ .
critic_params = Critic().init(key, jnp.zeros((1,action_dim)), jnp.zeros((1,state_dim)))
actor_params = Actor().init(key, jnp.zeros((1, state_dim)))

# define optimizers
actor_optimizer = optax.adam(learning_rate=100)
actor_opt_state = actor_optimizer.init(actor_params)

critic_optimizer = optax.adam(learning_rate=100)
critic_opt_state = critic_optimizer.init(critic_params)

# Initialize the training state for flax porpuses
critic = train_state.TrainState.create(
    apply_fn=Critic().apply,
    params=critic_params,
    tx=critic_optimizer
)

actor = train_state.TrainState.create(
    apply_fn=Actor().apply,
    params=actor_params,
    tx=actor_optimizer
)

print(Actor().tabulate(key, (1, state_dim) ))
print(Critic().tabulate(key, jnp.ones((1,action_dim)), jnp.ones((1,state_dim))))



[3m                                Actor Summary                                [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs      [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams                  [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ Actor  │ - 1          │ [2mfloat32[0m[1]   │                          │
│         │        │ - 2          │              │                          │
├─────────┼────────┼──────────────┼──────────────┼──────────────────────────┤
│ Dense_0 │ Dense  │ - 1          │ [2mfloat32[0m[256] │ bias: [2mfloat32[0m[256]       │
│         │        │ - 2          │              │ kernel: [2mfloat32[0m[2,256]   │
│         │        │              │              │                          │
│         │        │              │              │ [1m768 

In [7]:
# Initialize target network Q_0_target and μ_0_target with weights θ_Q_target ← θ_Q , θ_μ_target ← θ_μ
target_critic_params = critic_params
target_actor_params = actor_params

# Initialize the training state for flax porpuses
target_critic = train_state.TrainState.create(
    apply_fn=Critic().apply,
    params=target_critic_params,
    tx=critic_optimizer
)

target_actor = train_state.TrainState.create(
    apply_fn=Actor().apply,
    params=target_actor_params,
    tx=actor_optimizer
)

In [14]:
# Initialize replay buffer R
buffer_size = 100000
batch_size = 10

class ReplayBuffer():
    def __init__(self, buffer_size, batch_size):
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.buffer = deque(maxlen=self.buffer_size)

    def add(self, transition):
        self.buffer.append(transition)

    def sample_batch(self):
        indices = jax.random.choice(key, len(self.buffer), shape=(self.batch_size,), replace=True)
        batch = [self.buffer[i] for i in indices]

        return zip(*batch)
    
buffer = ReplayBuffer(buffer_size, batch_size)


In [13]:
episode = 1
T = 2
gamma = 0.1
tau = 0.001

for i in range(episode):
    # Initialize a random process N for action exploration
    N = noise(0.1)
    # Receive initial observation state s 1
    state, info = env.reset(seed=seed)

    for t in range(T):

        # Select action a_t = μ(s t |θ μ ) + N t according to the current policy and exploration noise
        action = noise() + actor.apply_fn(actor.params, state)

        # Execute action a t and observe reward r t and observe new state s t+1
        observation, reward, terminated, truncated, info = env.step(action)

        # Store transition (s t , a t , r t , s t+1 ) in R
        transition = (state, action, reward, observation)

        # buffer.append(transition)
        buffer.add(transition)

        # Sample a random minibatch of N transitions (s i , a i , r i , s i+1 ) from R
        states, actions, rewards, next_states = buffer.sample_batch()

        # Set y = r  + γQ^0 (s_{i+1} , μ^0 (s_{i+1} |θ^μ )|θ^Q ) P
        target_action = target_actor.apply_fn(target_actor_params, next_states)
        target_q = target_critic.apply_fn(target_critic_params, jnp.asarray(next_states), jnp.asarray(target_action))

        y = reward + gamma * (1 - terminated) * target_q

        # Update critic by minimizing the loss
        critic = update_critic(critic, jnp.asarray(states), jnp.asarray(actions), jnp.asarray(y))

        # Update the actor policy using the sampled gradient:
        actor = update_actor(actor, jnp.asarray(states))

        # Update the target networks:
        target_actor_params = soft_update(target_actor_params, actor.params, tau)
        target_critic_params = soft_update(target_critic_params, critic.params, tau)

        state = observation


## Stable baselines training

In [3]:
import gymnasium as gym
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
import numpy as np

# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("ddpg_MountainCarContinuous")
vec_env = model.get_env()

del model # remove to demonstrate saving and loading

model = DDPG.load("ddpg_MountainCarContinuous")

obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    env.render("human")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


KeyboardInterrupt: 