In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import deque, namedtuple
from matplotlib import pyplot as plt
import numpy as np
import random
import gym
import pdb

from src.utils.OUNoise import OUNoise 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Define Networks

In [4]:
class Actor(nn.Module):
    def __init__(self, state_space, action_space):
        super(Actor, self).__init__()
        
        self.noise = OUNoise(action_space)
        
        self.head = nn.Sequential(
            nn.Linear(state_space, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, action_space),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.head(x)
    
    def act(self, state, add_noise=True):
        
        state = torch.from_numpy(state).float()
        
        action = self.forward(state).cpu().data.numpy()
        if add_noise:
            action += self.noise.noise()

        return np.clip(action, -1, 1)


In [5]:
# class Critic(nn.Module):
#     def __init__(self, state_space, action_space):
#         super(Critic, self).__init__()
        
#         self.head = nn.Sequential(
#             nn.Linear(state_space, 1024),
#             nn.ReLU(),
#         )
        
#         self.body = nn.Sequential(
#             nn.Linear(1024 + action_space, 512),
#             nn.ReLU(),
#             nn.Linear(512, 300),
#             nn.ReLU(),
#             nn.Linear(300, 1),
#         )
        
#         self.single = nn.Sequential(
#             nn.Linear(state_space + action_space, 256),
#             nn.ReLU(),
#             nn.Linear(256, 256),
#             nn.ReLU(),
#             nn.Linear(256, 1),
#         )
    
#     def forward(self, x, actions):
#         actions = torch.tensor(actions).float()
#         x = torch.tensor(x).float()
        
#         x = self.head(x)
#         x = self.body(torch.cat((x, actions), dim=1))
        
# #         x = self.single(torch.cat((x, actions), dim=1))
#         return x

In [6]:
class Critic(nn.Module):

    def __init__(self, obs_dim, action_dim):
        super(Critic, self).__init__()

        self.obs_dim = obs_dim
        self.action_dim = action_dim

        self.linear1 = nn.Linear(self.obs_dim, 1024)
        self.linear2 = nn.Linear(1024 + self.action_dim, 512)
        self.linear3 = nn.Linear(512, 300)
        self.linear4 = nn.Linear(300, 1)

    def forward(self, x, a):
        x = F.relu(self.linear1(x))
        xa_cat = torch.cat([x,a], 1)
        xa = F.relu(self.linear2(xa_cat))
        xa = F.relu(self.linear3(xa))
        qval = self.linear4(xa)

        return qval

# class Actor(nn.Module):

#     def __init__(self, obs_dim, action_dim):
#         super(Actor, self).__init__()

#         self.obs_dim = obs_dim
#         self.action_dim = action_dim

#         self.linear1 = nn.Linear(self.obs_dim, 512)
#         self.linear2 = nn.Linear(512, 128)
#         self.linear3 = nn.Linear(128, self.action_dim)

#     def forward(self, obs):
#         x = F.relu(self.linear1(obs))
#         x = F.relu(self.linear2(x))
#         x = torch.tanh(self.linear3(x))

#         return x
    
#     def act(self, obs, add_noise):
#         state = torch.FloatTensor(obs).unsqueeze(0)
#         action = self.forward(state)
#         action = action.squeeze(0).cpu().detach().numpy()

#         return action
    




## Create environment with Agents

In [7]:
import gym

# env = gym.make("MountainCarContinuous-v0")
env = gym.make("Pendulum-v0")

state_space = env.observation_space.shape[0]
action_space = env.action_space.shape[0]

print("State space: {}".format(state_space))
print("Action space: {}".format(action_space))

State space: 3
Action space: 1


  result = entry_point.load(False)


In [8]:
actor = Actor(state_space, action_space).to(device)
critic = Critic(state_space, action_space).to(device)

actor_target = Actor(state_space, action_space).to(device)
critic_target = Critic(state_space, action_space).to(device)

### Replay Buffer

In [9]:
class BasicBuffer:

  def __init__(self, max_size):
      self.max_size = max_size
      self.buffer = deque(maxlen=max_size)

  def push(self, state, action, reward, next_state, done):
      experience = (state, action, np.array([reward]), next_state, done)
      self.buffer.append(experience)

  def sample(self, batch_size):
      state_batch = []
      action_batch = []
      reward_batch = []
      next_state_batch = []
      done_batch = []

      batch = random.sample(self.buffer, batch_size)

      for experience in batch:
          state, action, reward, next_state, done = experience
          state_batch.append(state)
          action_batch.append(action)
          reward_batch.append(reward)
          next_state_batch.append(next_state)
          done_batch.append(done)

      return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)

  def __len__(self):
      return len(self.buffer)

## Computing loss and updating Networks

In [10]:
actor_optimiser = optim.Adam(actor.parameters(), lr=1e-3)
critic_optimiser = optim.Adam(critic.parameters(), lr=1e-3)

