In [1]:
import gym
import numpy as np
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
from collections import deque
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from copy import deepcopy
from typing import List, Tuple, Dict
from gym import Space
import torch.multiprocessing as mp
from gym.vector import SyncVectorEnv, AsyncVectorEnv
from copy import deepcopy


In [2]:
from deeprl.common.utils import to_torch

In [3]:
from deeprl.common.utils import net_gym_space_dims, discount_cumsum, to_torch, compute_td_deltas, compute_gae_and_v_targets, normalise_adv
from deeprl.algos.a2c.a2c import A2C
from deeprl.common.base import Network, Policy, CategoricalPolicy, GaussianPolicy

In [4]:
env_name = "CartPole-v1"
envs = gym.vector.make(env_name)
policy_layers = [
    (nn.Linear,
        {"in_features": net_gym_space_dims(envs.single_observation_space),
        "out_features": 20}),
    (nn.ReLU, {}),
    (nn.Linear,
        {"in_features": 20,
        "out_features": 20}),
    (nn.ReLU, {}),
    (nn.Linear,{"in_features": 20, "out_features": net_gym_space_dims(envs.single_action_space)}),
]

critic_layers = [
    (nn.Linear, {"in_features": net_gym_space_dims(envs.single_observation_space), "out_features": 20}),
    (nn.ReLU, {}),
    (nn.Linear,
        {"in_features": 20,
        "out_features": 20}),
    (nn.ReLU, {}),
    (nn.Linear, {"in_features": 20, "out_features": 1}),
]

a2c_args = {
    "gamma": 0.99,
    "env_name": env_name,
    "step_lim": 200,
    "policy": CategoricalPolicy(policy_layers),
    "policy_optimiser": optim.Adam,
    "policy_lr": 0.002,
    "critic": Network(critic_layers),
    "critic_lr": 0.002,
    "critic_optimiser": optim.Adam,
    "critic_criterion": nn.MSELoss(),
    "device": "cpu",
    "entropy_coef": 0.01,
    "batch_size": 250,
    "num_train_passes": 1,
    "lam": 0.95,
    "num_eval_episodes": 15,
    "num_workers": 4,
    "minibatch_size": 10,
    "norm_adv": True
}


In [5]:
envs.close()

In [11]:
agent = A2C(a2c_args)
batch_size= 250
num_workers = 4
device=agent.device
env_name = "CartPole-v1"


# Do this in the init? or in a train function
def init_envs(env_name, num_envs, asynchronous):
    envs = gym.vector.make(env_name, num_envs=num_envs, asynchronous=asynchronous)
    envs.reset()
    return envs

In [12]:
def collect_rollout(agent, batch_size, num_workers, device, envs):
    # Do this in the train function
    states_batch = torch.zeros((batch_size, num_workers) + envs.single_observation_space.shape, dtype=torch.float).to(device)
    actions_batch = torch.zeros((batch_size, num_workers) + envs.single_action_space.shape, dtype=torch.float).to(device)
    rewards_batch = torch.zeros((batch_size, num_workers)).to(device)
    dones_batch = torch.zeros((batch_size, num_workers)).to(device)
    next_states_batch = torch.zeros((batch_size, num_workers) + envs.single_observation_space.shape).to(device)

    # Put this into the rollout function
    states = deepcopy(envs.observations)

    for step in range(batch_size):
        
        actions = agent.choose_action(states)
        next_states, rewards, dones, infos = envs.step(actions)
        
        states_batch[step] = to_torch(states, device)
        actions_batch[step] = to_torch(actions, device)
        rewards_batch[step] = to_torch(rewards, device)
        dones_batch[step] = to_torch(dones, device)
        next_states_batch[step] = to_torch(next_states, device)

        states = next_states

    return states_batch, actions_batch, rewards_batch, dones_batch, next_states_batch


def process_rollout(agent, rollout):
    batches = [None for _ in range(num_workers)]
    states_batch, actions_batch, rewards_batch, dones_batch, next_states_batch = rollout
    for j in range(num_workers):
        batch = {}
        batch["states"] = states_batch[:, j]
        batch["actions"] = actions_batch[:, j]
        batch["rewards"] = rewards_batch[:, j]
        batch["dones"] = dones_batch[:, j]
        batch["next_states"] = next_states_batch[:, j]
        batch["advantages"], batch["v_targets"] = compute_gae_and_v_targets(agent.critic, batch, device, agent.gamma, agent.lam)
        batches[j] = batch
        # batch_reward = torch.reduce_sum(batch["rewards"])
        # batch_log = {"rewards_sum": batch_reward}

    concat_batch = {k: torch.concat([b[k] for b in batches]) for k in batches[0].keys()}

    return concat_batch

```assert len(concat_batch["states"]) == batch_size * num_workers
assert len(concat_batch["next_states"]) == batch_size * num_workers
assert concat_batch["states"].dtype == torch.float32
assert concat_batch["rewards"].shape == (batch_size * num_workers,)
assert concat_batch["dones"].shape == (batch_size * num_workers,)
assert concat_batch["actions"].shape == (batch_size * num_workers,)
assert (not concat_batch["advantages"].requires_grad)````

### Write a function which shuffles and splits a batch (which has adv and v_targets)

In [13]:
def minibatch_split(batch, mb_size, shuffle=True):
    """Returns a list of minibatch dictionaries."""
    batch_size = len(batch["states"])

    if batch_size % mb_size != 0 or mb_size > batch_size:
        raise ValueError("Minibatch size does not divide batch size.")

    batch_idc = np.arange(batch_size)
    
    if shuffle:
        np.random.shuffle(batch_idc)
    
    mb_indices = [batch_idc[i: i+mb_size] for i in range(0, batch_size, mb_size)]
    
    minibatches = []

    for mb_idc in mb_indices:
        minibatch = {k: v[mb_idc] for k, v in batch.items()}
        minibatches.append(minibatch)

    return minibatches

In [14]:
def update_from_batch(agent, num_passes, batch):
    total_policy_loss = 0.
    total_critic_loss = 0.
    for _ in range(num_passes):
        minibatches = minibatch_split(batch, agent.minibatch_size)
        for minibatch in minibatches:
            
            if agent.norm_adv:
                minibatch["advantages"] = normalise_adv(minibatch["advantages"])
                
            policy_loss = agent.update_policy(minibatch)
            critic_loss = agent.update_critic(minibatch)

            total_policy_loss += policy_loss[0]
            total_critic_loss += critic_loss[0]

    mean_policy_loss = total_policy_loss / num_passes / len(minibatches)
    mean_critic_loss = total_critic_loss / num_passes / len(minibatches)

    return mean_policy_loss, mean_critic_loss





In [15]:
num_passes=4
envs = init_envs(env_name, num_workers, False)
r = []
num_epochs = 5
for e in range(num_epochs):
    rollout = collect_rollout(agent, batch_size, num_workers, device, envs)
    batch = process_rollout(agent, rollout)
    p_loss, c_loss = update_from_batch(agent, num_passes, batch)
    print(agent.run_eval())
    
# Why are the rewards so big?

[300.0, 333.0, 228.0, 255.0, 259.0, 280.0, 184.0, 216.0, 260.0, 142.0]
[500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0]
[9.0, 10.0, 10.0, 10.0, 9.0, 10.0, 10.0, 9.0, 9.0, 9.0]
[69.0, 57.0, 87.0, 79.0, 44.0, 77.0, 70.0, 72.0, 46.0, 83.0]
[9.0, 10.0, 10.0, 9.0, 9.0, 8.0, 9.0, 10.0, 10.0, 8.0]
