Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ml-agents/mlagents/trainers/model_saver/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ def _register_optimizer(self, optimizer):
pass

@abc.abstractmethod
def save_checkpoint(self, brain_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> str:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param brain_name: Brain name of brain to be trained
:param behavior_name: Behavior name of bevavior to be trained
"""
pass

@abc.abstractmethod
def export(self, output_filepath: str, brain_name: str) -> None:
def export(self, output_filepath: str, behavior_name: str) -> None:
"""
Saves the serialized model, given a path and brain name.
Saves the serialized model, given a path and behavior name.
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param brain_name: Brain name of brain to be trained.
:param behavior_name: Behavior name of behavior to be trained.
"""
pass

Expand Down
10 changes: 5 additions & 5 deletions ml-agents/mlagents/trainers/model_saver/tf_model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def _register_policy(self, policy: TFPolicy) -> None:
with self.policy.graph.as_default():
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)

def save_checkpoint(self, brain_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
def save_checkpoint(self, behavior_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
# Save the TF checkpoint and graph definition
if self.graph:
with self.graph.as_default():
Expand All @@ -66,16 +66,16 @@ def save_checkpoint(self, brain_name: str, step: int) -> str:
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, brain_name)
self.export(checkpoint_path, behavior_name)
return checkpoint_path

def export(self, output_filepath: str, brain_name: str) -> None:
def export(self, output_filepath: str, behavior_name: str) -> None:
# save model if there is only one worker or
# only on worker-0 if there are multiple workers
if self.policy and self.policy.rank is not None and self.policy.rank != 0:
return
export_policy_model(
self.model_path, output_filepath, brain_name, self.graph, self.sess
self.model_path, output_filepath, behavior_name, self.graph, self.sess
)

def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
Expand Down
8 changes: 4 additions & 4 deletions ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None:
self.policy = module
self.exporter = ModelSerializer(self.policy)

def save_checkpoint(self, brain_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> str:
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
state_dict = {
name: module.state_dict() for name, module in self.modules.items()
}
torch.save(state_dict, f"{checkpoint_path}.pt")
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
self.export(checkpoint_path, brain_name)
self.export(checkpoint_path, behavior_name)
return checkpoint_path

def export(self, output_filepath: str, brain_name: str) -> None:
def export(self, output_filepath: str, behavior_name: str) -> None:
if self.exporter is not None:
self.exporter.export_policy_model(output_filepath)

Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
also use a CNN to encode visual input prior to the MLP. Supports discrete and
continuous action spaces, as well as recurrent networks.
:param seed: Random seed.
:param brain: Assigned BrainParameters object.
:param behavior_spec: Assigned BehaviorSpec object.
:param trainer_settings: Defined training parameters.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param tanh_squash: Whether to use a tanh function on the continuous output,
Expand Down Expand Up @@ -214,7 +214,7 @@ def get_action(
"""
Decides actions given observations information, and takes them in environment.
:param worker_id:
:param decision_requests: A dictionary of brain names and BrainInfo from environment.
:param decision_requests: A dictionary of behavior names and DecisionSteps from environment.
:return: an ActionInfo containing action, memories, values and an object
to be passed to add experiences
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _create_dc_critic(
name="old_probabilities",
)

# Break old log probs into separate branches
# Break old log log_probs into separate branches
old_log_prob_branches = ModelUtils.break_into_branches(
self.all_old_log_probs, self.policy.act_size
)
Expand Down
13 changes: 9 additions & 4 deletions ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.policy import Policy
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, PPOSettings, FrameworkType
Expand All @@ -34,7 +34,7 @@ class PPOTrainer(RLTrainer):

def __init__(
self,
brain_name: str,
behavior_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
Expand All @@ -44,7 +44,7 @@ def __init__(
):
"""
Responsible for collecting experiences and training PPO model.
:param brain_name: The name of the brain associated with trainer config
:param behavior_name: The name of the behavior associated with trainer config
:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_settings: The parameters for the trainer.
:param training: Whether the trainer is set for training.
Expand All @@ -53,7 +53,12 @@ def __init__(
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super().__init__(
brain_name, trainer_settings, training, load, artifact_path, reward_buff_cap
behavior_name,
trainer_settings,
training,
load,
artifact_path,
reward_buff_cap,
)
self.hyperparameters: PPOSettings = cast(
PPOSettings, self.trainer_settings.hyperparameters
Expand Down
13 changes: 9 additions & 4 deletions ml-agents/mlagents/trainers/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy import Policy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.sac.optimizer_tf import SACOptimizer
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
Expand All @@ -41,7 +41,7 @@ class SACTrainer(RLTrainer):

def __init__(
self,
brain_name: str,
behavior_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
Expand All @@ -51,7 +51,7 @@ def __init__(
):
"""
Responsible for collecting experiences and training SAC model.
:param brain_name: The name of the brain associated with trainer config
:param behavior_name: The name of the behavior associated with trainer config
:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_settings: The parameters for the trainer.
:param training: Whether the trainer is set for training.
Expand All @@ -60,7 +60,12 @@ def __init__(
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super().__init__(
brain_name, trainer_settings, training, load, artifact_path, reward_buff_cap
behavior_name,
trainer_settings,
training,
load,
artifact_path,
reward_buff_cap,
)

self.seed = seed
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.tests import mock_brain as mb
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_reward_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.sac.optimizer_tf import SACOptimizer
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG
from mlagents.trainers.settings import (
GAILSettings,
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.sac.optimizer_tf import SACOptimizer
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.tests import mock_brain as mb
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/test_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.test_nn_policy import create_policy_mock
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer


def test_register(tmp_path):
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/tf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def create_discrete_action_masking_layer(
:param action_masks: The mask for the logits. Must be of dimension [None x total_number_of_action]
:param action_size: A list containing the number of possible actions for each branch
:return: The action output dimension [batch_size, num_branches], the concatenated
normalized probs (after softmax)
and the concatenated normalized log probs
normalized log_probs (after softmax)
and the concatenated normalized log log_probs
"""
branch_masks = ModelUtils.break_into_branches(action_masks, action_size)
raw_probs = [
Expand Down
1 change: 0 additions & 1 deletion ml-agents/mlagents/trainers/torch/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def export_policy_model(self, output_filepath: str) -> None:
Exports a Torch model for a Policy to .onnx format for Unity embedding.

:param output_filepath: file path to output the model (without file suffix)
:param brain_name: Brain name of brain to be trained
"""
if not os.path.exists(output_filepath):
os.makedirs(output_filepath)
Expand Down