In [1]:
import os
import gym
import ptan
import numpy as np
import collections

import torch
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp

In [2]:
GAMMA = 0.99
LEARNING_RATE = 0.001
ENTROPY_BETA = 0.01
BATCH_SIZE = 128

REWARD_STEPS = 4
CLIP_GRAD = 0.1

PROCESS_COUNT = 4
NUM_ENVS = 8
MICRO_BATCH_SIZE = 32

ENV_NAME = 'LunarLander-v0'
NAME = 'Lunar_Lander'
REWARD_BOUND = 18

In [3]:
def make_env():
    return ptan.common.wrappers.wrap_dqn(gym.make(ENV_NAME))

TotalReward = collections.namedtuple('TotalReward',
                                     field_names='reward')

In [4]:
def data_func(net, device, train_queue):
    envs = [make_env() for _ in range(NUM_ENVS)]
    agent = ptan.agent.PolicyAgent(lambda x: net(x)[0],
                                   device=device,
                                   apply_softmax=True)
    exp_source = ptan.experience.ExperienceSourceFirstLast(
                  envs, agent, gamma=GAMMA, steps_count=REWARD_STEPS)
    micro_batch = []
    
    for exp in exp_source:
        new_rewards = exp_source.pop_total_rewards()
        if new_rewards:
            data = TotalReward(reward=np.mean(new_rewards))
            train_queue.put(data)
        
        micro_batch.append(exp)
        if len(micro_batch) < MICRO_BATCH:
            continue
        
        data = common.unpack_batch(micro_batch, net, device=device,
                                  last_val_gamma=GAMMMA**REWARD_STEPS)
        train_queue.put(data)
        micro_batch.clear()

In [5]:
mp.set_start_method('spawn')
os.environ['OMP_NUM_THREADS'] = '1'
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [None]:
env = make_env()
net = A2CNet(env.observation_space.shape[0],
             env.action_space.n).to(device)

net.share_memory()
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE,
                       eps=1e-3)

train_queue = mp.Queue(maxsize=PROCESSES_COUNT)
data_proc_list = []

In [None]:
for _ in range(PROCESSES_COUNT):
    data_proc = mp.Process(target=data_func,
                           args=(net, device, train_queue))
    data_proc.start()
    data_proc_list.append(data_proc)