# DDPG algorithm

## Initialize actor and critic network

In [40]:
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap
from jax import random
from flax import linen as nn  # Linen API
from collections import deque
import gymnasium as gym

env = gym.make("CartPole-v1")
seed = 0
key = random.PRNGKey(seed)

F_CPP_MIN_LOG_LEVEL=0
action_dim = env.action_space.n
state_dim = 20

# print(action_dim)
# print(state_dim)

In [35]:
#create the actor and critic newtorks like multilayer perceptrons


class Critic(nn.Module):
    """critic model MLP"""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x
    
class Actor(nn.Module):
    """actor model MLP"""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=action_dim)(x)
        x = nn.tanh(x)
        return x

In [36]:
# Randomly initialize critic network Q(s, a|θ_Q ) and actor μ(s|θ_μ ) with weights θ_Q and θ_μ .
critic = Critic()
critic_input = jnp.ones((1, state_dim + action_dim))
critic_params = critic.init(key, critic_input)
actor = Actor()
actor_input = jnp.ones((1, state_dim))
actor_params = actor.init(key, actor_input)

print(actor.tabulate(key, (1, state_dim) ))

check_critic = jax.tree_util.tree_map(lambda x: x.shape, critic_params) #checking critic params
check_actor = jax.tree_util.tree_map(lambda x: x.shape, actor_params) #checking actor params

print("critic parameters:\n", check_critic, "\n")
print("actor parameters:\n", check_actor)


[3m                                Actor Summary                                [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs      [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams                  [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ Actor  │ - 1          │ [2mfloat32[0m[2]   │                          │
│         │        │ - 20         │              │                          │
├─────────┼────────┼──────────────┼──────────────┼──────────────────────────┤
│ Dense_0 │ Dense  │ - 1          │ [2mfloat32[0m[256] │ bias: [2mfloat32[0m[256]       │
│         │        │ - 20         │              │ kernel: [2mfloat32[0m[2,256]   │
│         │        │              │              │                          │
│         │        │              │              │ [1m768 

In [37]:
#foward example with the initial  parameters 
# the parameters never store in the model
critic_forward = critic.apply(critic_params, critic_input)
actor_forward = actor.apply(actor_params, actor_input)

print("critic forward test:", critic_forward, "\n")
print("actor forward test:", actor_forward)

critic forward test: [[0.18390828]] 

actor forward test: [[0.3893989  0.13889198]]


In [38]:
# Initialize target network Q_0_target and μ_0_target with weights 
# θ_Q_target ← θ_Q , θ_μ_target ← θ_μ
target_critic = Critic()
target_actor = Actor()

target_critic_params = critic_params
target_actor_params = actor_params


In [39]:
# Initialize replay buffer R
buffer_size = 1000
buffer = deque(maxlen=buffer_size)

In [43]:
# episode loop

# Initialize a random process N for action exploration
def noise(noise_scale=0.1):
    return noise_scale * jax.random.normal(key, (action_dim,))
# Receive initial observation state s 1
observation, info = env.reset(seed=seed)
print(observation)
N = noise(0.1)
print(N)

    

[ 0.01369617 -0.02302133 -0.04590265 -0.04834723]
[-0.0784766   0.08564448]
