<a href="https://colab.research.google.com/github/anirbanl/jax-code/blob/master/rlflax/pg/jax_flax_ddpg_pendulum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
!pip install jax jaxlib flax



In [22]:
import gym
gym.logger.set_level(40) # suppress warnings (please remove if gives error)
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline
import jax
import jax.numpy as jp
from jax.ops import index, index_add, index_update
from jax import jit, grad, vmap, random, jacrev, jacobian, jacfwd, value_and_grad
from functools import partial
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays
from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state
import optax                           # Optimizers
from typing import Sequence
import copy

In [23]:
env = gym.make('Pendulum-v0')
env.seed(0)
print('observation space:', env.observation_space)
print('action space:', env.action_space)

observation space: Box(-8.0, 8.0, (3,), float32)
action space: Box(-2.0, 2.0, (1,), float32)


#Normalize Action Space

In [24]:
class NormalizedActions(gym.ActionWrapper):

    def action(self, action):
        low_bound   = self.action_space.low
        upper_bound = self.action_space.high
        
        action = low_bound + (action + 1.0) * 0.5 * (upper_bound - low_bound)
        action = np.clip(action, low_bound, upper_bound)
        
        return action

    def reverse_action(self, action):
        low_bound   = self.action_space.low
        upper_bound = self.action_space.high
        
        action = 2 * (action - low_bound) / (upper_bound - low_bound) - 1
        action = np.clip(action, low_bound, upper_bound)
        
        return actions

#Ornstein-Uhlenbeck process
Adding time-correlated noise to the actions taken by the deterministic policy

In [25]:
class OUNoise(object):
    def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.3, decay_period=100000):
        self.mu           = mu
        self.theta        = theta
        self.sigma        = max_sigma
        self.max_sigma    = max_sigma
        self.min_sigma    = min_sigma
        self.decay_period = decay_period
        self.action_dim   = action_space.shape[0]
        self.low          = action_space.low
        self.high         = action_space.high
        self.reset()
        
    def reset(self):
        self.state = np.ones(self.action_dim) * self.mu
        
    def evolve_state(self):
        x  = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
        self.state = x + dx
        return self.state
    
    def get_action(self, action, t=0):
        ou_state = self.evolve_state()
        self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)
        return np.clip(action + ou_state, self.low, self.high)
    
#https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py

In [26]:
def plot(frame_idx, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))
    plt.plot(rewards)
    plt.show()

In [27]:
from collections import deque
class Memory():
    def __init__(self, rng, max_size = 1000):
        self.buffer = deque(maxlen=max_size)
        self.key = rng
    
    def add(self, experience):
        self.buffer.append(experience)
            
    def sample(self, batch_size):
        self.key, _ = jax.random.split(self.key)
        idx = jax.random.choice(self.key,
                               jp.arange(len(self.buffer)), 
                               shape=(batch_size, ))
        return [self.buffer[ii] for ii in idx]

    def __len__(self):
        return len(self.buffer)
'''
def init_memory(env, memory_size=1000000):
    # Initialize the simulation
    env.reset()
    # Take one random step to get the pole and cart moving
    state, reward, done, _ = env.step(env.action_space.sample())

    memory = Memory(max_size=memory_size)

    # Make a bunch of random actions and store the experiences
    for ii in range(pretrain_length):
        # Uncomment the line below to watch the simulation
        # env.render()

        # Make a random action
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)

        if done:
            # The simulation fails so no next state
            next_state = jp.zeros(state.shape)
            # Add experience to memory
            memory.add((state, action, reward, next_state))
            
            # Start new episode
            env.reset()
            # Take one random step to get the pole and cart moving
            state, reward, done, _ = env.step(env.action_space.sample())
        else:
            # Add experience to memory
            memory.add((state, action, reward, next_state))
            state = next_state
    return memory, state
'''

'\ndef init_memory(env, memory_size=1000000):\n    # Initialize the simulation\n    env.reset()\n    # Take one random step to get the pole and cart moving\n    state, reward, done, _ = env.step(env.action_space.sample())\n\n    memory = Memory(max_size=memory_size)\n\n    # Make a bunch of random actions and store the experiences\n    for ii in range(pretrain_length):\n        # Uncomment the line below to watch the simulation\n        # env.render()\n\n        # Make a random action\n        action = env.action_space.sample()\n        next_state, reward, done, _ = env.step(action)\n\n        if done:\n            # The simulation fails so no next state\n            next_state = jp.zeros(state.shape)\n            # Add experience to memory\n            memory.add((state, action, reward, next_state))\n            \n            # Start new episode\n            env.reset()\n            # Take one random step to get the pole and cart moving\n            state, reward, done, _ = env.step(e

