Skip to content

Commit

Permalink
Merge pull request #31 from WPI-MMR/FixingExamples
Browse files Browse the repository at this point in the history
Adding reward instance to the examples to make them work
  • Loading branch information
goobta committed Dec 7, 2020
2 parents 472e59c + 94e5fbc commit 0ee3f1b
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 31 deletions.
3 changes: 3 additions & 0 deletions examples/solo8_vanilla/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@

import gym_solo
from gym_solo.envs import solo8v2vanilla
from gym_solo.core import rewards


if __name__ == '__main__':
config = solo8v2vanilla.Solo8VanillaConfig()
env = gym.make('solo8vanilla-v0', use_gui=True, realtime=True, config=config)

env.reward_factory.register_reward(1,rewards.UprightReward(env.robot))

try:
print("""\n
=========================
Expand Down
2 changes: 2 additions & 0 deletions examples/solo8_vanilla/observation_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import gym_solo
from gym_solo.envs import solo8v2vanilla
from gym_solo.core import obs
from gym_solo.core import rewards


if __name__ == '__main__':
config = solo8v2vanilla.Solo8VanillaConfig()
env = gym.make('solo8vanilla-v0', use_gui=True, realtime=True, config=config)

env.obs_factory.register_observation(obs.TorsoIMU(env.robot))
env.reward_factory.register_reward(1,rewards.UprightReward(env.robot))

try:
print("""\n
Expand Down
2 changes: 1 addition & 1 deletion gym_solo/core/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_obs(self) -> Tuple[solo_types.obs, List[str]]:
i-th observation.
"""
if not self._observations:
return np.empty(shape=(0,)), []
raise ValueError('Need to register at least one observation instance')

all_obs = []
all_labels = []
Expand Down
4 changes: 2 additions & 2 deletions gym_solo/core/termination.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def is_terminated(self) -> bool:
termination conditions and value of _is_or attribute
"""

if not self._use_or:
raise ValueError('No termination condition other than OR is defined')
if not self._terminations:
raise ValueError('Need to register at least one termination instance')

for termination in self._terminations:
if termination.is_terminated():
Expand Down
10 changes: 4 additions & 6 deletions gym_solo/core/test_obs_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def test_empty(self):
self.assertFalse(of._observations)
self.assertIsNone(of._obs_space)

observations, labels = of.get_obs()
self.assertEqual(observations.size, 0)
self.assertFalse(labels)
with self.assertRaises(ValueError):
observations, labels = of.get_obs()

def test_register_happy(self):
of = obs.ObservationFactory(self.client)
Expand Down Expand Up @@ -70,10 +69,9 @@ def test_register_mismatch(self):

def test_get_obs_no_observations(self):
of = obs.ObservationFactory(self.client)
observations, labels = of.get_obs()

np.testing.assert_array_equal(observations, np.empty(shape=(0,)))
self.assertFalse(labels)
with self.assertRaises(ValueError):
observations, labels = of.get_obs()

def test_get_obs_single_observation(self):
of = obs.ObservationFactory(self.client)
Expand Down
15 changes: 1 addition & 14 deletions gym_solo/core/test_termination_factory.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import unittest
from gym_solo.core import termination

# TODO: Move this to gym_solo.testing
class DummyTermination(termination.Termination):
def __init__(self, body_id: int, termination_var: bool):
self.body_id = body_id
self.termination_var = termination_var
self.reset_counter = 0
self.reset()

def reset(self):
self.reset_counter += 1

def is_terminated(self) -> bool:
return self.termination_var
from gym_solo.testing import DummyTermination


class TestTerminationFactory(unittest.TestCase):
Expand Down
11 changes: 7 additions & 4 deletions gym_solo/envs/solo8v2vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, use_gui: bool = False, realtime: bool = False,
self._config.motor_torque_limit,
shape=(joint_cnt,))

self.reset()
self.reset(init_call=True)

