# Imports and torch initialization

In [None]:
# https://medium.com/analytics-vidhya/rendering-openai-gym-environments-in-google-colab-9df4e7d6f99f
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install -U colabgymrender
!pip install pyglet==1.5.11

!pip install box2d-py
!pip install gym[Box_2D]

In [None]:
import numpy as np
import gym

from numpy.random import random, randint
from colabgymrender.recorder import Recorder
from fastprogress.fastprogress import master_bar, progress_bar

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.distributions import Normal

In [None]:
_ = torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Creating policy and env

In [None]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, max_action):
        super().__init__()

        self.max_action = max_action

        self.get_action = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh(),
        )

    def forward(self, state):
        """
        Returns action
        """
        return self.get_action(state) * self.max_action

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()

        self.get_action_value = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, state, action):
        """
        Returns state-action value
        """
        return self.get_action_value(torch.cat([state, action], axis=1))

In [None]:
class DDPG:
    def __init__(self,
                 state_dim, action_dim, hidden_dim, max_action,
                 buffer_size, batch_size,
                 gamma, soft_update_c):
        """
        DDPG policy
        """

        self.state_dim = state_dim
        self.action_dim = action_dim

        self.replay_buffer = []
        self.buffer_size = buffer_size
        self.batch_size = batch_size

        # Reward decay
        self.gamma = gamma
        # Coefficient used in soft model updates
        self.soft_update_c = soft_update_c

        # Actor model
        self.actor = Actor(state_dim, action_dim, hidden_dim, max_action).to(device)
        # Frozen actor
        self.target_actor = Actor(state_dim, action_dim, hidden_dim, max_action).to(device)

        self.target_actor.load_state_dict(self.actor.state_dict())
        self.actor_optim = Adam(self.actor.parameters(), lr=1e-2)

        # Critic model
        self.critic = Critic(state_dim, action_dim, hidden_dim).to(device)
        # Frozen critic
        self.target_critic = Critic(state_dim, action_dim, hidden_dim).to(device)

        self.target_critic.load_state_dict(self.critic.state_dict())
        self.critic_optim = Adam(self.critic.parameters(), lr=1e-2)

    def train(self, mode=True):
        """
        Switches between training and evaluation
        """

        self.actor.train(mode)
        self.target_actor.train(mode)
        self.critic.train(mode)
        self.target_critic.train(mode)

    def add_to_buffer(self, step):
        """
        Adds new step to buffer
        """

        self.replay_buffer.append(step)

        # Remove first half if buffer overflowed
        if len(self.replay_buffer) > self.buffer_size:
            self.replay_buffer = self.replay_buffer[len(self.replay_buffer) // 2:]

    def sample_step(self):
        """
        Samples a batch of steps from replay buffer
        """

        idx = np.random.choice(len(self.replay_buffer), self.batch_size)

        raw_batch = [[] for _ in range(5)]
        for i in idx:
            for j in range(5):
                raw_batch[j].append(self.replay_buffer[i][j])
        
        for j in range(5):
            raw_batch[j] = np.array(raw_batch[j], copy=False)

        return raw_batch
    
    def select_action(self, state):
        """
        Selects and returns an action as numpy array
        """

        state = torch.tensor(state).float().to(device)
        action = self.actor(state).cpu().detach().numpy()
        return action.reshape(-1)
    
    def sync_networks(self):
        """
        Syncronizes frozen networks. Uses soft update
        """

        for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
            target_param.data.copy_(self.soft_update_c * target_param.data +\
                                    (1 - self.soft_update_c) * param.data)
                
        for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
            target_param.data.copy_(self.soft_update_c * target_param.data +\
                                    (1 - self.soft_update_c) * param.data)

    
    def update_networks(self):
        """
        Updates networks with steps from current replay buffer
        """

        try:
            raw_batch = self.sample_step()

            state = torch.tensor(raw_batch[0]).float().to(device)
            action = torch.tensor(raw_batch[1]).float().to(device)
            reward = torch.tensor(raw_batch[2]).float().to(device).unsqueeze(1)
            next_state = torch.tensor(raw_batch[3]).float().to(device)
            done = torch.tensor(raw_batch[4]).float().to(device).unsqueeze(1)

            # Getting target reward
            with torch.no_grad():
                new_reward = (1 - done) * self.target_critic(next_state,
                                                             self.target_actor(next_state))
                target_reward = reward + self.gamma * new_reward.detach()

            # Updating critic
            predicted_reward = self.critic(state, action)
            critic_loss = F.mse_loss(target_reward, predicted_reward)

            self.critic_optim.zero_grad()
            critic_loss.backward()
            self.critic_optim.step()

            # Updating actor
            actor_loss = -self.critic(state, self.actor(state)).mean()

            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            # Syncronizes frozen models    
            self.sync_networks()
        
            return critic_loss.cpu().item(), actor_loss.cpu().item()

        except KeyboardInterrupt:
            return None, None

In [None]:
recorder_dir = './video'
env_rec = Recorder(gym.make('Pendulum-v0'), recorder_dir)
env = gym.make('Pendulum-v0').env

In [None]:
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
max_action = env.action_space.high[0]

# Training code

In [None]:
def get_window_avg(arr, window_size):
    return sum(arr[-window_size:]) / min(window_size, len(arr))

In [None]:
def train(env, policy, train_noise_std, use_epsilon_greedy=True,
          warmup=1000, max_episodes=10000, max_steps=1000, episodes_per_log=100):

    # Running average window size
    window_size = 100

    policy.train()

    rewards = []
    critic_losses = []
    actor_losses = []

    for episode in progress_bar(range(1, max_episodes + 1)):
        # Initializing new episode
        state = env.reset()
        episode_reward = 0

        for step in range(max_steps):
            # Choosing an action
            if episode <= warmup or (use_epsilon_greedy and np.random.rand() < 0.05):
                # Epsilon greedy and warmup, increases randomness if actor loss is very small
                action = env.action_space.sample()
            else:
                action = policy.select_action(state)
                action += np.random.normal(0, train_noise_std, action_dim)
                action = np.clip(action, -1, 1)

            # Taking an action
            next_state, reward, done, _ = env.step(action)
            # Adding step to buffer
            policy.add_to_buffer((state, action, reward, next_state, float(done)))
            # Updating state
            state = next_state

            # Updating rewards
            episode_reward += reward

            if done or step == max_steps - 1:
                break

        if episode > warmup:
            # Updating networks
            critic_loss, actor_loss = policy.update_networks()

            # Updating logs
            critic_losses.append(critic_loss)
            actor_losses.append(actor_loss)
            rewards.append(episode_reward)

            reward_window_avg = get_window_avg(rewards, window_size)
            critic_loss_window_avg = get_window_avg(critic_losses, window_size)
            actor_loss_window_avg = get_window_avg(actor_losses, window_size)

            if episode % episodes_per_log == 0:
                print(f"Average reward of last {min(window_size, episode)} episodes: {reward_window_avg}")
                print(f"Average critic loss of last {min(window_size, episode)} episodes: {critic_loss_window_avg}")
                print(f"Average actor loss of last {min(window_size, episode)} episodes: {actor_loss_window_avg}")
                print("---")
            
            # Useful if you're doing Box2D tasks
            # if reward_window_avg >= 200 and episode >= window_size:
            #     print("Solved!")
            #     return rewards, critic_losses, actor_losses
    
    print("Max episodes reached")
    return rewards, critic_losses, actor_losses

# Training

In [None]:
policy = DDPG(
    state_dim, action_dim, 128, max_action,
    1e6, 256,
    0.99, 0.995,
)

In [None]:
rewards, critic_losses, actor_losses = train(env, policy, 0.2, True, 2000, 4000, 500)

In [None]:
rewards, critic_losses, actor_losses = train(env, policy, 0.05, False, 0, 2000, 500)

# Evaluation

In [None]:
policy.train(False)

In [None]:
# In case you saved the model
policy.actor = torch.load("actor.pth")
policy.critic =  torch.load("critic.pth")

In [None]:
done = False
n = 1
observation = env_rec.reset()
total_reward = 0
while not done:
  action = policy.select_action(torch.Tensor(observation).to(device))
  observation, reward, done, _ = env_rec.step(action)
  n += 1

  if n <= 100:
      total_reward += reward

env_rec.play()
print(f"100 episode reward: {total_reward}")