Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
422 lines (369 sloc) 16.3 KB
import logging
from typing import Any, Dict, List, Optional
import numpy as np
from mlagents.tf_utils import tf
from mlagents_envs.exception import UnityException
from mlagents.trainers.policy import Policy
from mlagents.trainers.action_info import ActionInfo
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.env_manager import get_global_agent_id
from mlagents_envs.base_env import BatchedStepResult
logger = logging.getLogger("mlagents.trainers")
class UnityPolicyException(UnityException):
Related to errors with the Trainer.
class TFPolicy(Policy):
Contains a learning model, and the necessary
functions to interact with it to perform evaluate and updating.
possible_output_nodes = [
def __init__(self, seed, brain, trainer_parameters):
Initialized the policy.
:param seed: Random seed to use for TensorFlow.
:param brain: The corresponding Brain for this policy.
:param trainer_parameters: The trainer parameters.
self.m_size = None
self.model = None
self.inference_dict = {}
self.update_dict = {}
self.sequence_length = 1
self.seed = seed
self.brain = brain
self.use_recurrent = trainer_parameters["use_recurrent"]
self.memory_dict: Dict[str, np.ndarray] = {}
self.reward_signals: Dict[str, "RewardSignal"] = {}
self.num_branches = len(self.brain.vector_action_space_size)
self.previous_action_dict: Dict[str, np.array] = {}
self.normalize = trainer_parameters.get("normalize", False)
self.use_continuous_act = brain.vector_action_space_type == "continuous"
if self.use_continuous_act:
self.num_branches = self.brain.vector_action_space_size[0]
self.model_path = trainer_parameters["model_path"]
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# For multi-GPU training, set allow_soft_placement to True to allow
# placing the operation into an alternative device automatically
# to prevent from exceptions if the device doesn't suppport the operation
# or the device does not exist
config.allow_soft_placement = True
self.sess = tf.Session(config=config, graph=self.graph)
self.saver = None
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]
self.sequence_length = trainer_parameters["sequence_length"]
if self.m_size == 0:
raise UnityPolicyException(
"The memory size for brain {0} is 0 even "
"though the trainer uses recurrent.".format(brain.brain_name)
elif self.m_size % 4 != 0:
raise UnityPolicyException(
"The memory size for brain {0} is {1} "
"but it must be divisible by 4.".format(
brain.brain_name, self.m_size
def _initialize_graph(self):
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
init = tf.global_variables_initializer()
def _load_graph(self):
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)"Loading Model for brain {}".format(self.brain.brain_name))
ckpt = tf.train.get_checkpoint_state(self.model_path)
if ckpt is None:
"The model {0} could not be found. Make "
"sure you specified the right "
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
def evaluate(
self, batched_step_result: BatchedStepResult, global_agent_ids: List[str]
) -> Dict[str, Any]:
Evaluates policy for the agent experiences provided.
:param batched_step_result: BatchedStepResult input to network.
:return: Output from policy based on self.inference_dict.
raise UnityPolicyException("The evaluate function was not implemented.")
def get_action(
self, batched_step_result: BatchedStepResult, worker_id: int = 0
) -> ActionInfo:
Decides actions given observations information, and takes them in environment.
:param batched_step_result: A dictionary of brain names and BatchedStepResult from environment.
:param worker_id: In parallel environment training, the unique id of the environment worker that
the BatchedStepResult came from. Used to construct a globally unique id for each agent.
:return: an ActionInfo containing action, memories, values and an object
to be passed to add experiences
if batched_step_result.n_agents() == 0:
return ActionInfo([], [], {}, [])
agents_done = [
for agent, done in zip(
batched_step_result.agent_id, batched_step_result.done
if done
global_agent_ids = [
get_global_agent_id(worker_id, int(agent_id))
for agent_id in batched_step_result.agent_id
] # For 1-D array, the iterator order is correct.
run_out = self.evaluate( # pylint: disable=assignment-from-no-return
batched_step_result, global_agent_ids
self.save_memories(global_agent_ids, run_out.get("memory_out"))
return ActionInfo(
def update(self, mini_batch, num_sequences):
Performs update of the policy.
:param num_sequences: Number of experience trajectories in batch.
:param mini_batch: Batch of experiences.
:return: Results of update.
raise UnityPolicyException("The update function was not implemented.")
def _execute_model(self, feed_dict, out_dict):
Executes model.
:param feed_dict: Input dictionary mapping nodes to input data.
:param out_dict: Output dictionary mapping names to nodes.
:return: Dictionary mapping names to input data.
network_out =, feed_dict=feed_dict)
run_out = dict(zip(list(out_dict.keys()), network_out))
return run_out
def fill_eval_dict(self, feed_dict, batched_step_result):
vec_vis_obs = SplitObservations.from_observations(batched_step_result.obs)
for i, _ in enumerate(vec_vis_obs.visual_observations):
feed_dict[self.model.visual_in[i]] = vec_vis_obs.visual_observations[i]
if self.use_vec_obs:
feed_dict[self.model.vector_in] = vec_vis_obs.vector_observations
if not self.use_continuous_act:
mask = np.ones(
if batched_step_result.action_mask is not None:
mask = 1 - np.concatenate(batched_step_result.action_mask, axis=1)
feed_dict[self.model.action_masks] = mask
return feed_dict
def make_empty_memory(self, num_agents):
Creates empty memory for use with RNNs
:param num_agents: Number of agents.
:return: Numpy array of zeros.
return np.zeros((num_agents, self.m_size), dtype=np.float32)
def save_memories(
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
) -> None:
if memory_matrix is None:
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.memory_dict:
memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix
def remove_memories(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.memory_dict:
def make_empty_previous_action(self, num_agents):
Creates empty previous action for use with RNNs and discrete control
:param num_agents: Number of agents.
:return: Numpy array of zeros.
return np.zeros((num_agents, self.num_branches),
def save_previous_action(
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
) -> None:
if action_matrix is None:
for index, agent_id in enumerate(agent_ids):
self.previous_action_dict[agent_id] = action_matrix[index, :]
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
action_matrix = np.zeros((len(agent_ids), self.num_branches),
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_action_dict:
action_matrix[index, :] = self.previous_action_dict[agent_id]
return action_matrix
def remove_previous_action(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.previous_action_dict:
def get_current_step(self):
Gets current model step.
:return: current model step.
step =
return step
def increment_step(self, n_steps):
Increments model step.
out_dict = {
"global_step": self.model.global_step,
"increment_step": self.model.increment_step,
feed_dict = {self.model.steps_to_increment: n_steps}
return, feed_dict=feed_dict)["global_step"]
def get_inference_vars(self):
:return:list of inference var names
return list(self.inference_dict.keys())
def get_update_vars(self):
:return:list of update var names
return list(self.update_dict.keys())
def save_model(self, steps):
Saves the model
:param steps: The number of steps the model was trained for
with self.graph.as_default():
last_checkpoint = self.model_path + "/model-" + str(steps) + ".cptk", last_checkpoint)
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
def export_model(self):
Exports latest saved model to .nn format for Unity embedding.
with self.graph.as_default():
target_nodes = ",".join(self._process_graph())
graph_def = self.graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
self.sess, graph_def, target_nodes.replace(" ", "").split(",")
frozen_graph_def_path = self.model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
tf2bc.convert(frozen_graph_def_path, self.model_path + ".nn")"Exported " + self.model_path + ".nn file")
def _process_graph(self):
Gets the list of the output nodes present in the graph for inference
:return: list of node names
all_nodes = [ for x in self.graph.as_graph_def().node]
nodes = [x for x in all_nodes if x in self.possible_output_nodes]"List of nodes to export for brain :" + self.brain.brain_name)
for n in nodes:"\t" + n)
return nodes
def update_normalization(self, vector_obs: np.ndarray) -> None:
If this policy normalizes vector observations, this will update the norm values in the graph.
:param vector_obs: The vector observations to add to the running estimate of the distribution.
if self.use_vec_obs and self.normalize:
feed_dict={self.model.vector_in: vector_obs},
def get_batched_value_estimates(self, batch: AgentBuffer) -> Dict[str, np.ndarray]:
feed_dict: Dict[tf.Tensor, Any] = {
self.model.batch_size: batch.num_experiences,
self.model.sequence_length: 1, # We want to feed data in batch-wise, not time-wise.
if self.use_vec_obs:
feed_dict[self.model.vector_in] = batch["vector_obs"]
if self.model.vis_obs_size > 0:
for i in range(len(self.model.visual_in)):
_obs = batch["visual_obs%d" % i]
feed_dict[self.model.visual_in[i]] = _obs
if self.use_recurrent:
feed_dict[self.model.memory_in] = batch["memory"]
if not self.use_continuous_act and self.use_recurrent:
feed_dict[self.model.prev_action] = batch["prev_action"]
value_estimates =, feed_dict)
value_estimates = {k: np.squeeze(v, axis=1) for k, v in value_estimates.items()}
return value_estimates
def get_value_estimates(
self, next_obs: List[np.ndarray], agent_id: str, done: bool
) -> Dict[str, float]:
Generates value estimates for bootstrapping.
:param experience: AgentExperience to be used for bootstrapping.
:param done: Whether or not this is the last element of the episode, in which case the value estimate will be 0.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
feed_dict: Dict[tf.Tensor, Any] = {
self.model.batch_size: 1,
self.model.sequence_length: 1,
vec_vis_obs = SplitObservations.from_observations(next_obs)
for i in range(len(vec_vis_obs.visual_observations)):
feed_dict[self.model.visual_in[i]] = [vec_vis_obs.visual_observations[i]]
if self.use_vec_obs:
feed_dict[self.model.vector_in] = [vec_vis_obs.vector_observations]
if self.use_recurrent:
feed_dict[self.model.memory_in] = self.retrieve_memories([agent_id])
if not self.use_continuous_act and self.use_recurrent:
feed_dict[self.model.prev_action] = self.retrieve_previous_action(
value_estimates =, feed_dict)
value_estimates = {k: float(v) for k, v in value_estimates.items()}
# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if self.reward_signals[k].use_terminal_states:
value_estimates[k] = 0.0
return value_estimates
def vis_obs_size(self):
return self.model.vis_obs_size
def vec_obs_size(self):
return self.model.vec_obs_size
def use_vis_obs(self):
return self.model.vis_obs_size > 0
def use_vec_obs(self):
return self.model.vec_obs_size > 0
You can’t perform that action at this time.