In [14]:
import unittest
import numpy as np
import gym

In [15]:
class TestSpaces(unittest.TestCase):
    def test_box_space(self):
        # 定义 Box 空间
        low = np.array([-1.0, -2.0, -3.0])
        high = np.array([1.0, 2.0, 3.0])
        box_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)

        # 测试包含样本
        sample = np.array([0.0, 1.0, -1.0], dtype=np.float32)
        self.assertTrue(box_space.contains(sample))

        # 测试超出范围的样本
        out_of_range_sample = np.array([2.0, 1.0, -1.0], dtype=np.float32)
        self.assertFalse(box_space.contains(out_of_range_sample))

        # 测试样本的数据类型
        wrong_dtype_sample = np.array([0.0, 1.0, -1.0], dtype=np.float64)
        self.assertFalse(box_space.contains(wrong_dtype_sample))

        # 测试随机样本
        random_sample = box_space.sample()
        self.assertTrue(box_space.contains(random_sample))

    def test_dict_space(self):
        # 定义 Dict 空间
        observation_space = gym.spaces.Dict({
            'position': gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32),
            'velocity': gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32),
            'is_active': gym.spaces.Discrete(2)
        })

        # 测试包含样本
        sample = {
            'position': np.array([0.5, -0.5], dtype=np.float32),
            'velocity': np.array([0.1, -0.1], dtype=np.float32),
            'is_active': 1
        }
        self.assertTrue(observation_space.contains(sample))

        # 测试超出范围的样本
        out_of_range_sample = {
            'position': np.array([2.0, -0.5], dtype=np.float32),
            'velocity': np.array([0.1, -0.1], dtype=np.float32),
            'is_active': 1
        }
        self.assertFalse(observation_space.contains(out_of_range_sample))

        # 测试样本的数据类型
        wrong_dtype_sample = {
            'position': np.array([0.5, -0.5], dtype=np.float64),
            'velocity': np.array([0.1, -0.1], dtype=np.float32),
            'is_active': 1
        }
        self.assertFalse(observation_space.contains(wrong_dtype_sample))

        # 测试随机样本
        random_sample = observation_space.sample()
        self.assertTrue(observation_space.contains(random_sample))

# if __name__ == '__main__':
#     unittest.main()

In [17]:
test_spaces = TestSpaces()
test_spaces.test_box_space()
test_spaces.test_dict_space()

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [18]:
import gym
import numpy as np

# 定义一个 Dict 空间，包含一个 2 维的连续位置、一个 2 维的连续速度和一个离散的状态
observation_space = gym.spaces.Dict({
    'position': gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32),
    'velocity': gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32),
    'is_active': gym.spaces.Discrete(2)
})

# 检查一个样本是否在空间范围内
sample = {
    'position': np.array([0.5, -0.5], dtype=np.float32),
    'velocity': np.array([0.1, -0.1], dtype=np.float32),
    'is_active': 1
}
print(observation_space.contains(sample))  # 输出: True

# 生成一个随机样本
random_sample = observation_space.sample()
print(random_sample)

True
OrderedDict([('is_active', 0), ('position', array([ 0.11500746, -0.6305465 ], dtype=float32)), ('velocity', array([-0.32620978, -0.6026009 ], dtype=float32))])
