In [1]:
import gymnasium as gym


#自定义一个Wrapper
class Pendulum(gym.Wrapper):

    def __init__(self):
        env = gym.make('Pendulum-v1')
        super().__init__(env)
        self.env = env

    def reset(self, seed=None, options=None):
        state, info = self.env.reset()
        return state, info

    def step(self, action):
        state, reward, done, truncated, info = self.env.step(action)
        return state, reward, done, truncated, info


Pendulum().reset()

(array([0.997827  , 0.06588829, 0.38405415], dtype=float32), {})

In [3]:
#测试一个环境
def test(env, wrap_action_in_list=False):
    print(env)

    state = env.reset()
    over = False
    step = 0

    while not over:
        action = env.action_space.sample()

        if wrap_action_in_list:
            action = [action]

        next_state, reward, over, _, _ = env.step(action)

        if step % 20 == 0:
            print(step, state, action, reward)

        if step > 200:
            break

        state = next_state
        step += 1


test(Pendulum())

<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>
0 (array([-0.4691058 , -0.883142  , -0.20973344], dtype=float32), {}) [-1.1045635] -4.245405942066593
20 [-0.4018474   0.91570663  1.414736  ] [-0.2861242] -4.13779416845475
40 [-0.63657784 -0.77121246 -3.7317877 ] [0.6696638] -6.504510714284118
60 [-0.83991814  0.5427131   4.2385383 ] [-1.2615361] -8.392370163169778
80 [-0.9989467   0.04588455 -5.5532975 ] [-1.5991619] -12.669777657641209
100 [-0.9768472  -0.21393828  5.9013104 ] [0.34193948] -12.04406833734451
120 [-0.9576539   0.28792185 -5.8473873 ] [-1.132296] -11.540334657338475
140 [-0.97199184 -0.23501462  6.2311244 ] [-0.5069268] -12.318249651150294
160 [-0.7329268 -0.6803075 -6.091003 ] [-1.8259742] -9.441779788987565
180 [ 0.99605465 -0.08874175 -1.3733063 ] [-0.77021796] -0.19708609086447199
200 [ 0.36763787  0.929969   -4.718802  ] [-1.8153288] -3.6564252435145272


In [4]:
#修改最大步数
class StepLimitWrapper(gym.Wrapper):

    def __init__(self, env):
        super().__init__(env)
        self.current_step = 0

    def reset(self, seed=None, options=None):
        self.current_step = 0
        return self.env.reset()

    def step(self, action):
        self.current_step += 1
        state, reward, done, truncated, info  = self.env.step(action)

        #修改done字段
        if self.current_step >= 100:
            done = True

        return state, reward, done, truncated, info


test(StepLimitWrapper(Pendulum()))

<StepLimitWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 (array([-0.09608521,  0.99537313,  0.2529563 ], dtype=float32), {}) [1.6271371] -2.788035294230727
20 [ 0.31086937 -0.9504526  -1.1583184 ] [-1.3020552] -1.7101093517205692
40 [ 0.54534     0.83821493 -0.27485672] [-1.5568552] -0.9980177151085119
60 [ 0.62841   -0.7778823  3.09197  ] [-0.92350626] -1.751276482180738
80 [-0.9394501   0.34268573 -7.9643903 ] [1.7432866] -14.140440238265683


In [5]:
import numpy as np


#修改动作空间
class NormalizeActionWrapper(gym.Wrapper):

    def __init__(self, env):
        #获取动作空间
        action_space = env.action_space

        #动作空间必须是连续值
        assert isinstance(action_space, gym.spaces.Box)

        #重新定义动作空间,在正负一之间的连续值
        #这里其实只影响env.action_space.sample的返回结果
        #实际在计算时,还是正负2之间计算的
        env.action_space = gym.spaces.Box(low=-1,
                                          high=1,
                                          shape=action_space.shape,
                                          dtype=np.float32)

        super().__init__(env)

    def reset(self, seed=None, options=None):
        return self.env.reset()

    def step(self, action):
        #重新缩放动作的值域
        action = action * 2.0

        if action > 2.0:
            action = 2.0

        if action < -2.0:
            action = -2.0

        return self.env.step(action)


test(NormalizeActionWrapper(Pendulum()))

<NormalizeActionWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 (array([0.8683225 , 0.49600008, 0.26862347], dtype=float32), {}) [0.15149884] -0.27665433265454015
20 [ 0.2755823 -0.9612775  4.7309628] [0.7567814] -3.908724785659091
40 [ 0.823339  -0.5675499 -1.5734091] [0.04376562] -0.6118140816421935
60 [ 0.6945427  0.7194515 -2.340134 ] [-0.47278035] -1.1933454085510684
80 [-0.94381416  0.3304767   7.1183777 ] [-0.54940784] -12.935150975053574
100 [ 0.8611308 -0.5083835 -1.381999 ] [0.16036637] -0.47551090241353305
120 [ 0.8850032  0.4655849 -3.7487524] [-0.34694064] -1.6403379580771387
140 [ 0.09859236  0.9951279  -4.6197014 ] [0.44980657] -4.301885600197689
160 [-0.85694766 -0.5154034   6.9798055 ] [0.49159285] -11.633329626914147
180 [-0.62242013 -0.7826833  -6.357145  ] [-0.81554466] -9.073367291695241
200 [0.2647891  0.96430635 2.7944098 ] [0.6231875] -2.479742752274073


In [6]:
from gymnasium.wrappers import TimeLimit


#修改状态
class StateStepWrapper(gym.Wrapper):

    def __init__(self, env):

        #状态空间必须是连续值
        assert isinstance(env.observation_space, gym.spaces.Box)

        #增加一个新状态字段
        low = np.concatenate([env.observation_space.low, [0.0]])
        high = np.concatenate([env.observation_space.high, [1.0]])

        env.observation_space = gym.spaces.Box(low=low,
                                               high=high,
                                               dtype=np.float32)

        super().__init__(env)

        self.step_current = 0

    def reset(self, seed=None, options=None):
        self.step_current = 0
        return np.concatenate([self.env.reset(), [0.0]])

    def step(self, action):
        self.step_current += 1
        state, reward, done, truncated, info = self.env.step(action)

        #根据step_max修改done
        if self.step_current >= 100:
            done = True

        return self.get_state(state), reward, done, truncated, info

    def get_state(self, state):
        #添加一个新的state字段
        state_step = self.step_current / 100

        return np.concatenate([state, [state_step]])


test(StateStepWrapper(Pendulum()))

<StateStepWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>


  gym.logger.warn(
  gym.logger.warn(


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.

In [6]:
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

#使用Monitor Wrapper,会在训练的过程中输出rollout/ep_len_mean和rollout/ep_rew_mean,就是增加些日志
#gym升级到0.26以后失效了,可能是因为使用了自定义的wapper
env = DummyVecEnv([lambda: Monitor(Pendulum())])

A2C('MlpPolicy', env, verbose=1).learn(1000)

Using cpu device
------------------------------------
| time/                 |          |
|    fps                | 919      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -1.44    |
|    explained_variance | -0.00454 |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -43.2    |
|    std                | 1.02     |
|    value_loss         | 976      |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 897      |
|    iterations         | 200      |
|    time_elapsed       | 1        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -1.43    |
|    explained_variance | 4.43e-05 |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss     

<stable_baselines3.a2c.a2c.A2C at 0x7f94cc7fee80>

In [7]:
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack

#VecNormalize,他会对state和reward进行Normalize
env = DummyVecEnv([Pendulum])
env = VecNormalize(env)

test(env, wrap_action_in_list=True)

<stable_baselines3.common.vec_env.vec_normalize.VecNormalize object at 0x7f949b273e20>
0 [[-0.00487219  0.00638567  0.00554428]] [array([1.8279244], dtype=float32)] [-10.]
20 [[-0.03349784 -0.66757905 -2.173565  ]] [array([-1.3404763], dtype=float32)] [-0.16482905]
40 [[-1.4011567   0.07794451  1.3022798 ]] [array([1.6657785], dtype=float32)] [-0.16956142]
60 [[-1.3572015   0.41527337 -1.5391531 ]] [array([1.3103601], dtype=float32)] [-0.12919044]
80 [[-0.34073314 -0.9262497   1.2184559 ]] [array([0.99134326], dtype=float32)] [-0.07326685]
100 [[ 1.6391766  1.4822493 -1.3587774]] [array([-1.5195391], dtype=float32)] [-0.04477553]
120 [[-0.01510759 -1.150826    1.7960454 ]] [array([-0.46556672], dtype=float32)] [-0.07235025]
140 [[-0.5499776  -0.83480734 -2.1066453 ]] [array([-0.03180405], dtype=float32)] [-0.08347733]
160 [[ 2.0203962  -0.5923113  -0.62100804]] [array([1.302856], dtype=float32)] [-0.00734869]
180 [[ 1.8016039   0.8581704  -0.43550822]] [array([1.2153829], dtype=float32