diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py index f22209c7..513376d2 100644 --- a/genrl/distributed/actor.py +++ b/genrl/distributed/actor.py @@ -48,7 +48,7 @@ def act( experience_server = get_proxy(experience_server_name) learner = get_proxy(learner_name) print(f"{name}: Begining experience collection") - while not learner.is_done(): + while not learner.is_completed(): collect_experience(agent, parameter_server, experience_server) rpc.shutdown() diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py index 90f9f178..c9181052 100644 --- a/genrl/distributed/learner.py +++ b/genrl/distributed/learner.py @@ -42,5 +42,6 @@ def learn( parameter_server = get_proxy(parameter_server_name) experience_server = get_proxy(experience_server_name) print(f"{name}: Beginning training") - trainer.train_wrapper(parameter_server, experience_server) + trainer.train(parameter_server, experience_server) + trainer.set_completed(True) rpc.shutdown() diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py index c254a0b0..a93b7498 100644 --- a/genrl/trainers/distributed.py +++ b/genrl/trainers/distributed.py @@ -12,14 +12,12 @@ def __init__(self, agent): def train(self, parameter_server, experience_server): raise NotImplementedError - def train_wrapper(self, parameter_server, experience_server): - self._completed_training_flag = False - self.train(parameter_server, experience_server) - self._completed_training_flag = True - - def is_done(self): + def is_completed(self): return self._completed_training_flag + def set_completed(self, value=True): + self._completed_training_flag = value + def evaluate(self, timestep, render: bool = False) -> None: """Evaluate performance of Agent