Skip to content

Commit

Permalink
Added n_episodes property to dataset
Browse files Browse the repository at this point in the history
- now the dataset can compute the number of episodes collected
- updated tests
  • Loading branch information
boris-il-forte committed Oct 31, 2023
1 parent 2ba0eb0 commit 4dac89b
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 1 deletion.
12 changes: 12 additions & 0 deletions mushroom_rl/core/_impl/list_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,15 @@ def policy_state(self):
@property
def policy_next_state(self):
return [step[7] for step in self._dataset]

@property
def n_episodes(self):
n_episodes = 0
for sample in self._dataset:
if sample[5] is True:
n_episodes += 1
if self._dataset[-1][5] is not True:
n_episodes += 1

return n_episodes

11 changes: 10 additions & 1 deletion mushroom_rl/core/_impl/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,13 @@ def policy_next_state(self):

@property
def _is_stateful(self):
return self._policy_states is not None
return self._policy_states is not None

@property
def n_episodes(self):
n_episodes = self.last.sum()

if not self.last[-1]:
n_episodes += 1

return n_episodes
9 changes: 9 additions & 0 deletions mushroom_rl/core/_impl/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,12 @@ def policy_next_state(self):
@property
def _is_stateful(self):
return self._policy_states is not None

@property
def n_episodes(self):
n_episodes = self.last.sum()

if not self.last[-1]:
n_episodes += 1

return n_episodes
4 changes: 4 additions & 0 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ def episodes_length(self):

return lengths

@property
def n_episodes(self):
return self._data.n_episodes

@property
def undiscounted_return(self):
return self.compute_J()
Expand Down
6 changes: 6 additions & 0 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def test_dataset():
mdp = GridWorld(3, 3, (2, 2))
dataset = generate_dataset(mdp, 10)

assert dataset.n_episodes == 10

J = dataset.compute_J(mdp.info.gamma)
J_test = np.array([4.304672100000001, 2.287679245496101, 3.138105960900001, 0.13302794647291147,
7.290000000000001, 1.8530201888518416, 1.3508517176729928, 0.011790184577738602,
Expand Down Expand Up @@ -88,6 +90,10 @@ def test_dataset_creation():
assert vars(dataset).keys() == vars(new_list_dataset).keys()
assert vars(dataset).keys() == vars(new_torch_dataset).keys()

assert new_numpy_dataset.n_episodes == dataset.n_episodes
assert new_list_dataset.n_episodes == dataset.n_episodes
assert new_torch_dataset.n_episodes == dataset.n_episodes

for array_1, array_2 in zip(parsed, new_numpy_dataset.parse()):
assert np.array_equal(array_1, array_2)

Expand Down

0 comments on commit 4dac89b

Please sign in to comment.