Skip to content

Commit

Permalink
Fixing casts to int by to_torch_as(...) calls in policies when using …
Browse files Browse the repository at this point in the history
…discrete actions (thu-ml#521)
  • Loading branch information
Kenneth-Schroeder committed Feb 6, 2022
1 parent c20d64c commit dd4804d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as
from tianshou.policy import BasePolicy
from tianshou.utils import RunningMeanStd

Expand Down Expand Up @@ -131,7 +131,7 @@ def learn( # type: ignore
result = self(minibatch)
dist = result.dist
act = to_torch_as(minibatch.act, result.act)
ret = to_torch_as(minibatch.returns, result.act)
ret = to_torch(minibatch.returns, torch.float, result.act.device)
log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
loss = -(log_prob * ret).mean()
loss.backward()
Expand Down

0 comments on commit dd4804d

Please sign in to comment.