In [2]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import random
from collections import namedtuple, deque

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device =  "cpu"
env = gym.make("LunarLanderContinuous-v2")
action_size = env.action_space.shape[0]
state_size = env.observation_space.shape[0]

ROLLOUT_LENGTH = 300
DISCOUNT = 0.99
GAE_LAMBDA = 0.95
OPTIMIZATION_EPOCHS = 10
MINI_BATCH_SIZE = 64
PPO_RATIO_CLIP = 0.1
GRADIENT_CLIP = 0.75
HIDDEN_LAYERS = 32

print("State size {} with action size {}".format(state_size, action_size))

State size 8 with action size 2


In [4]:
# Thank you for these helper routines to Shangtong Zhang 
# https://github.com/ShangtongZhang/DeepRL

def layer_init(layer, w_scale=1.0):
    nn.init.orthogonal_(layer.weight.data)
    layer.weight.data.mul_(w_scale)
    nn.init.constant_(layer.bias.data, 0)
    return layer

def to_np(t):
    return t.cpu().detach().numpy()

def tensor(x):
    if isinstance(x, torch.Tensor):
        return x
    x = torch.tensor(x, device=device, dtype=torch.float32)
    return x

def random_sample(indices, batch_size):
    indices = np.asarray(np.random.permutation(indices))
    batches = indices[:len(indices) // batch_size * batch_size].reshape(-1, batch_size)
    for batch in batches:
        yield batch
    r = len(indices) % batch_size
    if r:
        yield indices[-r:]


In [5]:
class SubNetwork(nn.Module):
    
    def __init__(self, input_size, hidden_units, output_size, seed):
        super(SubNetwork, self).__init__()
        dims = (input_size,) + hidden_units        
        self.layers = nn.ModuleList([layer_init(nn.Linear(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
        self.feature_dim = dims[-1]
        self.output_layer = layer_init(nn.Linear(self.feature_dim, output_size), 1e-3)
        
    def forward(self, x):
        for layer in self.layers:
            x = F.tanh(layer(x))
        x = self.output_layer(x)    
        return x    
            
class ActorAndCritic(nn.Module):
    
    def __init__(self, state_size, action_size, seed):
        super(ActorAndCritic, self).__init__()
        self.seed = random.seed(seed)
        self.actor = SubNetwork(state_size, (HIDDEN_LAYERS, HIDDEN_LAYERS), action_size, seed)
        self.critic = SubNetwork(state_size, (HIDDEN_LAYERS, HIDDEN_LAYERS), 1, seed)
        self.std = nn.Parameter(torch.zeros(action_size))
        #self.to(Config.DEVICE)
        
    def forward(self, obs, action=None):
        obs = tensor(obs)
        a = self.actor(obs)
        v = self.critic(obs)
        mean = F.tanh(a)
        dist = torch.distributions.Normal(mean, F.softplus(self.std))
        return (v, dist)

In [6]:
network = ActorAndCritic(state_size, action_size, 0).to(device)
network

ActorAndCritic(
  (actor): SubNetwork(
    (layers): ModuleList(
      (0): Linear(in_features=8, out_features=32, bias=True)
      (1): Linear(in_features=32, out_features=32, bias=True)
    )
    (output_layer): Linear(in_features=32, out_features=2, bias=True)
  )
  (critic): SubNetwork(
    (layers): ModuleList(
      (0): Linear(in_features=8, out_features=32, bias=True)
      (1): Linear(in_features=32, out_features=32, bias=True)
    )
    (output_layer): Linear(in_features=32, out_features=1, bias=True)
  )
)

In [7]:
def get_actions_from_states(network, states):
    if not isinstance(states, torch.Tensor):
        states = torch.tensor(states, dtype=torch.float32, device=device)
    value, distribution = network.forward(states)
    actions = distribution.sample()
    return to_np(actions), distribution, value

def get_log_probability_from_actions(distribution, action):
    if not isinstance(action, torch.Tensor):
        action = torch.Tensor(action)
    log_prob = distribution.log_prob(action)
    return log_prob.cpu().detach().numpy()

In [12]:
#Random walk
state = env.reset()
while True:
    action, distribution, value = get_actions_from_states(network, state)
    log_prob = get_log_probability_from_actions(distribution, action)
    env.render()
    next_state, reward, done, _ = env.step(action)
    state = next_state
    time.sleep(0.01)
    if done: break
env.close()

In [17]:
class Rollout:
    def __init__(self):
        self.actions = []
        self.states = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
        self.advantages = []
        self.returns = []
        self.values = []

    def save_rollout_data(self, state, action, value, reward, done, log_prob):
        self.states.append(state)
        self.dones.append(done)
        self.values.append(value)
        self.actions.append(action)
        self.log_probs.append(log_prob)
        self.rewards.append(reward)

    def calculate_advantage_and_returns(self):
        pass


    def stack(self, data):
        return torch.cat(data[:ROLLOUT_LENGTH], dim=0)

    def sample(self):
        self.calculate_advantage_and_returns()
        states = self.stack(self.states)
        actions = self.stack(self.actions)
        returns = self.stack(self.returns)
        log_probs = self.stack(self.log_probs)
        advantages = self.stack(self.advantages)
        values = self.stack(self.values)
        return states, advantages, actions, returns, log_probs


    

In [25]:
class MasterAgent():

    def __init__(self, state_size, action_size, seed):
        self.network = ActorAndCritic(state_size, action_size, seed)
        self.avg_score = 0
        self.score_list = deque(maxlen=100)
        
    def rollout_init(self):
        self.rollout = Rollout()
    
    def optimize(self):
        states, advantages, actions, returns, log_probs = self.rollout.sample()
        for _ in range(OPTIMIZATION_EPOCHS):
            batch_seletion_iterator = random_sample(np.arange(states.size(0)), MINI_BATCH_SIZE)
            for batch in batch_seletion_iterator:
                state_t = states[batch]
                actions_t = actions[batch]
                log_pros_t = log_probs[batch]
                

    def step(self, state, action, value, reward, done, log_prob):
        state = torch.Tensor(state, device = device).float().reshape(1,len(state))
        value = torch.from_numpy(np.array(value)).float().to(device).reshape(1,1)
        done = torch.Tensor(np.asarray([int(done)]), device = device).reshape(1,1)
        log_prob = torch.Tensor(log_prob,device = device).float().reshape(1,len(log_prob))
        reward = torch.from_numpy(np.array(reward)).float().to(device).reshape(1,1)
        action = torch.from_numpy(np.array(action)).float().to(device).reshape(1, len(action))
        self.rollout.save_rollout_data(state, action, next_state, reward, done, log_prob)

    def train(self, iteration = 300):
        self.state = env.reset()
        for i in range(iteration):
            score = self.run(env)
            self.score_list.append(score)
            self.optimize()
            

    def run(self, env):
        self.rollout_init()
        score = 0
        for _ in range(ROLLOUT_LENGTH):
            action, distribution, value = get_actions_from_states(network, self.state)
            log_prob = get_log_probability_from_actions(distribution, action)
            next_state, reward, done, _ = env.step(action)
            score += reward
            self.step(self.state, action, value, reward, done, log_prob)
            self.state = next_state
            if done:
                state = env.reset()
        print("Terminating Rollout")
        return score


In [26]:
session = MasterAgent(state_size, action_size, 2)

In [27]:
session.train(1)

RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

In [257]:
states, next_states, actions, dones, rewards = session.rollout.sample()

In [279]:
data = list(random_sample(np.arange(300),64))[0]

In [281]:
states[tensor(data).long()].shape

torch.Size([64, 8])

tensor([[-0.1348,  1.6143, -0.6328,  0.2763,  0.2140,  0.1709,  0.0000,  0.0000],
        [-0.1606,  1.6388, -0.6496,  0.2393,  0.2455,  0.1232,  0.0000,  0.0000]])

In [282]:
tensor(data)

tensor([277.,  84., 129.,   1.,  46., 247., 187., 281., 248., 239.,  37., 230.,
        236.,  88., 262., 146.,  58.,  92., 251., 265., 284., 225., 206., 140.,
        290., 195., 266., 100., 214., 184., 159., 163., 116.,   6., 124.,  13.,
        193., 106.,  45.,  98.,  23., 145.,  56.,  26.,  44., 254., 175.,  94.,
        157., 295., 179., 180., 256., 237.,  33., 231., 299., 161., 286.,  83.,
        189., 139., 107., 165.])