diff --git a/mushroom_rl/core/_impl/list_dataset.py b/mushroom_rl/core/_impl/list_dataset.py index f88e3c72..135054c3 100644 --- a/mushroom_rl/core/_impl/list_dataset.py +++ b/mushroom_rl/core/_impl/list_dataset.py @@ -102,3 +102,15 @@ def policy_state(self): @property def policy_next_state(self): return [step[7] for step in self._dataset] + + @property + def n_episodes(self): + n_episodes = 0 + for sample in self._dataset: + if sample[5] is True: + n_episodes += 1 + if self._dataset[-1][5] is not True: + n_episodes += 1 + + return n_episodes + diff --git a/mushroom_rl/core/_impl/numpy_dataset.py b/mushroom_rl/core/_impl/numpy_dataset.py index d8986d2c..d0de9f6f 100644 --- a/mushroom_rl/core/_impl/numpy_dataset.py +++ b/mushroom_rl/core/_impl/numpy_dataset.py @@ -193,4 +193,13 @@ def policy_next_state(self): @property def _is_stateful(self): - return self._policy_states is not None \ No newline at end of file + return self._policy_states is not None + + @property + def n_episodes(self): + n_episodes = self.last.sum() + + if not self.last[-1]: + n_episodes += 1 + + return n_episodes \ No newline at end of file diff --git a/mushroom_rl/core/_impl/torch_dataset.py b/mushroom_rl/core/_impl/torch_dataset.py index 7e7309e8..92f4b728 100644 --- a/mushroom_rl/core/_impl/torch_dataset.py +++ b/mushroom_rl/core/_impl/torch_dataset.py @@ -194,3 +194,12 @@ def policy_next_state(self): @property def _is_stateful(self): return self._policy_states is not None + + @property + def n_episodes(self): + n_episodes = self.last.sum() + + if not self.last[-1]: + n_episodes += 1 + + return n_episodes \ No newline at end of file diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index 1a31841e..95295b0a 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -254,6 +254,10 @@ def episodes_length(self): return lengths + @property + def n_episodes(self): + return self._data.n_episodes + @property def undiscounted_return(self): return self.compute_J() diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index 47420553..0469526e 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -24,6 +24,8 @@ def test_dataset(): mdp = GridWorld(3, 3, (2, 2)) dataset = generate_dataset(mdp, 10) + assert dataset.n_episodes == 10 + J = dataset.compute_J(mdp.info.gamma) J_test = np.array([4.304672100000001, 2.287679245496101, 3.138105960900001, 0.13302794647291147, 7.290000000000001, 1.8530201888518416, 1.3508517176729928, 0.011790184577738602, @@ -88,6 +90,10 @@ def test_dataset_creation(): assert vars(dataset).keys() == vars(new_list_dataset).keys() assert vars(dataset).keys() == vars(new_torch_dataset).keys() + assert new_numpy_dataset.n_episodes == dataset.n_episodes + assert new_list_dataset.n_episodes == dataset.n_episodes + assert new_torch_dataset.n_episodes == dataset.n_episodes + for array_1, array_2 in zip(parsed, new_numpy_dataset.parse()): assert np.array_equal(array_1, array_2)