In [28]:
class Policy:
    def __init__(self, rng, s_size=4, a_size=1, hidden_size=256, critic_lr=1e-3, actor_lr=1e-4):
        super(Policy, self).__init__()
        self.key = rng

        class Actor(nn.Module):
            features: Sequence[int]

            @nn.compact
            def __call__(self, x):
                x = nn.relu(nn.Dense(self.features[0])(x))
                x = nn.relu(nn.Dense(self.features[1])(x))
                x = 2*nn.sigmoid(2*nn.Dense(self.features[2])(x))-1
                return x

        self.actor = Actor(features=[hidden_size, hidden_size, a_size])

        class Critic(nn.Module):
            features: Sequence[int]

            @nn.compact
            def __call__(self, x):
                x = nn.relu(nn.Dense(self.features[0])(x))
                x = nn.relu(nn.Dense(self.features[1])(x))
                x = nn.Dense(self.features[2])(x)
                return x

        self.critic = Critic(features=[hidden_size, hidden_size, 1])

        def create_train_state(rng, model, learning_rate, input_size):
            """Creates initial `TrainState`."""
            params = model.init(rng, jp.ones((input_size, )))#['params']
            tx = optax.adam(learning_rate)
            return train_state.TrainState.create(
                apply_fn=model.apply, params=params, tx=tx)

        self.actor_ts = create_train_state(rng, self.actor, actor_lr, s_size)
        self.critic_ts = create_train_state(rng, self.critic, critic_lr, s_size + a_size)

        @jit
        def train_step(actor_ts, critic_ts, states, actions, targets):

            def loss_fn(actor_params, critic_params):
                critic_inputs_for_actor = jp.concatenate((states, actor_ts.apply_fn(actor_params, states)), axis=1)
                actor_loss = -jp.mean(critic_ts.apply_fn(critic_params, critic_inputs_for_actor))
                
                critic_inputs_for_critic = jp.concatenate((states, actions), axis=1)
                selectedq = critic_ts.apply_fn(critic_params, critic_inputs_for_critic)
                diff = selectedq - jax.lax.stop_gradient(targets)
                critic_loss = jp.mean(diff**2)

                return actor_loss, critic_loss

            actor_loss = lambda x: loss_fn(x, critic_ts.params)[0]
            critic_loss = lambda y: loss_fn(actor_ts.params, y)[1]
            al, ag = value_and_grad(actor_loss)(actor_ts.params)
            cl, cg = value_and_grad(critic_loss)(critic_ts.params)
            return actor_ts.apply_gradients(grads=ag), critic_ts.apply_gradients(grads=cg), al, cl

        self.train_fn = train_step

    def act(self, state):
        action = self.actor_ts.apply_fn(self.actor_ts.params, state)
        return action.item()