In [11]:
def learn():

    state_batch, action_batch, reward_batch, next_state_batch, masks = mem.sample(batch_size)
    
    state_batch = torch.FloatTensor(state_batch)
    action_batch = torch.FloatTensor(action_batch)
    reward_batch = torch.FloatTensor(reward_batch)
    next_state_batch = torch.FloatTensor(next_state_batch)
    masks = torch.FloatTensor(masks)

    curr_Q = critic.forward(state_batch, action_batch)

    next_actions = actor_target.forward(next_state_batch)
    next_Q = critic_target.forward(next_state_batch, next_actions.detach())
    expected_Q = reward_batch + gamma * next_Q

    # update critic
    q_loss = F.mse_loss(curr_Q, expected_Q.detach())

    critic_optimiser.zero_grad()
    q_loss.backward() 
    critic_optimiser.step()

    # update actor
    policy_loss = -critic.forward(state_batch, actor.forward(state_batch)).mean()

    actor_optimiser.zero_grad()
    policy_loss.backward()
    actor_optimiser.step()

    # update target networks 
    for target_param, param in zip(actor_target.parameters(), actor.parameters()):
        target_param.data.copy_(param.data * tau + target_param.data * (1.0 - tau))

    for target_param, param in zip(critic_target.parameters(), critic.parameters()):
        target_param.data.copy_(param.data * tau + target_param.data * (1.0 - tau))

# def learn():
# #     states, next_states, actions, rewards, dones = mem.sample(batch_size)
#     state_batch, action_batch, reward_batch, next_state_batch, masks = mem.sample(batch_size)
#     update_actor(states=state_batch)
    
#     update_critic(
#         states=state_batch,
#         next_states=next_state_batch,
#         actions=action_batch,
#         rewards=reward_batch,
#         dones=masks
#     )
    
#     update_target_networks()

### Actor Update

<img src="./img/ddpg/actor_update.png" alt="Drawing" style="height: 50px;"/>

In [12]:
def update_actor(states):
#     actions_pred = actor.act(states.numpy())
#     loss = -critic(states, actions_pred).mean()

    actions_pred = actor.act(np.array(states))
    loss = -critic(np.array(states), actions_pred).mean()
    
    actor_optimiser.zero_grad()
    loss.backward()
    actor_optimiser.step()

### Critic Update

Critic Loss:
<img src="./img/ddpg/critic_loss.png" alt="Drawing" style="height: 30px;"/>

Critic $y_i$:
<img src="./img/ddpg/critic_yi.png" alt="Drawing" style="height: 35px;"/>

In [13]:
def update_critic(states, next_states, actions, rewards, dones):
#     next_actions = actor_target.act(next_states.numpy())
    next_actions = actor_target.act(np.array(next_states))
    
    y_i =  rewards + ( gamma * critic_target(np.array(next_states), next_actions).squeeze() * (1-dones ))
    expected_Q = critic(np.array(states), actions).squeeze()

    loss = F.mse_loss(y_i, expected_Q)
    
    critic_optimiser.zero_grad()
    loss.backward()
    critic_optimiser.step()

### Copy Weights Over

In [14]:
def update_target_networks():
    for target, local in zip(actor_target.parameters(), actor.parameters()):
        target.data.copy_(tau*local.data + (1.0-tau)*target.data)
        
    for target, local in zip(critic_target.parameters(), critic.parameters()):
        target.data.copy_(tau*local.data + (1.0-tau)*target.data)

## Runner

In [15]:
max_e = 1000
max_t = 500
buffer_size = 100000
batch_size = 32
learn_every = 1

gamma = 0.99
tau = 1e-2

In [16]:
mem = BasicBuffer(buffer_size)

score_log = []
score_window = deque(maxlen=100)

In [18]:
for episode in range(max_e):
    state = env.reset()
    score = 0
    for t in range(max_t):
        action = actor.act(state, add_noise=False)
        next_state, reward, done, _ = env.step(action)
        mem.push(state, action, reward, next_state, done)
        score += reward


        
        if len(mem) > batch_size and t % learn_every == 0:
            learn()

        if done:
            break;
        
        state = next_state
    
    score_log.append(score)
    score_window.append(score)
    
#     print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score), end="")    
#     if (episode % 100 == 0):
#         print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score))
    print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score))



Epsiode: 0.0	Window Score: -1167.3611	Score: -1222.7608
Epsiode: 1.0	Window Score: -1327.0833	Score: -1646.5275
Epsiode: 2.0	Window Score: -1338.3269	Score: -1372.0579
Epsiode: 3.0	Window Score: -1415.6801	Score: -1725.0928
Epsiode: 4.0	Window Score: -1378.5930	Score: -1193.1575
Epsiode: 5.0	Window Score: -1350.1367	Score: -1179.3987
Epsiode: 6.0	Window Score: -1324.6458	Score: -1146.2098
Epsiode: 7.0	Window Score: -1276.9889	Score: -895.7332
Epsiode: 8.0	Window Score: -1248.7883	Score: -994.9829
Epsiode: 9.0	Window Score: -1228.2605	Score: -1022.9833
Epsiode: 10.0	Window Score: -1188.1958	Score: -747.4842
Epsiode: 11.0	Window Score: -1145.8864	Score: -638.1735


KeyboardInterrupt: 

In [None]:
plt.plot(score_log)