Skip to content

Commit

Permalink
added mp
Browse files Browse the repository at this point in the history
  • Loading branch information
threewisemonkeys-as committed Sep 6, 2020
1 parent 4cba727 commit 3d233c4
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 130 deletions.
76 changes: 76 additions & 0 deletions examples/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from genrl.agents import DDPG
from genrl.trainers import OffPolicyTrainer
from genrl.trainers.distributed import DistributedOffPolicyTrainer
from genrl.environments import VectorEnv
import gym
import reverb
import numpy as np
import multiprocessing as mp
import threading


# env = VectorEnv("Pendulum-v0")
# agent = DDPG("mlp", env)
# trainer = OffPolicyTrainer(agent, env)
# trainer.train()


env = gym.make("Pendulum-v0")
agent = DDPG("mlp", env)

# o = env.reset()
# action = agent.select_action(o)
# next_state, reward, done, info = env.step(action.numpy())

# buffer_server = reverb.Server(
# tables=[
# reverb.Table(
# name="replay_buffer",
# sampler=reverb.selectors.Uniform(),
# remover=reverb.selectors.Fifo(),
# max_size=10,
# rate_limiter=reverb.rate_limiters.MinSize(4),
# )
# ],
# port=None,
# )
# client = reverb.Client(f"localhost:{buffer_server.port}")
# print(client.server_info())

# state = env.reset()
# action = agent.select_action(state)
# next_state, reward, done, info = env.step(action.numpy())

# state = next_state.copy()
# print(client.server_info())
# print("going to insert")
# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1})
# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1})
# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1})
# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1})
# client.insert([state, action, np.array([reward]), np.array([done]), next_state], {"replay_buffer": 1})
# print("inserted")

# # print(list(client.sample('replay_buffer', num_samples=1)))


# def sample(address):
# print("-- entered proc")
# client = reverb.Client(address)
# print("-- started client")
# print(list(client.sample('replay_buffer', num_samples=1)))

# a = f"localhost:{buffer_server.port}"
# print("create process")
# # p = mp.Process(target=sample, args=(a,))
# p = threading.Thread(target=sample, args=(a,))
# print("start process")
# p.start()
# print("wait process")
# p.join()
# print("end process")

trainer = DistributedOffPolicyTrainer(agent, env)
trainer.train(
n_actors=2, max_buffer_size=100, batch_size=4, max_updates=10, update_interval=1
)
2 changes: 1 addition & 1 deletion genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def sample_from_buffer(self, beta: float = None):
*[states, actions, rewards, next_states, dones, indices, weights]
)
else:
raise NotImplementedError
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
return batch

def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
Expand Down
3 changes: 3 additions & 0 deletions genrl/agents/deep/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def get_hyperparams(self) -> Dict[str, Any]:
}
return hyperparams

def get_weights(self):
return self.ac.state_dict()

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
Loading

0 comments on commit 3d233c4

Please sign in to comment.