In [31]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
from collections import deque, namedtuple
from matplotlib import pyplot as plt
import numpy as np
import random
import gym
import pdb

from src.utils.OUNoise import OUNoise 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

In [33]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Define Networks

In [34]:
class Actor(nn.Module):
    def __init__(self, state_space, action_space):
        super(Actor, self).__init__()
        
        self.noise = OUNoise(action_space)
        
        self.head = nn.Sequential(
            nn.Linear(state_space, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, action_space),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.head(x)
    
    def act(self, state, add_noise=True):
        state = torch.from_numpy(state).float()
        
        action = self.forward(state).cpu().data.numpy()
        if add_noise:
            action += self.noise.noise()

        return np.clip(action, -1, 1)


In [47]:
# class Critic(nn.Module):
#     def __init__(self, state_space, action_space):
#         super(Critic, self).__init__()
        
#         self.head = nn.Sequential(
#             nn.Linear(state_space, 1024),
#             nn.ReLU(),
#         )
        
#         self.body = nn.Sequential(
#             nn.Linear(1024 + action_space, 512),
#             nn.ReLU(),
#             nn.Linear(512, 300),
#             nn.ReLU(),
#             nn.Linear(300, 1),
#             nn.Tanh()
#         )
        
#         self.single = nn.Sequential(
#             nn.Linear(state_space + action_space, 256),
#             nn.ReLU(),
#             nn.Linear(256, 256),
#             nn.ReLU(),
#             nn.Linear(256, 1),
#             nn.Tanh()
#         )
    
#     def forward(self, x, actions):
#         actions = torch.tensor(actions).float()
#         x = torch.tensor(x).float()
        
#         x = self.head(x)
#         x = self.body(torch.cat((x, actions), dim=1))
        
# #         x = self.single(torch.cat((x, actions), dim=1))
#         return x

class Critic(nn.Module):

    def __init__(self, obs_dim, action_dim):
        super(Critic, self).__init__()

        self.obs_dim = obs_dim
        self.action_dim = action_dim

        self.linear1 = nn.Linear(self.obs_dim, 1024)
        self.linear2 = nn.Linear(1024 + self.action_dim, 512)
        self.linear3 = nn.Linear(512, 300)
        self.linear4 = nn.Linear(300, 1)

    def forward(self, x, a):
        a = torch.tensor(a).float()
        x = torch.tensor(x).float()

        x = F.relu(self.linear1(x))
        xa_cat = torch.cat([x,a], 1)
        xa = F.relu(self.linear2(xa_cat))
        xa = F.relu(self.linear3(xa))
        qval = self.linear4(xa)

        return qval

## Create environment with Agents

In [48]:
import gym

# env = gym.make("MountainCarContinuous-v0")
env = gym.make("Pendulum-v0")

state_space = env.observation_space.shape[0]
action_space = env.action_space.shape[0]

print("State space: {}".format(state_space))
print("Action space: {}".format(action_space))

State space: 3
Action space: 1


In [49]:
actor = Actor(state_space, action_space).to(device)
critic = Critic(state_space, action_space).to(device)

actor_target = Actor(state_space, action_space).to(device)
critic_target = Critic(state_space, action_space).to(device)

### Replay Buffer

In [50]:
class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer = deque(maxlen=int(buffer_size))
        self.Experience = namedtuple("experience", ["state", "next_state", "action", "reward", "done"])
    
    def add(self, state, next_state, action, reward, done):
        e = self.Experience(state, next_state, action, reward, done)
        self.buffer.append(e)
    
    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)
        
        states = torch.stack([ torch.tensor(exp.state) for exp in samples]).float().to(device)
        next_states = torch.stack([ torch.tensor(exp.next_state) for exp in samples]).float().to(device)
        actions = torch.stack([ torch.tensor(exp.action) for exp in samples]).float().to(device)
        rewards = torch.stack([ torch.tensor(exp.reward) for exp in samples]).float().to(device)
        dones = torch.stack([ torch.tensor(exp.done) for exp in samples]).float().to(device)
        
        return (states, next_states, actions, rewards, dones)
    
    def __len__(self):
        return len(self.buffer)

## Computing loss and updating Networks

In [51]:
actor_optimiser = optim.Adam(actor.parameters(), lr=1e-4)
critic_optimiser = optim.Adam(critic.parameters(), lr=1e-3)

