# DDPG algorithm

## Initialize actor and critic network

In [9]:
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap
from jax import random
from flax import linen as nn  # Linen API
from collections import deque
import gymnasium as gym

F_CPP_MIN_LOG_LEVEL=0

In [4]:
#create the actor and critic newtorks like multilayer perceptrons
action_dim = 10
state_dim = 20

class Critic(nn.Module):
    """critic model MLP"""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x
    
class Actor(nn.Module):
    """actor model MLP"""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=action_dim)(x)
        x = nn.tanh(x)
        return x

In [5]:
# Randomly initialize critic network Q(s, a|θ_Q ) and actor μ(s|θ_μ ) with weights θ_Q and θ_μ .
key = random.PRNGKey(0)
critic = Critic()
critic_input = jnp.ones((1, state_dim + action_dim))
critic_params = critic.init(key, critic_input)
actor = Actor()
actor_input = jnp.ones((1, state_dim))
actor_params = actor.init(key, actor_input)

print(actor.tabulate(key, (1, state_dim) ))

check_critic = jax.tree_util.tree_map(lambda x: x.shape, critic_params) #checking critic params
check_actor = jax.tree_util.tree_map(lambda x: x.shape, actor_params) #checking actor params

print("critic parameters:\n", check_critic, "\n")
print("actor parameters:\n", check_actor)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)



[3m                                Actor Summary                                [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs      [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams                  [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ Actor  │ - 1          │ [2mfloat32[0m[10]  │                          │
│         │        │ - 20         │              │                          │
├─────────┼────────┼──────────────┼──────────────┼──────────────────────────┤
│ Dense_0 │ Dense  │ - 1          │ [2mfloat32[0m[256] │ bias: [2mfloat32[0m[256]       │
│         │        │ - 20         │              │ kernel: [2mfloat32[0m[2,256]   │
│         │        │              │              │                          │
│         │        │              │              │ [1m768 

In [6]:
#foward example with the initial  parameters 
# the parameters never store in the model
critic.apply(critic_params, critic_input)

Array([[-0.8972499]], dtype=float32)

In [7]:
actor.apply(actor_params, actor_input)

Array([[-0.28007025, -0.21052897, -0.23458375, -0.5754893 , -0.56756514,
        -0.14069036,  0.14474116, -0.06463255, -0.6299922 ,  0.8642028 ]],      dtype=float32)

In [8]:
# Initialize target network Q_0_target and μ_0_target with weights 
# θ_Q_target ← θ_Q , θ_μ_target ← θ_μ
target_critic = Critic()
target_actor = Actor()

target_critic_params = critic_params
target_actor_params = actor_params


In [19]:
# Initialize replay buffer R
buffer_size = 1000
buffer = deque(maxlen=buffer_size)