Skip to content

Commit

Permalink
Fix reproducibility for all robotics environments
Browse files Browse the repository at this point in the history
- Fix remaining mujoco envs
- Fix mujoco_py envs
- Simplify reset
  • Loading branch information
amacati committed Feb 11, 2024
1 parent e89b53e commit 5e14ea1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
7 changes: 3 additions & 4 deletions gymnasium_robotics/envs/fetch/fetch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def _viewer_setup(self):
setattr(self.viewer.cam, key, value)

def _reset_sim(self):
self.sim.reset() # Reset warm-start buffers, control buffers etc.
self.sim.set_state(self.initial_state)

# Randomize start position of object.
Expand Down Expand Up @@ -376,10 +377,8 @@ def _reset_sim(self):
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
self.data.qacc_warmstart[:] = np.copy(self.initial_qacc_warmstart)
self.data.ctrl[:] = np.copy(self.initial_ctrl)
self.data.mocap_pos[:] = np.copy(self.initial_mocap_pos)
self.data.mocap_quat[:] = np.copy(self.initial_mocap_quat)
# Reset buffers for warm-start, control buffers etc.
self._mujoco.mj_resetData(self.model, self.data)
if self.model.na != 0:
self.data.act[:] = None

Expand Down
7 changes: 3 additions & 4 deletions gymnasium_robotics/envs/robot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,10 @@ def _initialize_simulation(self):
self.initial_time = self.data.time
self.initial_qpos = np.copy(self.data.qpos)
self.initial_qvel = np.copy(self.data.qvel)
self.initial_ctrl = np.copy(self.data.ctrl)
self.initial_qacc_warmstart = np.copy(self.data.qacc_warmstart)
self.initial_mocap_pos = np.copy(self.data.mocap_pos)
self.initial_mocap_quat = np.copy(self.data.mocap_quat)

def _reset_sim(self):
# Reset warm-start buffers, control buffers etc.
mujoco.mj_resetData(self.model, self.data)
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
Expand Down Expand Up @@ -381,6 +379,7 @@ def _initialize_simulation(self):
self.initial_state = copy.deepcopy(self.sim.get_state())

def _reset_sim(self):
self.sim.reset() # Reset warm-start buffers, control buffers etc.
self.sim.set_state(self.initial_state)
self.sim.forward()
return super()._reset_sim()
Expand Down

0 comments on commit 5e14ea1

Please sign in to comment.