Skip to content

Commit

Permalink
fixed end of batch size mismatch (#389)
Browse files Browse the repository at this point in the history
Co-authored-by: unknown <sidhant96@hotmail.com>
  • Loading branch information
sidhantls and unknown committed Nov 22, 2020
1 parent cff6b52 commit f0e2bee
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/reinforce_model.py
Expand Up @@ -215,7 +215,7 @@ def loss(self, states, actions, scaled_rewards) -> torch.Tensor:

# policy loss
log_prob = log_softmax(logits, dim=1)
log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions]
log_prob_actions = scaled_rewards * log_prob[range(len(log_prob)), actions]
loss = -log_prob_actions.mean()

return loss
Expand Down

0 comments on commit f0e2bee

Please sign in to comment.