Skip to content

Commit

Permalink
bug fix for scalar observation spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
ahalev committed Dec 28, 2022
1 parent ded4830 commit d07d1c6
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/pymgrid/utils/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
class ModuleSpace(Space):
def __init__(self, unnormalized_low, unnormalized_high, shape=None, dtype=np.float64, seed=None):

self._unnormalized = Box(low=unnormalized_low.astype(np.float64),
high=unnormalized_high.astype(np.float64),
low = np.float64(unnormalized_low) if np.isscalar(unnormalized_low) else unnormalized_low.astype(np.float64)
high = np.float64(unnormalized_high) if np.isscalar(unnormalized_high) else unnormalized_high.astype(np.float64)

self._unnormalized = Box(low=low,
high=high,
shape=shape,
dtype=dtype)

Expand Down

0 comments on commit d07d1c6

Please sign in to comment.