Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 14 additions & 39 deletions ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Dict, Optional, Tuple, List
import torch
import numpy as np
from mlagents_envs.base_env import DecisionSteps

from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.utils import ModelUtils


Expand Down Expand Up @@ -50,35 +49,6 @@ def create_reward_signals(self, reward_signal_configs):
reward_signal, self.policy.behavior_spec, settings
)

def get_value_estimates(
self, decision_requests: DecisionSteps, idx: int, done: bool
) -> Dict[str, float]:
"""
Generates value estimates for bootstrapping.
:param decision_requests:
:param idx: Index in BrainInfo of agent.
:param done: Whether or not this is the last element of the episode,
in which case the value estimate will be 0.
:return: The value estimate dictionary with key being the name of the reward signal
and the value the corresponding value estimate.
"""
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)

value_estimates = self.policy.actor_critic.critic_pass(
np.expand_dims(vec_vis_obs.vector_observations[idx], 0),
np.expand_dims(vec_vis_obs.visual_observations[idx], 0),
)

value_estimates = {k: float(v) for k, v in value_estimates.items()}

# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if not self.reward_signals[k].ignore_done:
value_estimates[k] = 0.0

return value_estimates

def get_trajectory_value_estimates(
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
Expand All @@ -93,18 +63,23 @@ def get_trajectory_value_estimates(
else:
visual_obs = []

memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size])
memory = torch.zeros([1, 1, self.policy.m_size])

next_obs = np.concatenate(next_obs, axis=-1)
next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
next_memory = torch.zeros([1, 1, self.policy.m_size])
vec_vis_obs = SplitObservations.from_observations(next_obs)
next_vec_obs = [
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0)
]
next_vis_obs = [
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0)
for _vis_ob in vec_vis_obs.visual_observations
]

value_estimates = self.policy.actor_critic.critic_pass(
vector_obs, visual_obs, memory
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences
)

next_value_estimate = self.policy.actor_critic.critic_pass(
next_obs, next_obs, next_memory
next_value_estimate, _ = self.policy.actor_critic.critic_pass(
next_vec_obs, next_vis_obs, next_memory, sequence_length=1
)

for name, estimate in value_estimates.items():
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def evaluate(
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memories"] = memories.detach().cpu().numpy()
run_out["memory_out"] = memories.detach().cpu().numpy().squeeze(0)
return run_out

def get_action(
Expand Down
44 changes: 31 additions & 13 deletions ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def ppo_value_loss(
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Creates training-specific Tensorflow ops for PPO models.
:param returns:
:param old_values:
:param values:
Evaluates value loss for PPO.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
value_losses = []
for name, head in values.items():
Expand All @@ -77,18 +80,24 @@ def ppo_value_loss(
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = torch.mean(torch.max(v_opt_a, v_opt_b))
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
return value_loss

def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks):
def ppo_policy_loss(
self,
advantages: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Creates training-specific Tensorflow ops for PPO models.
:param masks:
:param advantages:
Evaluate PPO policy loss.
:param advantages: Computed advantages.
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
advantage = advantages.unsqueeze(-1)

Expand All @@ -99,7 +108,9 @@ def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks):
p_opt_b = (
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
)
policy_loss = -torch.mean(torch.min(p_opt_a, p_opt_b))
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b).flatten(), loss_masks
)
return policy_loss

@timed
Expand Down Expand Up @@ -153,14 +164,21 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
memories=memories,
seq_len=self.policy.sequence_length,
)
value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps)
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
value_loss = self.ppo_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
policy_loss = self.ppo_policy_loss(
ModelUtils.list_to_tensor(batch["advantages"]),
log_probs,
ModelUtils.list_to_tensor(batch["action_probs"]),
ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32),
loss_masks,
)
loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks)
)
loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy)

# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
Expand Down
125 changes: 96 additions & 29 deletions ml-agents/mlagents/trainers/sac/optimizer_torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
from typing import Dict, List, Mapping, cast, Tuple
from typing import Dict, List, Mapping, cast, Tuple, Optional
import torch
from torch import nn
import attr

