In [47]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
from collections import deque, namedtuple
from matplotlib import pyplot as plt
import numpy as np
import random
import gym
import pdb

from src.utils.OUNoise import OUNoise 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

## Define Networks

In [49]:
class Actor(nn.Module):
    def __init__(self, state_space, action_space):
        super(Actor, self).__init__()
        
        self.noise = OUNoise(action_space)
        
        self.head = nn.Sequential(
            nn.Linear(state_space, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_space),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.head(x)
    
    def act(self, state, add_noise=True):
        state = torch.from_numpy(state).float()
        
        action = self.forward(state).cpu().data.numpy()
        if add_noise:
            action += self.noise.noise()

        return np.clip(action, -1, 1)


In [50]:
class Critic(nn.Module):
    def __init__(self, state_space, action_space):
        super(Critic, self).__init__()
        
        self.head = nn.Sequential(
            nn.Linear(state_space, 64),
            nn.ReLU(),
        )
        
        self.body = nn.Sequential(
            nn.Linear(64 + action_space, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        
        self.single = nn.Sequential(
            nn.Linear(state_space + action_space, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )
    
    def forward(self, x, actions):
        actions = torch.tensor(actions).float()
        x = torch.tensor(x).float()
        
#         x = self.head(x)
#         x = self.body(torch.cat((x, actions), dim=1))
        
        x = self.single(torch.cat((x, actions), dim=1))
        return x

## Create environment with Agents

In [51]:
import gym

# env = gym.make("MountainCarContinuous-v0")
env = gym.make("Pendulum-v0")

state_space = env.observation_space.shape[0]
action_space = env.action_space.shape[0]

print("State space: {}".format(state_space))
print("Action space: {}".format(action_space))

State space: 3
Action space: 1


In [52]:
actor = Actor(state_space, action_space)
critic = Critic(state_space, action_space)

actor_target = Actor(state_space, action_space)
critic_target = Critic(state_space, action_space)

### Replay Buffer

In [53]:
class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer = deque(maxlen=int(buffer_size))
        self.Experience = namedtuple("experience", ["state", "next_state", "action", "reward", "done"])
    
    def add(self, state, next_state, action, reward, done):
        e = self.Experience(state, next_state, action, reward, done)
        self.buffer.append(e)
    
    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)
        
        states = torch.stack([ torch.tensor(exp.state) for exp in samples]).float()
        next_states = torch.stack([ torch.tensor(exp.next_state) for exp in samples]).float()
        actions = torch.stack([ torch.tensor(exp.action) for exp in samples]).float()
        rewards = torch.stack([ torch.tensor(exp.reward) for exp in samples]).float()
        dones = torch.stack([ torch.tensor(exp.done) for exp in samples]).float()
        
        return (states, next_states, actions, rewards, dones)
    
    def __len__(self):
        return len(self.buffer)

## Computing loss and updating Networks

In [54]:
actor_optimiser = optim.Adam(actor.parameters(), lr=1e-4)
critic_optimiser = optim.Adam(critic.parameters(), lr=1e-3)

In [38]:
def learn():
    states, next_states, actions, rewards, dones = mem.sample(batch_size)
    
    update_actor(states=states)
    
    update_critic(
        states=states,
        next_states=next_states,
        actions=actions,
        rewards=rewards,
        dones=dones
    )
    
    update_target_networks()

### Actor Update

<img src="./img/ddpg/actor_update.png" alt="Drawing" style="height: 50px;"/>

In [39]:
def update_actor(states):
    actions_pred = actor.act(states.numpy())
    loss = -critic(states, actions_pred).mean()
    
    actor_optimiser.zero_grad()
    loss.backward()
    actor_optimiser.step()

### Critic Update

Critic Loss:
<img src="./img/ddpg/critic_loss.png" alt="Drawing" style="height: 30px;"/>

Critic $y_i$:
<img src="./img/ddpg/critic_yi.png" alt="Drawing" style="height: 35px;"/>

In [40]:
def update_critic(states, next_states, actions, rewards, dones):
    next_actions = actor_target.act(next_states.numpy())
    
    y_i =  rewards + ( gamma * critic_target(next_states, next_actions).squeeze() * (1-dones ))
    expected_Q = critic(states, actions).squeeze()

    loss = F.mse_loss(y_i, expected_Q)
    
    critic_optimiser.zero_grad()
    loss.backward()
    critic_optimiser.step()

### Copy Weights Over

In [41]:
def update_target_networks():
    for target, local in zip(actor_target.parameters(), actor.parameters()):
        target.data.copy_(tau*local.data + (1.0-tau)*target.data)
        
    for target, local in zip(critic_target.parameters(), critic.parameters()):
        target.data.copy_(tau*local.data + (1.0-tau)*target.data)

## Runner

In [42]:
max_e = 50
max_t = 200
buffer_size = 50000
batch_size = 32
learn_every = 1

gamma = 0.99
tau = 1e-2

In [58]:
mem = ReplayBuffer(buffer_size)

score_log = []
score_window = deque(maxlen=100)

In [None]:
for episode in range(max_e):
    state = env.reset()
    score = 0
    for t in range(max_t):
        action = actor.act(state)
        next_state, reward, done, _ = env.step(action)
        
        score += reward

        mem.add(state, next_state, action, reward, done)
        
        if len(mem) > batch_size and t % learn_every == 0:
            learn()

        if done:
            break;
        
        state = next_state
    
    score_log.append(score)
    score_window.append(score)
    
    print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score), end="")
    
    if (episode % 100 == 0):
        print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score))



  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Epsiode: 0.0	Window Score: nan	Score: -1133.2210
Epsiode: 5.0	Window Score: -1380.9968	Score: -1162.8464

In [None]:
plt.plot(score_log)