def step(self, action: List[float]) -> Tuple[solo_types.obs, float, bool,
Dict[Any, Any]]:
Expand Down Expand Up @@ -88,7 +88,7 @@ def step(self, action: List[float]) -> Tuple[solo_types.obs, float, bool,

return obs_values, reward, done, {'labels': obs_labels}

def reset(self) -> solo_types.obs:
def reset(self, init_call: bool = False) -> solo_types.obs:
"""Reset the state of the environment and returns an initial observation.
Returns:
Expand All @@ -106,8 +106,11 @@ def reset(self) -> solo_types.obs:
positionGains=self._zero_gains, velocityGains=self._zero_gains)
self.client.stepSimulation()

obs_values, _ = self.obs_factory.get_obs()
return obs_values
if init_call:
return np.empty(shape=(0,)), []
else:
obs_values, _ = self.obs_factory.get_obs()
return obs_values

@property
def observation_space(self):
Expand Down
20 changes: 17 additions & 3 deletions gym_solo/envs/test_solo8v2vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gym_solo.core import obs as solo_obs
from gym_solo.testing import CompliantObs
from gym_solo.testing import SimpleReward
from gym_solo.testing import DummyTermination

from gym import error, spaces
from parameterized import parameterized
Expand Down Expand Up @@ -36,7 +37,10 @@ def test_realtime(self, mock_time):
env = solo_env.Solo8VanillaEnv(config=solo_env.Solo8VanillaConfig(),
realtime=True)
env.reward_factory.register_reward(1, SimpleReward())


env.obs_factory.register_observation(CompliantObs(None))
env.termination_factory.register_termination(DummyTermination(0, True))

env.step(env.action_space.sample())
self.assertTrue(mock_time.called)

Expand Down Expand Up @@ -91,7 +95,10 @@ def test_action_space(self):

def test_actions(self):
no_op = np.zeros(self.env.action_space.shape[0])


self.env.obs_factory.register_observation(CompliantObs(None))
self.env.termination_factory.register_termination(DummyTermination(0, True))

# Let the robot stabilize first
for i in range(1000):
self.env.step(no_op)
Expand All @@ -116,6 +123,9 @@ def test_actions(self):
self.assert_array_not_almost_equal(orientation, new_or)

def test_reset(self):
self.env.obs_factory.register_observation(CompliantObs(None))
self.env.termination_factory.register_termination(DummyTermination(0, True))

base_pos, base_or = p.getBasePositionAndOrientation(self.env.robot)

action = np.array([5.] * self.env.action_space.shape[0])
Expand All @@ -138,6 +148,9 @@ def test_step_no_rewards(self):
env.step(np.zeros(self.env.action_space.shape[0]))

def test_step_simple_reward(self):
self.env.obs_factory.register_observation(CompliantObs(None))
self.env.termination_factory.register_termination(DummyTermination(0, True))

obs, reward, done, info = self.env.step(self.env.action_space.sample())
self.assertEqual(reward, 1)

Expand All @@ -151,7 +164,7 @@ def test_disjoint_environments(self):
env1.obs_factory.register_observation(solo_obs.TorsoIMU(env1.robot))
env1.obs_factory.register_observation(solo_obs.MotorEncoder(env1.robot))
env1.reward_factory.register_reward(1, SimpleReward())

env1.termination_factory.register_termination(DummyTermination(0, True))
home_position = env1.reset()

for i in range(1000):
Expand All @@ -161,6 +174,7 @@ def test_disjoint_environments(self):
env2.obs_factory.register_observation(solo_obs.TorsoIMU(env2.robot))
env2.obs_factory.register_observation(solo_obs.MotorEncoder(env2.robot))
env2.reward_factory.register_reward(1, SimpleReward())
env2.termination_factory.register_termination(DummyTermination(0, True))

np.testing.assert_array_almost_equal(home_position, env2.reset())

Expand Down
17 changes: 16 additions & 1 deletion gym_solo/testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from gym_solo.core import obs
from gym_solo.core import rewards
from gym_solo.core import termination

from gym import spaces
import numpy as np
Expand Down Expand Up @@ -62,4 +63,18 @@ def compute(self) -> float:
Returns:
float: the configured reward.
"""
return self._return_value
return self._return_value


class DummyTermination(termination.Termination):
def __init__(self, body_id: int, termination_var: bool):
self.body_id = body_id
self.termination_var = termination_var
self.reset_counter = 0
self.reset()

def reset(self):
self.reset_counter += 1

def is_terminated(self) -> bool:
return self.termination_var

0 comments on commit 0ee3f1b

Please sign in to comment.