Skip to content

Commit

Permalink
add files
Browse files Browse the repository at this point in the history
  • Loading branch information
threewisemonkeys-as committed Sep 20, 2020
1 parent 3d233c4 commit 1c504cc
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 1 deletion.
275 changes: 275 additions & 0 deletions examples/distributed2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
from genrl.core.buffers import ReplayBuffer
import os

from genrl.agents import DDPG
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import argparse

import copy

import gym
import numpy as np

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

# to call a function on an rref, we could do the following
# _remote_method(some_func, rref, *args)

def _call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)


def _remote_method(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc.rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)


gloabl_lock = mp.Lock()


class ParamServer:
def __init__(self, init_params):
self.params = init_params
# self.lock = mp.Lock()

def store_params(self, new_params):
# with self.lock:
with gloabl_lock:
self.params = new_params

def get_params(self):
# with self.lock:
with gloabl_lock:
return self.params


class DistributedReplayBuffer:
def __init__(self, size):
self.size = size
self.len = 0
self._buffer = ReplayBuffer(self.size)


class DistributedOffPolicyTrainer:
"""Distributed Off Policy Trainer Class
Trainer class for Distributed Off Policy Agents
"""
def __init__(
self,
agent,
env,
**kwargs,
):
self.env = env
self.agent = agent

def train(
self, n_actors, max_buffer_size, batch_size, max_updates, update_interval
):

print("a")
world_size = n_actors + 2
completed = mp.Value("i", 0)
print("a")
param_server_rref_q = mp.Queue(1)
param_server_p = mp.Process(
target=run_param_server, args=(param_server_rref_q, world_size,)
)
param_server_p.start()
param_server_rref = param_server_rref_q.get()
param_server_rref_q.close()

print("a")
buffer_rref_q = mp.Queue(1)
buffer_p = mp.Process(target=run_buffer, args=(max_buffer_size, buffer_rref_q, world_size,))
buffer_p.start()
buffer_rref = buffer_rref_q.get()
buffer_rref_q.close()
print("a")

actor_ps = []
for i in range(n_actors):
a_p = mp.Process(
target=run_actor,
args=(
i,
copy.deepcopy(self.agent),
copy.deepcopy(self.env),
param_server_rref,~
buffer_rref,
world_size,
completed
),
)
a_p.start()
actor_ps.append(a_p)

learner_p = mp.Process(
target=run_learner,
args=(max_updates, batch_size, self.agent, param_server_rref, buffer_rref, world_size, completed),
)
learner_p.start()

learner_p.join()
for a in actor_ps:
a.join()
buffer_p.join()
param_server_p.join()


def run_param_server(q, world_size):
print("Running parameter server")
rpc.init_rpc(name="param_server", rank=0, world_size=world_size)
print("d")
param_server = ParamServer(None)
param_server_rref = rpc.RRef(param_server)
q.put(param_server_rref)
rpc.shutdown()
print("param server shutting down")


def run_buffer(max_buffer_size, q, world_size):
print("Running buffer server")
rpc.init_rpc(name="buffer", rank=1, world_size=world_size)
buffer = ReplayBuffer(max_buffer_size)
buffer_rref = rpc.RRef(buffer)
q.put(buffer_rref)
rpc.shutdown()
print("buffer shutting down")


def run_learner(max_updates, batch_size, agent, param_server_rref, buffer_rref, world_size, completed):
print("Running learner")
rpc.init_rpc(name="learner", rank=world_size - 1, world_size=world_size)
i = 0
while i < max_updates:
batch = _remote_method(ReplayBuffer.sample, buffer_rref, batch_size)
if batch is None:
continue
agent.update_params(batch)
_remote_method(ParamServer.store_params, param_server_rref, agent.get_weights())
print("weights updated")
i += 1
print(i)
completed.value = 1
rpc.shutdown()
print("learner shutting down")


def run_actor(i, agent, env, param_server_rref, buffer_rref, world_size, completed):
print(f"Running actor {i}")

rpc.init_rpc(name=f"action_{i}", rank=i + 1, world_size=world_size)

state = env.reset().astype(np.float32)

while not completed.value == 1:
params = _remote_method(ParamServer.get_params, param_server_rref)
agent.load_weights(params)

action = agent.select_action(state).numpy()
next_state, reward, done, _ = env.step(action)
next_state = next_state.astype(np.float32)
reward = np.array([reward]).astype(np.float32)
done = np.array([done]).astype(np.bool)

print("attempting to insert transition")
_remote_method(ReplayBuffer.push, buffer_rref, [state, action, reward, next_state, done])
print("inserted transition")
state = env.reset().astype(np.float32) if done else next_state.copy()

rpc.shutdown()
print("actor shutting down")

env = gym.make("Pendulum-v0")
agent = DDPG("mlp", env)

trainer = DistributedOffPolicyTrainer(agent, env)
trainer.train(
n_actors=1, max_buffer_size=100, batch_size=1, max_updates=100, update_interval=1
)


# if __name__ == '__main__':
# parser = argparse.ArgumentParser(
# description="Parameter-Server RPC based training")
# parser.add_argument(
# "--world_size",
# type=int,
# default=4,
# help="""Total number of participating processes. Should be the number
# of actors + 3.""")
# parser.add_argument(
# "--run",
# type=str,
# default="param_server",
# choices=["param_server", "buffer", "learner", "actor"],
# help="Which program to run")
# parser.add_argument(
# "--master_addr",
# type=str,
# default="localhost",
# help="""Address of master, will default to localhost if not provided.
# Master must be able to accept network traffic on the address + port.""")
# parser.add_argument(
# "--master_port",
# type=str,
# default="29500",
# help="""Port that master is listening on, will default to 29500 if not
# provided. Master must be able to accept network traffic on the host and port.""")

# args = parser.parse_args()

# os.environ['MASTER_ADDR'] = args.master_addr
# os.environ["MASTER_PORT"] = args.master_port

# processes = []
# world_size = args.world_size
# if args.run == "param_server":
# p = mp.Process(target=run_param_server, args=(world_size))
# p.start()
# processes.append(p)
# elif args.run == "buffer":
# p = mp.Process(target=run_buffer, args=(world_size))
# p.start()
# processes.append(p)
# # Get data to train on
# train_loader = torch.utils.data.DataLoader(
# datasets.MNIST('../data', train=True, download=True,
# transform=transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
# ])),
# batch_size=32, shuffle=True,)
# test_loader = torch.utils.data.DataLoader(
# datasets.MNIST(
# '../data',
# train=False,
# transform=transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
# ])),
# batch_size=32,
# shuffle=True,
# )
# # start training worker on this node
# p = mp.Process(
# target=run_worker,
# args=(
# args.rank,
# world_size, args.num_gpus,
# train_loader,
# test_loader))
# p.start()
# processes.append(p)

# for p in processes:
# p.join()
2 changes: 1 addition & 1 deletion genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,4 @@ def load_weights(self, weights) -> None:
Args:
weights (:obj:`dict`): Dictionary of different neural net weights
"""
self.ac.load_state_dict(weights["weights"])
self.ac.load_state_dict(weights)
3 changes: 3 additions & 0 deletions genrl/core/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def sample(
:returns: (Tuple composing of `state`, `action`, `reward`,
`next_state` and `done`)
"""
if batch_size > len(self.memory):
return None

batch = random.sample(self.memory, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return [
Expand Down
Empty file added genrl/distributed/actor.py
Empty file.
Empty file.
2 changes: 2 additions & 0 deletions genrl/distributed/transition_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from genrl.core.buffers import ReplayBuffer

0 comments on commit 1c504cc

Please sign in to comment.