from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType
Expand Down Expand Up @@ -56,10 +57,24 @@ def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
actions: torch.Tensor = None,
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions)
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions)
q1_out, _ = self.q1_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
q2_out, _ = self.q2_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
return q1_out, q2_out

def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
Expand Down Expand Up @@ -87,17 +102,28 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
for name in self.stream_names
}

# Critics should have 1/2 of the memory of the policy
critic_memory = policy_network_settings.memory
if critic_memory is not None:
critic_memory = attr.evolve(
critic_memory, memory_size=critic_memory.memory_size // 2
)
value_network_settings = attr.evolve(
policy_network_settings, memory=critic_memory
)

self.value_network = TorchSACOptimizer.PolicyValueNetwork(
self.stream_names,
self.policy.behavior_spec.observation_shapes,
policy_network_settings,
value_network_settings,
self.policy.behavior_spec.action_type,
self.act_size,
)

self.target_network = ValueNetwork(
self.stream_names,
self.policy.behavior_spec.observation_shapes,
policy_network_settings,
value_network_settings,
)
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)

Expand Down Expand Up @@ -168,11 +194,11 @@ def sac_q_loss(
* self.gammas[i]
* target_values[name]
)
_q1_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream)
_q1_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks
)
_q2_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream)
_q2_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks
)

q1_losses.append(_q1_loss)
Expand Down Expand Up @@ -232,9 +258,8 @@ def sac_value_loss(
v_backup = min_policy_qs[name] - torch.sum(
_ent_coef * log_probs, dim=1
)
# print(log_probs, v_backup, _ent_coef, loss_masks)
value_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
)
value_losses.append(value_loss)
else:
Expand All @@ -253,9 +278,9 @@ def sac_value_loss(
v_backup = min_policy_qs[name] - torch.mean(
branched_ent_bonus, axis=0
)
value_loss = 0.5 * torch.mean(
loss_masks
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze())
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
loss_masks,
)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
Expand All @@ -275,7 +300,7 @@ def sac_policy_loss(
if not discrete:
mean_q1 = mean_q1.unsqueeze(1)
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
policy_loss = torch.mean(loss_masks * batch_policy_loss)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
else:
action_probs = log_probs.exp()
branched_per_action_ent = ModelUtils.break_into_branches(
Expand Down Expand Up @@ -322,9 +347,8 @@ def sac_entropy_loss(
target_current_diff = torch.squeeze(
target_current_diff_branched, axis=2
)
entropy_loss = -torch.mean(
loss_masks
* torch.mean(self._log_ent_coef * target_current_diff, axis=1)
entropy_loss = -1 * ModelUtils.masked_mean(
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
)

return entropy_loss
Expand Down Expand Up @@ -369,12 +393,28 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)

memories = [
memories_list = [
ModelUtils.list_to_tensor(batch["memory"][i])
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
offset = 1 if self.policy.sequence_length > 1 else 0
next_memories_list = [
ModelUtils.list_to_tensor(
batch["memory"][i][self.policy.m_size // 2 :]
) # only pass value part of memory to target network
for i in range(offset, len(batch["memory"]), self.policy.sequence_length)
]

if len(memories_list) > 0:
memories = torch.stack(memories_list).unsqueeze(0)
next_memories = torch.stack(next_memories_list).unsqueeze(0)
else:
memories = None
next_memories = None
# Q network memories are 0'ed out, since we don't have them during inference.
q_memories = torch.zeros_like(next_memories)

vis_obs: List[torch.Tensor] = []
next_vis_obs: List[torch.Tensor] = []
if self.policy.use_vis_obs:
Expand Down Expand Up @@ -415,19 +455,46 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
)
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions)
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions)
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
sampled_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
squeezed_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_stream, q2_stream = q1_out, q2_out
else:
with torch.no_grad():
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs)
q1_out, q2_out = self.value_network(vec_obs, vis_obs)
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_stream = self._condense_q_streams(q1_out, actions)
q2_stream = self._condense_q_streams(q2_out, actions)

with torch.no_grad():
target_values, _ = self.target_network(next_vec_obs, next_vis_obs)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32)
target_values, _ = self.target_network(
next_vec_obs,
next_vis_obs,
memories=next_memories,
sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])

Expand Down
Loading