Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions ml-agents/mlagents/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@


class UnityEnvironment(object):
SCALAR_ACTION_TYPES = (int, np.int32, np.int64, float, np.float32, np.float64)
SINGLE_BRAIN_ACTION_TYPES = SCALAR_ACTION_TYPES + (list, np.ndarray)
SINGLE_BRAIN_TEXT_TYPES = (str, list, np.ndarray)

def __init__(self, file_name=None, worker_id=0,
base_port=5005, seed=0,
docker_training=False, no_graphics=False):
Expand Down Expand Up @@ -270,7 +274,7 @@ def step(self, vector_action=None, memory=None, text_action=None, value=None) ->

# Check that environment is loaded, and episode is currently running.
if self._loaded and not self._global_done and self._global_done is not None:
if isinstance(vector_action, (int, np.int_, float, np.float_, list, np.ndarray)):
if isinstance(vector_action, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
vector_action = {self._external_brain_names[0]: vector_action}
elif self._num_external_brains > 1:
Expand All @@ -282,7 +286,7 @@ def step(self, vector_action=None, memory=None, text_action=None, value=None) ->
"There are no external brains in the environment, "
"step cannot take a vector_action input")

if isinstance(memory, (int, np.int_, float, np.float_, list, np.ndarray)):
if isinstance(memory, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
memory = {self._external_brain_names[0]: memory}
elif self._num_external_brains > 1:
Expand All @@ -294,7 +298,7 @@ def step(self, vector_action=None, memory=None, text_action=None, value=None) ->
"There are no external brains in the environment, "
"step cannot take a memory input")

if isinstance(text_action, (str, list, np.ndarray)):
if isinstance(text_action, self.SINGLE_BRAIN_TEXT_TYPES):
if self._num_external_brains == 1:
text_action = {self._external_brain_names[0]: text_action}
elif self._num_external_brains > 1:
Expand All @@ -306,7 +310,7 @@ def step(self, vector_action=None, memory=None, text_action=None, value=None) ->
"There are no external brains in the environment, "
"step cannot take a value input")

if isinstance(value, (int, np.int_, float, np.float_, list, np.ndarray)):
if isinstance(value, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
value = {self._external_brain_names[0]: value}
elif self._num_external_brains > 1:
Expand Down Expand Up @@ -419,14 +423,14 @@ def _close(self):
if self.proc1 is not None:
self.proc1.kill()

@staticmethod
def _flatten(arr):
@classmethod
def _flatten(cls, arr):
"""
Converts arrays to list.
:param arr: numpy vector.
:return: flattened list.
"""
if isinstance(arr, (int, np.int_, float, np.float_)):
if isinstance(arr, cls.SCALAR_ACTION_TYPES):
arr = [float(arr)]
if isinstance(arr, np.ndarray):
arr = arr.tolist()
Expand Down