# DDPG algorithm

## Initialize actor and critic network

In [25]:
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap
import jax.tree_util as jtu
from jax import random
from flax import linen as nn  # Linen API
import optax
from collections import deque
import gymnasium as gym

env = gym.make("MountainCarContinuous-v0")
seed = 0
key = random.PRNGKey(seed)

F_CPP_MIN_LOG_LEVEL=0
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]

In [2]:
#create the actor and critic newtorks like multilayer perceptrons
class Critic(nn.Module):
    """critic model MLP"""
    
    @nn.compact
    def __call__(self, observations, actions):
        x = jnp.concatenate([observations, actions], axis=-1)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return jnp.squeeze(x, axis=-1)
    
class Actor(nn.Module):
    """actor model MLP"""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=action_dim)(x)
        x = nn.tanh(x)
        return x

In [3]:
# Randomly initialize critic network Q(s, a|θ_Q ) and actor μ(s|θ_μ ) with weights θ_Q and θ_μ .
critic = Critic()
critic_params = critic.init(key, jnp.zeros((1,action_dim)), jnp.zeros((1,state_dim)))
actor = Actor()
actor_params = actor.init(key, jnp.zeros((1, state_dim)))

check_critic = jax.tree_util.tree_map(lambda x: x.shape, critic_params) #checking critic params
check_actor = jax.tree_util.tree_map(lambda x: x.shape, actor_params) #checking actor params

print(actor.tabulate(key, (1, state_dim) ))
print("actor parameters:\n", check_actor)

print(critic.tabulate(key, jnp.ones((1,action_dim)), jnp.ones((1,state_dim))))
print("critic parameters:\n", check_critic, "\n")




