Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ml-agents-envs/mlagents/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .brain import AllBrainInfo, BrainInfo, BrainParameters
from .exception import (
UnityEnvironmentException,
UnityCommunicationException,
UnityActionException,
UnityTimeOutException,
)
Expand Down Expand Up @@ -343,7 +344,7 @@ def reset(
self._generate_reset_input(train_mode, config, custom_reset_parameters)
)
if outputs is None:
raise KeyboardInterrupt
raise UnityCommunicationException("Communicator has stopped.")
rl_output = outputs.rl_output
s = self._get_state(rl_output)
self._global_done = s[1]
Expand Down Expand Up @@ -570,7 +571,7 @@ def step(
with hierarchical_timer("communicator.exchange"):
outputs = self.communicator.exchange(step_input)
if outputs is None:
raise KeyboardInterrupt
raise UnityCommunicationException("Communicator has stopped.")
rl_output = outputs.rl_output
state = self._get_state(rl_output)
self._global_done = state[1]
Expand Down
8 changes: 8 additions & 0 deletions ml-agents-envs/mlagents/envs/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ class UnityEnvironmentException(UnityException):
pass


class UnityCommunicationException(UnityException):
"""
Related to errors with the communicator.
"""

pass


class UnityActionException(UnityException):
"""
Related to errors with sending actions.
Expand Down
14 changes: 10 additions & 4 deletions ml-agents-envs/mlagents/envs/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import cloudpickle

from mlagents.envs import UnityEnvironment
from mlagents.envs.exception import UnityCommunicationException
from multiprocessing import Process, Pipe, Queue
from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
Expand Down Expand Up @@ -47,14 +48,14 @@ def send(self, name: str, payload=None):
cmd = EnvironmentCommand(name, payload)
self.conn.send(cmd)
except (BrokenPipeError, EOFError):
raise KeyboardInterrupt
raise UnityCommunicationException("UnityEnvironment worker: send failed.")

def recv(self) -> EnvironmentResponse:
try:
response: EnvironmentResponse = self.conn.recv()
return response
except (BrokenPipeError, EOFError):
raise KeyboardInterrupt
raise UnityCommunicationException("UnityEnvironment worker: recv failed.")

def close(self):
try:
Expand Down Expand Up @@ -115,8 +116,9 @@ def _send_response(cmd_name, payload):
_send_response("global_done", env.global_done)
elif cmd.name == "close":
break
except KeyboardInterrupt:
print("UnityEnvironment worker: keyboard interrupt")
except (KeyboardInterrupt, UnityCommunicationException):
print("UnityEnvironment worker: environment stopping.")
step_queue.put(EnvironmentResponse("env_close", worker_id, None))
finally:
step_queue.close()
env.close()
Expand Down Expand Up @@ -171,6 +173,10 @@ def step(self) -> List[StepInfo]:
try:
while True:
step = self.step_queue.get_nowait()
if step.name == "env_close":
raise UnityCommunicationException(
"At least one of the environments has closed."
)
self.env_workers[step.worker_id].waiting = False
if step.worker_id not in step_workers:
worker_steps.append(step)
Expand Down
9 changes: 6 additions & 3 deletions ml-agents/mlagents/trainers/trainer_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from mlagents.envs import BrainParameters
from mlagents.envs.env_manager import StepInfo
from mlagents.envs.env_manager import EnvManager
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.envs.exception import (
UnityEnvironmentException,
UnityCommunicationException,
)
from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.timers import hierarchical_timer, get_timer_tree, timed
from mlagents.trainers import Trainer, TrainerMetrics
Expand Down Expand Up @@ -302,15 +305,15 @@ def start_learning(
# Final save Tensorflow model
if global_step != 0 and self.train_model:
self._save_model()
except KeyboardInterrupt:
except (KeyboardInterrupt, UnityCommunicationException):
if self.train_model:
self._save_model_when_interrupted()
pass
env_manager.close()
if self.train_model:
self._write_training_metrics()
self._export_graph()
self._write_timing_tree()
env_manager.close()

def end_trainer_episodes(
self, env: BaseUnityEnvironment, lessons_incremented: Dict[str, bool]
Expand Down