Skip to content

Commit

Permalink
Merge pull request #25 from WPI-MMR/agupta231-testing-module
Browse files Browse the repository at this point in the history
Refactor Objects into a Testing Module
  • Loading branch information
mahajanrevant committed Nov 5, 2020
2 parents 8906164 + d79c132 commit 383f67d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 27 deletions.
13 changes: 1 addition & 12 deletions gym_solo/core/test_obs_factory.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
import unittest
from gym_solo.core import obs
from gym_solo.testing import CompliantObs

from gym import spaces
import numpy as np


class CompliantObs(obs.Observation):
observation_space = spaces.Box(low=np.array([0., 0.]),
high=np.array([3., 3.]))
labels = ['1', '2']

def __init__(self, body_id):
pass

def compute(self):
return np.array([1, 2])


class TestObservationFactory(unittest.TestCase):
def test_empty(self):
of = obs.ObservationFactory()
Expand Down
11 changes: 2 additions & 9 deletions gym_solo/core/test_rewards_factory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import unittest
from gym_solo.core import rewards
from gym_solo.testing import ReflectiveReward

from parameterized import parameterized
from unittest import mock

import numpy as np


class TestReward(rewards.Reward):
def __init__(self, return_value):
self._return_value = return_value

def compute(self):
return self._return_value


class TestRewardsFactory(unittest.TestCase):
def test_empty(self):
rf = rewards.RewardFactory()
Expand All @@ -33,7 +26,7 @@ def test_empty(self):
def test_register_and_compute(self, name, rewards_dict, expected_reward):
rf = rewards.RewardFactory()
for weight, reward in rewards_dict.items():
rf.register_reward(weight, TestReward(reward))
rf.register_reward(weight, ReflectiveReward(reward))
self.assertEqual(rf.get_reward(), expected_reward)


Expand Down
7 changes: 1 addition & 6 deletions gym_solo/envs/test_solo8v2vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from gym_solo.envs import solo8v2vanilla as solo_env

from gym_solo.core.test_obs_factory import CompliantObs
from gym_solo.core import rewards
from gym_solo.testing import SimpleReward

from gym import error, spaces
from parameterized import parameterized
Expand All @@ -14,11 +14,6 @@
import pybullet as p


class SimpleReward(rewards.Reward):
def compute(self):
return 1


class TestSolo8v2VanillaEnv(unittest.TestCase):
def setUp(self):
self.env = solo_env.Solo8VanillaEnv(config=solo_env.Solo8VanillaConfig())
Expand Down
65 changes: 65 additions & 0 deletions gym_solo/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from gym_solo.core import obs
from gym_solo.core import rewards

from gym import spaces
import numpy as np

from gym_solo import solo_types


class CompliantObs(obs.Observation):
"""A simple observation which implements the Observation interface.
Note that the observation always returns [1, 2].
Attributes:
observation_space (spaces.Space): The observation space. Always returns a
box with a low of 0 and a high of 3.
labels (List[int]): The labels for the two observations. Always returns
['1', '2']
"""
observation_space = spaces.Box(low=np.array([0., 0.]),
high=np.array([3., 3.]))
labels = ['1', '2']

def __init__(self, body_id: int):
"""Create a new CompliantObs. Only kept for API conformance."""
pass

def compute(self) -> solo_types.obs:
"""Compute the observation for the state.
Returns:
solo_types.obs: [1., 2.]
"""
return np.array([1., 2.])


class SimpleReward(rewards.Reward):
"""A reward which will always return 1."""
def compute(self) -> float:
"""'Compute' the reward for the step. Always returns 1.
Returns:
float: 1.0
"""
return 1


class ReflectiveReward(rewards.Reward):
"""A reward which returns a configurable fixed value."""
def __init__(self, return_value: float):
"""Create a ReflectiveReward.
Args:
return_value (float): What value this reward should return.
"""
self._return_value = return_value

def compute(self) -> float:
"""Return the fixed reward.
Returns:
float: the configured reward.
"""
return self._return_value

0 comments on commit 383f67d

Please sign in to comment.