In [52]:
def learn():
    states, next_states, actions, rewards, dones = mem.sample(batch_size)
    
    update_actor(states=states)
    
    update_critic(
        states=states,
        next_states=next_states,
        actions=actions,
        rewards=rewards,
        dones=dones
    )
    
    update_target_networks()

### Actor Update

<img src="./img/ddpg/actor_update.png" alt="Drawing" style="height: 50px;"/>

In [53]:
def update_actor(states):
    actions_pred = actor.act(states.numpy())
    loss = -critic(states, actions_pred).mean()
    
    actor_optimiser.zero_grad()
    loss.backward()
    actor_optimiser.step()

### Critic Update

Critic Loss:
<img src="./img/ddpg/critic_loss.png" alt="Drawing" style="height: 30px;"/>

Critic $y_i$:
<img src="./img/ddpg/critic_yi.png" alt="Drawing" style="height: 35px;"/>

In [54]:
def update_critic(states, next_states, actions, rewards, dones):
    next_actions = actor_target.act(next_states.numpy())
    
    y_i =  rewards + ( gamma * critic_target(next_states, next_actions).squeeze() * (1-dones ))
    expected_Q = critic(states, actions).squeeze()

    loss = F.mse_loss(y_i, expected_Q)
    
    critic_optimiser.zero_grad()
    loss.backward()
    critic_optimiser.step()

### Copy Weights Over

In [55]:
def update_target_networks():
    for target, local in zip(actor_target.parameters(), actor.parameters()):
        target.data.copy_(tau*local.data + (1.0-tau)*target.data)
        
    for target, local in zip(critic_target.parameters(), critic.parameters()):
        target.data.copy_(tau*local.data + (1.0-tau)*target.data)

## Runner

In [56]:
max_e = 1000
max_t = 500
buffer_size = 100000
batch_size = 32
learn_every = 1

gamma = 0.99
tau = 1e-2

In [57]:
mem = ReplayBuffer(buffer_size)

score_log = []
score_window = deque(maxlen=100)

In [58]:
for episode in range(max_e):
    state = env.reset()
    score = 0
    for t in range(max_t):
        action = actor.act(state, add_noise=False)
        next_state, reward, done, _ = env.step(action)
        
        score += reward

        mem.add(state, next_state, action, reward, done)
        
        if len(mem) > batch_size and t % learn_every == 0:
            learn()

        if done:
            break;
        
        state = next_state
    
    score_log.append(score)
    score_window.append(score)
    
#     print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score), end="")    
#     if (episode % 100 == 0):
#         print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score))
    print("\rEpsiode: {:.1f}\tWindow Score: {:.4f}\tScore: {:.4f}".format(episode, np.mean(score_window), score))





Epsiode: 0.0	Window Score: -1593.9523	Score: -1593.9523
Epsiode: 1.0	Window Score: -1676.7055	Score: -1759.4586
Epsiode: 2.0	Window Score: -1554.4905	Score: -1310.0606
Epsiode: 3.0	Window Score: -1452.5441	Score: -1146.7048
Epsiode: 4.0	Window Score: -1365.7773	Score: -1018.7101
Epsiode: 5.0	Window Score: -1446.2562	Score: -1848.6507
Epsiode: 6.0	Window Score: -1395.8280	Score: -1093.2586
Epsiode: 7.0	Window Score: -1356.8770	Score: -1084.2202
Epsiode: 8.0	Window Score: -1379.2193	Score: -1557.9578
Epsiode: 9.0	Window Score: -1344.8528	Score: -1035.5540
Epsiode: 10.0	Window Score: -1336.4333	Score: -1252.2390
Epsiode: 11.0	Window Score: -1311.3665	Score: -1035.6315
Epsiode: 12.0	Window Score: -1337.5126	Score: -1651.2650
Epsiode: 13.0	Window Score: -1320.6312	Score: -1101.1730
Epsiode: 14.0	Window Score: -1307.4878	Score: -1123.4803
Epsiode: 15.0	Window Score: -1316.9559	Score: -1458.9774
Epsiode: 16.0	Window Score: -1300.9168	Score: -1044.2921
Epsiode: 17.0	Window Score: -1286.3931	Sc

KeyboardInterrupt: 

In [None]:
plt.plot(score_log)