[3m                                Actor Summary                                [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs      [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams                  [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ Actor  │ - 1          │ [2mfloat32[0m[1]   │                          │
│         │        │ - 2          │              │                          │
├─────────┼────────┼──────────────┼──────────────┼──────────────────────────┤
│ Dense_0 │ Dense  │ - 1          │ [2mfloat32[0m[256] │ bias: [2mfloat32[0m[256]       │
│         │        │ - 2          │              │ kernel: [2mfloat32[0m[2,256]   │
│         │        │              │              │                          │
│         │        │              │              │ [1m768 

In [4]:
#foward example with the initial  parameters 
# the parameters never store in the model
critic_forward = critic.apply(critic_params, jnp.ones((1, action_dim)), jnp.ones((1, state_dim)) )
actor_forward = actor.apply(actor_params, jnp.ones((1, state_dim)))

print("critic forward test:", critic_forward, "\n")
print("actor forward test:", actor_forward)

critic forward test: [0.13340816] 

actor forward test: [[-0.4034107]]


In [5]:
# Initialize target network Q_0_target and μ_0_target with weights 
# θ_Q_target ← θ_Q , θ_μ_target ← θ_μ
target_critic = Critic()
target_actor = Actor()

target_critic_params = critic_params
target_actor_params = actor_params


In [6]:
# Initialize replay buffer R
buffer_size = 1000
buffer = deque(maxlen=buffer_size)

In [7]:
# Initialize a random process N for action exploration
def noise(noise_scale=0.1):
    return noise_scale * jax.random.normal(key, (action_dim,))
# Receive initial observation state s 1
state, info = env.reset(seed=seed)
print(state)
N = noise(0.1)
print(N)

[-0.47260767  0.        ]
[-0.02058423]


In [8]:
# Select action a_t = μ(s t |θ μ ) + N t according to the current policy and exploration noise
action = noise() + actor.apply(actor_params, state)
# action = env.action_space.sample()
print("action:", action, "\n")

# Execute action a t and observe reward r t and observe new state s t+1
next_state, reward, terminated, done, info = env.step(action)
print("next state:", next_state, "\n")

# Store transition (s t , a t , r t , s t+1 ) in R
transition = (state, action, reward, next_state)
buffer.append(transition)
print("buffer:", buffer, "\n")

# Sample a random minibatch of N transitions (s i , a i , r i , s i+1 ) from R
batch_size = 1
indices = jax.random.choice(key, len(buffer), shape=(batch_size,), replace=False)
minibatch = [buffer[i] for i in indices]
print("minibatch:", minibatch)

action: [-0.07492159] 

next state: [-0.473101   -0.00049333] 

buffer: deque([(array([-0.47260767,  0.        ], dtype=float32), Array([-0.07492159], dtype=float32), -0.0005613243991732519, array([-0.473101  , -0.00049333], dtype=float32))], maxlen=1000) 

minibatch: [(array([-0.47260767,  0.        ], dtype=float32), Array([-0.07492159], dtype=float32), -0.0005613243991732519, array([-0.473101  , -0.00049333], dtype=float32))]


In [9]:
# Set y i = r i + γQ 0 (s i+1 , μ 0 (s i+1 |θ μ )|θ Q ) P
actor_target = target_actor.apply(target_actor_params, next_state)
critic_target = target_critic.apply(target_critic_params, next_state, actor_target)
gamma = 0.1

y = reward + gamma * (1 - done) * critic_target
print(y)

-0.016732937


In [20]:
# loss of the critic: L = N 1 i (y i − Q(s i , a i |θ Q ) 2 )
Q = critic.apply(critic_params, state, action)
q_loss = ((Q - y)**2).mean() #compute loss
print(q_loss)


# loss of the actor
def compute_actor_loss(actor_params, state):
    # Compute the actor loss
    actions = actor.apply(actor_params, state)
    return -jnp.mean(actions)

actor_loss = compute_actor_loss(actor_params, next_state)
print(actor_loss)

0.021249318
0.05431732


In [23]:
# Update the actor policy using the sampled gradient

# define optimizer
actor_optimizer = optax.adam(learning_rate=100)
opt_state = actor_optimizer.init(actor_params)

# Define a function to update the actor policy
@jax.jit
def update_actor(actor_params, state, opt_state):
    # Compute the gradients of the loss with respect to the actor parameters
    grad_fn = jax.grad(compute_actor_loss)
    actor_grads = grad_fn(actor_params, state)

    # Update the actor parameters using the optimizer
    updates, opt_state = actor_optimizer.update(actor_grads, opt_state) 
    new_params = optax.apply_updates(actor_params, updates)

    return new_params

test = update_actor(actor_params, next_state, opt_state)

print(test)



FrozenDict({
    params: {
        Dense_0: {
            bias: Array([  0.      ,   0.      ,  99.99929 ,  99.99901 ,   0.      ,
                    99.999176, -99.99929 ,   0.      ,  99.999306, -99.99929 ,
                     0.      ,   0.      ,   0.      ,   0.      , -99.999306,
                     0.      , -99.9993  , -99.9993  ,   0.      ,   0.      ,
                    99.9993  ,  99.999306,   0.      ,   0.      ,   0.      ,
                     0.      ,  99.999306,   0.      ,   0.      ,   0.      ,
                     0.      ,  99.99911 , -99.9993  ,   0.      ,  99.99927 ,
                   -99.999306,   0.      ,  99.99921 , -99.9993  ,  99.99928 ,
                     0.      ,   0.      ,   0.      ,  99.99932 ,   0.      ,
                     0.      , -99.999306, -99.99874 ,  99.999306,  99.999245,
                    99.999275,   0.      , -99.99928 , -99.99921 , -99.99923 ,
                     0.      ,   0.      ,   0.      ,  99.999146, -99.99929 ,


In [27]:
# Update the target networks of the actor

# Define the soft update function
@jax.jit
def soft_update(target_params, source_params, tau):
    # Convert the source_params to a JAX-compatible data structure
    source_params_tree = jtu.tree_map(lambda x: jnp.asarray(x), source_params)

    # Compute the updated target parameters using a soft update
    updated_params_tree = jtu.tree_map(lambda x, y: tau * x + (1 - tau) * y, source_params_tree, target_params)
    updated_params = jtu.tree_leaves(updated_params_tree)

    return updated_params

#  Update the target networks using a soft update
tau = 0.001
target_actor_params = soft_update(target_actor_params, actor_params, tau)

print(target_actor_params)

[Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0