From ace181e2a0d83be8940103bbb4d7e917afb30119 Mon Sep 17 00:00:00 2001 From: Rodrigo de Lazcano Perez-Vicente Date: Fri, 10 Feb 2023 16:34:54 -0500 Subject: [PATCH] add info success --- gymnasium_robotics/envs/maze/point_maze.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gymnasium_robotics/envs/maze/point_maze.py b/gymnasium_robotics/envs/maze/point_maze.py index 6408bbdc..da426b05 100644 --- a/gymnasium_robotics/envs/maze/point_maze.py +++ b/gymnasium_robotics/envs/maze/point_maze.py @@ -369,6 +369,9 @@ def reset( obs, info = self.point_env.reset(seed=seed) obs_dict = self._get_obs(obs) + info["success"] = bool( + np.linalg.norm(obs_dict["achieved_goal"] - self.goal) <= 0.45 + ) return obs_dict, info @@ -376,11 +379,14 @@ def step(self, action): obs, _, _, _, info = self.point_env.step(action) obs_dict = self._get_obs(obs) + info["success"] = bool( + np.linalg.norm(obs_dict["achieved_goal"] - self.goal) <= 0.45 + ) + reward = self.compute_reward(obs_dict["achieved_goal"], self.goal, info) + terminated = self.compute_terminated(obs_dict["achieved_goal"], self.goal, info) truncated = self.compute_truncated(obs_dict["achieved_goal"], self.goal, info) - reward = self.compute_reward(obs_dict["achieved_goal"], self.goal, info) - return obs_dict, reward, terminated, truncated, info def update_target_site_pos(self):