diff --git a/ml-agents-envs/mlagents/envs/environment.py b/ml-agents-envs/mlagents/envs/environment.py index b6acf890cd..d642d792b6 100644 --- a/ml-agents-envs/mlagents/envs/environment.py +++ b/ml-agents-envs/mlagents/envs/environment.py @@ -11,6 +11,7 @@ from .brain import AllBrainInfo, BrainInfo, BrainParameters from .exception import ( UnityEnvironmentException, + UnityCommunicationException, UnityActionException, UnityTimeOutException, ) @@ -343,7 +344,7 @@ def reset( self._generate_reset_input(train_mode, config, custom_reset_parameters) ) if outputs is None: - raise KeyboardInterrupt + raise UnityCommunicationException("Communicator has stopped.") rl_output = outputs.rl_output s = self._get_state(rl_output) self._global_done = s[1] @@ -570,7 +571,7 @@ def step( with hierarchical_timer("communicator.exchange"): outputs = self.communicator.exchange(step_input) if outputs is None: - raise KeyboardInterrupt + raise UnityCommunicationException("Communicator has stopped.") rl_output = outputs.rl_output state = self._get_state(rl_output) self._global_done = state[1] diff --git a/ml-agents-envs/mlagents/envs/exception.py b/ml-agents-envs/mlagents/envs/exception.py index 7824740c47..f1c0bed80c 100644 --- a/ml-agents-envs/mlagents/envs/exception.py +++ b/ml-agents-envs/mlagents/envs/exception.py @@ -19,6 +19,14 @@ class UnityEnvironmentException(UnityException): pass +class UnityCommunicationException(UnityException): + """ + Related to errors with the communicator. + """ + + pass + + class UnityActionException(UnityException): """ Related to errors with sending actions. diff --git a/ml-agents-envs/mlagents/envs/subprocess_env_manager.py b/ml-agents-envs/mlagents/envs/subprocess_env_manager.py index 679e548956..babb20382c 100644 --- a/ml-agents-envs/mlagents/envs/subprocess_env_manager.py +++ b/ml-agents-envs/mlagents/envs/subprocess_env_manager.py @@ -2,6 +2,7 @@ import cloudpickle from mlagents.envs import UnityEnvironment +from mlagents.envs.exception import UnityCommunicationException from multiprocessing import Process, Pipe, Queue from multiprocessing.connection import Connection from queue import Empty as EmptyQueueException @@ -47,14 +48,14 @@ def send(self, name: str, payload=None): cmd = EnvironmentCommand(name, payload) self.conn.send(cmd) except (BrokenPipeError, EOFError): - raise KeyboardInterrupt + raise UnityCommunicationException("UnityEnvironment worker: send failed.") def recv(self) -> EnvironmentResponse: try: response: EnvironmentResponse = self.conn.recv() return response except (BrokenPipeError, EOFError): - raise KeyboardInterrupt + raise UnityCommunicationException("UnityEnvironment worker: recv failed.") def close(self): try: @@ -115,8 +116,9 @@ def _send_response(cmd_name, payload): _send_response("global_done", env.global_done) elif cmd.name == "close": break - except KeyboardInterrupt: - print("UnityEnvironment worker: keyboard interrupt") + except (KeyboardInterrupt, UnityCommunicationException): + print("UnityEnvironment worker: environment stopping.") + step_queue.put(EnvironmentResponse("env_close", worker_id, None)) finally: step_queue.close() env.close() @@ -171,6 +173,10 @@ def step(self) -> List[StepInfo]: try: while True: step = self.step_queue.get_nowait() + if step.name == "env_close": + raise UnityCommunicationException( + "At least one of the environments has closed." + ) self.env_workers[step.worker_id].waiting = False if step.worker_id not in step_workers: worker_steps.append(step) diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 607fc4ede1..cfa7911ae4 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -14,7 +14,10 @@ from mlagents.envs import BrainParameters from mlagents.envs.env_manager import StepInfo from mlagents.envs.env_manager import EnvManager -from mlagents.envs.exception import UnityEnvironmentException +from mlagents.envs.exception import ( + UnityEnvironmentException, + UnityCommunicationException, +) from mlagents.envs.sampler_class import SamplerManager from mlagents.envs.timers import hierarchical_timer, get_timer_tree, timed from mlagents.trainers import Trainer, TrainerMetrics @@ -302,15 +305,15 @@ def start_learning( # Final save Tensorflow model if global_step != 0 and self.train_model: self._save_model() - except KeyboardInterrupt: + except (KeyboardInterrupt, UnityCommunicationException): if self.train_model: self._save_model_when_interrupted() pass - env_manager.close() if self.train_model: self._write_training_metrics() self._export_graph() self._write_timing_tree() + env_manager.close() def end_trainer_episodes( self, env: BaseUnityEnvironment, lessons_incremented: Dict[str, bool]