# DDPG algorithm

## Initialize actor and critic network

In [1]:
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap
from jax import random
from flax import linen as nn  # Linen API
from collections import deque
import gymnasium as gym

env = gym.make("MountainCarContinuous-v0")
seed = 0
key = random.PRNGKey(seed)

F_CPP_MIN_LOG_LEVEL=0
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


1
2


In [34]:
#create the actor and critic newtorks like multilayer perceptrons

# class Critic(nn.Module):
#     def setup(self):
#         self.q = nn.Dense(features=1)

#     def __call__(self, obs, act):
#         x = jnp.concatenate([obs, act], axis=-1)
#         q = self.q(x)
#         return jnp.squeeze(q, axis=-1)
    
class Critic(nn.Module):
    """critic model MLP"""
    
    @nn.compact
    def __call__(self, observations, actions):
        x = jnp.concatenate([observations, actions], axis=-1)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return jnp.squeeze(x, axis=-1)
    
class Actor(nn.Module):
    """actor model MLP"""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=action_dim)(x)
        x = nn.tanh(x)
        return x

[0. 0.]
[[0. 0.]]


In [45]:
# Randomly initialize critic network Q(s, a|θ_Q ) and actor μ(s|θ_μ ) with weights θ_Q and θ_μ .
critic = Critic()
critic_params = critic.init(key, jnp.zeros((1,action_dim)), jnp.zeros((1,state_dim)))
actor = Actor()
actor_params = actor.init(key, jnp.zeros((1, state_dim)))

check_critic = jax.tree_util.tree_map(lambda x: x.shape, critic_params) #checking critic params
check_actor = jax.tree_util.tree_map(lambda x: x.shape, actor_params) #checking actor params

print(actor.tabulate(key, (1, state_dim) ))
print("actor parameters:\n", check_actor)

print(critic.tabulate(key, jnp.ones((1,action_dim)), jnp.ones((1,state_dim))))
print("critic parameters:\n", check_critic, "\n")




[3m                                Actor Summary                                [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs      [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams                  [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ Actor  │ - 1          │ [2mfloat32[0m[1]   │                          │
│         │        │ - 2          │              │                          │
├─────────┼────────┼──────────────┼──────────────┼──────────────────────────┤
│ Dense_0 │ Dense  │ - 1          │ [2mfloat32[0m[256] │ bias: [2mfloat32[0m[256]       │
│         │        │ - 2          │              │ kernel: [2mfloat32[0m[2,256]   │
│         │        │              │              │                          │
│         │        │              │              │ [1m768 

In [46]:
#foward example with the initial  parameters 
# the parameters never store in the model
critic_forward = critic.apply(critic_params, jnp.ones((1, action_dim)), jnp.ones((1, state_dim)) )
actor_forward = actor.apply(actor_params, jnp.ones((1, state_dim)))

print("critic forward test:", critic_forward, "\n")
print("actor forward test:", actor_forward)

critic forward test: [0.13340816] 

actor forward test: [[-0.4034107]]


In [47]:
# Initialize target network Q_0_target and μ_0_target with weights 
# θ_Q_target ← θ_Q , θ_μ_target ← θ_μ
target_critic = Critic()
target_actor = Actor()

target_critic_params = critic_params
target_actor_params = actor_params


In [48]:
# Initialize replay buffer R
buffer_size = 1000
buffer = deque(maxlen=buffer_size)

In [49]:
# Initialize a random process N for action exploration
def noise(noise_scale=0.1):
    return noise_scale * jax.random.normal(key, (action_dim,))
# Receive initial observation state s 1
state, info = env.reset(seed=seed)
print(state)
N = noise(0.1)
print(N)

[-0.47260767  0.        ]
[-0.02058423]


In [66]:
# Select action a_t = μ(s t |θ μ ) + N t according to the current policy and exploration noise
action = noise() + actor.apply(actor_params, state)
# action = env.action_space.sample()
print("action:", action, "\n")

# Execute action a t and observe reward r t and observe new state s t+1
next_state, reward, terminated, done, info = env.step(action)
print("next state:", next_state, "\n")

# Store transition (s t , a t , r t , s t+1 ) in R
transition = (state, action, reward, next_state)
buffer.append(transition)
print("buffer:", buffer, "\n")

# Sample a random minibatch of N transitions (s i , a i , r i , s i+1 ) from R
batch_size = 1
indices = jax.random.choice(key, len(buffer), shape=(batch_size,), replace=False)
minibatch = [buffer[i] for i in indices]
print("minibatch:", minibatch)

action: [-0.07492159] 

next state: [-0.47748625 -0.00193688] 

buffer: deque([(array([-0.47260767,  0.        ], dtype=float32), Array([-0.07492159], dtype=float32), -0.0005613243991732519, array([-0.473101  , -0.00049333], dtype=float32)), (array([-0.47260767,  0.        ], dtype=float32), Array([-0.07492159], dtype=float32), -0.0005613243991732519, array([-0.474084  , -0.00098299], dtype=float32)), (array([-0.47260767,  0.        ], dtype=float32), Array([-0.07492159], dtype=float32), -0.0005613243991732519, array([-0.47554937, -0.00146537], dtype=float32)), (array([-0.47260767,  0.        ], dtype=float32), Array([-0.07492159], dtype=float32), -0.0005613243991732519, array([-0.47748625, -0.00193688], dtype=float32))], maxlen=1000) 

minibatch: [(array([-0.47260767,  0.        ], dtype=float32), Array([-0.07492159], dtype=float32), -0.0005613243991732519, array([-0.47554937, -0.00146537], dtype=float32))]


In [67]:
# Set y i = r i + γQ 0 (s i+1 , μ 0 (s i+1 |θ μ )|θ Q ) P
u_target = target_actor.apply(target_actor_params, next_state)
Q_target = target_critic.apply(target_critic_params, next_state, u_target)
gamma = 0.1

y = reward + gamma * (1 - done) * Q_target
print(y)

-0.016791953


In [70]:
# Update critic by minimizing the loss: L = N 1 i (y i − Q(s i , a i |θ Q ) 2 )
Q = critic.apply(critic_params, state, action)
loss = ((Q - y)**2).mean() #compute loss
print(loss)

0.021232113