In [29]:
def train(rng, env, policy, ou_noise, n_episodes=100, max_t=500, batch_size = 128, memory_size=1000000, gamma = 0.99, soft_tau=1e-2):
    #memory, state = init_memory(env)
    memory = Memory(rng, max_size=memory_size)
    target_critic_params = copy.deepcopy(policy.critic_ts.params)
    target_actor_params = copy.deepcopy(policy.actor_ts.params)
    # print(f"Current critic params:{policy.critic_ts.params}")
    # print(f"Target critic params:{target_critic_params}")
    # print(f"Current actor params:{policy.actor_ts.params}")
    # print(f"Target actor params:{target_actor_params}")
    rewards_list = []

    for i_episode in range(1, n_episodes+1):
        state = env.reset()
        ou_noise.reset()
        episode_reward = 0
        for t in range(max_t):
            # import time
            action = policy.act(state)
            # print(f"action without ou noise:{action}")
            action = ou_noise.get_action(action, t)
            # print(f"action with ou noise:{action}")
            next_state, reward, done, _ = env.step(action)
            # print(f"Episode:{i_episode} Step:{t} State:{state} Action:{action} Reward:{reward} Done:{done}")
            
            memory.add((state, action, reward, next_state, done))

            if len(memory) > batch_size:
                # Sample mini-batch from memory
                # time.sleep(20)
                batch = memory.sample(batch_size)
                states = jp.array([each[0] for each in batch])
                actions = jp.array([each[1] for each in batch])
                rewards = jp.array([each[2] for each in batch])
                next_states = jp.array([each[3] for each in batch])

                # Train network
                target_next_actions = policy.actor_ts.apply_fn(target_actor_params, next_states)
                critic_inputs_for_critic = jp.concatenate((next_states, target_next_actions), axis=1)
                target_Qs = policy.critic_ts.apply_fn(target_critic_params, critic_inputs_for_critic)

                # Set target_Qs to 0 for states where episode ends
                episode_ends = (next_states == jp.zeros(states[0].shape)).all(axis=1)
                new_target_Qs = index_update(target_Qs, index[episode_ends], 0)
                target_Qs = new_target_Qs
                
                targets = rewards + gamma * target_Qs

                # print(f"Current critic params:{policy.critic_ts.params}")
                # print(f"Target critic params:{target_critic_params}")
                # print(f"Current actor params:{policy.actor_ts.params}")
                # print(f"Target actor params:{target_actor_params}")
                policy.actor_ts, policy.critic_ts, al, cl = policy.train_fn(policy.actor_ts, policy.critic_ts, states, actions, targets)
                # Update target network params
                update_fn = lambda current, target : soft_tau * current + (1.0 - soft_tau) * target
                target_critic_params = copy.deepcopy(jax.tree_multimap(update_fn, 
                                                                       policy.critic_ts.params,
                                                                       target_critic_params))
                target_actor_params = copy.deepcopy(jax.tree_multimap(update_fn,
                                                                      policy.actor_ts.params,
                                                                      target_actor_params))
                # print("**** UPDATED TARGET ****")
                # print(f"Current critic params:{policy.critic_ts.params}")
                # print(f"Target critic params:{target_critic_params}")
                # print(f"Current actor params:{policy.actor_ts.params}")
                # print(f"Target actor params:{target_actor_params}")


            episode_reward += reward
            state = next_state

            if done:
                break 

        rewards_list.append(episode_reward)
        print(f"Episode:{i_episode} Reward:{episode_reward} Average:{np.mean(rewards_list[-10:])}")


In [None]:
def main():
    #env = NormalizedActions(gym.make("Pendulum-v0").env)
    env = NormalizedActions(gym.make("Pendulum-v0"))
    env.seed(0)
    env.action_space.seed(0)
    np.random.seed(0)
    ou_noise = OUNoise(env.action_space)

    state_dim  = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    hidden_dim = 256

    print('observation space:', env.observation_space)
    print('action space:', env.action_space)
    rng = jax.random.PRNGKey(0)

    pi = Policy(rng, state_dim, action_dim, hidden_dim)
    scores = train(rng, env, pi, ou_noise)
    plot_scores(scores)

if __name__ == '__main__':
    main()

observation space: Box(-8.0, 8.0, (3,), float32)
action space: Box(-2.0, 2.0, (1,), float32)
Episode:1 Reward:-1566.9500056607076 Average:-1566.9500056607076
Episode:2 Reward:-1477.2577702923757 Average:-1522.1038879765415
Episode:3 Reward:-1614.7446501774293 Average:-1552.984142043504
Episode:4 Reward:-1377.9845718852332 Average:-1509.2342495039363
Episode:5 Reward:-1610.2781510766104 Average:-1529.4430298184711
Episode:6 Reward:-1508.5267176007364 Average:-1525.9569777821819
Episode:7 Reward:-966.1556651684064 Average:-1445.9853616944997
Episode:8 Reward:-1614.6203250647993 Average:-1467.0647321157871
Episode:9 Reward:-1201.047668118781 Average:-1437.5072805605641
Episode:10 Reward:-1591.7333980631565 Average:-1452.9298923108236
Episode:11 Reward:-1538.218097529041 Average:-1450.056701497657
Episode:12 Reward:-1341.9226603727966 Average:-1436.523190505699
Episode:13 Reward:-1603.9006352701397 Average:-1435.4387890149703
Episode:14 Reward:-1610.3638134480907 Average:-1458.676713171255