Skip to content

Commit

Permalink
Added cuda testing
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Dec 4, 2023
1 parent c7e133e commit 9c24834
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mushroom_rl.core import Agent, VectorCore, VectorizedEnvironment, MDPInfo
from mushroom_rl.rl_utils import Box
from mushroom_rl.policy import Policy
from mushroom_rl.utils import TorchUtils


class DummyPolicy(Policy):
Expand Down Expand Up @@ -94,3 +95,10 @@ def test_vectorized_env_():
run_exp(env_backend='torch', agent_backend='numpy')
run_exp(env_backend='numpy', agent_backend='torch')
run_exp(env_backend='numpy', agent_backend='numpy')

if torch.cuda.is_available():
print('Testing also cuda')
TorchUtils.set_default_device('cuda')
run_exp(env_backend='torch', agent_backend='torch')
run_exp(env_backend='torch', agent_backend='numpy')

0 comments on commit 9c24834

Please sign in to comment.