Skip to content

Commit

Permalink
Revert "Unroll with AMP for off policy training (#1653)"
Browse files Browse the repository at this point in the history
This reverts commit c9550f2.
  • Loading branch information
Le Horizon committed Jun 20, 2024
1 parent d9c32a9 commit 89025e1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
11 changes: 4 additions & 7 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ def summarize_play(
"""Generate summaries for play or evaluate.
Args:
experience: experience of one step rollout in the environment during
play or evaluation, this is in contrast to the experience input to
experience: experience of one step rollout in the environment during
play or evaluation, this is in contrast to the experience input to
``summarize_rollout`` where a sequence of multiple rollout steps
are collected.
custom_summary: when specified it is a function that will be called every
Expand Down Expand Up @@ -602,9 +602,7 @@ def _process_unroll_step(self, policy_step, action, time_step,
transformed_time_step, policy_state,
experience_list, original_reward_list):
self.observe_for_metrics(time_step.cpu())
exp = make_experience(time_step.cpu(),
alf.layers.to_float32(policy_step),
alf.layers.to_float32(policy_state))
exp = make_experience(time_step.cpu(), policy_step, policy_state)

store_exp_time = 0
if not self.on_policy:
Expand Down Expand Up @@ -780,8 +778,7 @@ def _train_iter_off_policy(self):
or self.get_step_metrics()[1].result() <
self._config.num_env_steps)):
unrolled = True
with (torch.set_grad_enabled(config.unroll_with_grad),
torch.cuda.amp.autocast(self._config.enable_amp)):
with torch.set_grad_enabled(config.unroll_with_grad):
with record_time("time/unroll"):
self.eval()
# The period of performing unroll may not be an integer
Expand Down
5 changes: 0 additions & 5 deletions alf/environments/fast_parallel_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,6 @@ def _to_tensor(self, stacked):

def _step(self, action):
def _to_numpy(x):
# When AMP is enabled, the action.dtype can be torch.float16. We
# need to convert it to torch.float32 to match the dtype from
# action_spec
if x.dtype == torch.float16:
x = x.float()
x = x.cpu().numpy()
# parallel_environment.cpp requires the arrays to be contiguous. If
# x is already contiguous, ascontiguousarray() will simply return x.
Expand Down
2 changes: 1 addition & 1 deletion alf/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3691,7 +3691,7 @@ def to_float32(nested):
Args:
nested (nested Tensor): a nest of tensors
Returns:
nested Tensor: a nest of tensors/distributions with dtype torch.float32
nested Tensor: a nest of tensors with dtype torch.float32
"""

def _to_float32(x):
Expand Down

0 comments on commit 89025e1

Please sign in to comment.