In [1]:
import gymnasium as gym


#自定义一个Wrapper
class Pendulum(gym.Wrapper):

    def __init__(self):
        env = gym.make('Pendulum-v1',render_mode="human")
        super().__init__(env)
        self.env = env

    def reset(self, seed=None):
        state, info = self.env.reset()
        return state, info

    def step(self, action):
        state, reward, done, truncated, info = self.env.step(action)
        return state, reward, done, truncated, info

print(Pendulum().reset())

[-0.9555506   0.29482716  0.889765  ]


In [2]:
#测试一个环境
def test(env, wrap_action_in_list=False):
    print(env)
    try:
        state = env.reset()
    except:
        state, _ = env.reset()
    over = False
    step = 0

    while not over:
        action = env.action_space.sample()

        if wrap_action_in_list:
            action = [action]
        try:
            next_state, reward, over, _ = env.step(action)
        except:
            next_state, reward, over, _, _ = env.step(action)
        if step % 20 == 0:
            print(step, state, action, reward)

        if step > 200:
            break

        state = next_state
        step += 1


test(Pendulum())

<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>
0 [ 0.8690126  -0.49478996  0.25852925] [0.38116938] -0.2712292164795527
20 [-0.84894484 -0.5284814   7.14058   ] [-0.36051726] -9.539666244964971
40 [ 0.47867265 -0.87799346 -4.017311  ] [-0.1569533] -3.9156727454337834
60 [ 0.50958234 -0.8604219   0.9194922 ] [1.3635069] -1.0492464685213325
80 [-0.5145037  -0.85748816  6.737064  ] [0.5846952] -7.069341633539481
100 [ 0.8777114   0.47918963 -1.3289554 ] [0.04540587] -0.2966280328836356
120 [0.62123334 0.78362566 3.0971074 ] [-1.5702893] -2.3434320261914268
140 [-0.09331109  0.995637   -4.415158  ] [1.4396266] -3.41896588222713
160 [0.22346951 0.97471094 4.3902297 ] [0.6518642] -5.304709052746031
180 [ 0.3853862  -0.92275536  3.92823   ] [-0.12849551] -2.063710473386545
200 [ 0.7604966 -0.6493419 -3.059499 ] [1.1502334] -1.9060460989891752


In [3]:
#修改最大步数
class StepLimitWrapper(gym.Wrapper):

    def __init__(self, env):
        super().__init__(env)
        self.current_step = 0

    def reset(self):
        self.current_step = 0
        return self.env.reset()

    def step(self, action):
        self.current_step += 1
        state, reward, done, _, info = self.env.step(action)

        #修改done字段
        if self.current_step >= 100:
            done = True

        return state, reward, done, info


test(StepLimitWrapper(Pendulum()))

<StepLimitWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [-0.17543022 -0.9844919  -0.9861332 ] [-0.37350124] -3.14988032699306
20 [-0.19588898  0.9806261   1.3804784 ] [1.7438195] -3.319296344256698
40 [-0.5027321 -0.8644423 -2.720407 ] [0.20915733] -5.139832699340967
60 [-0.8646059   0.50245064  4.2143946 ] [0.23335248] -8.61523780681135
80 [-0.99498296  0.10004468 -4.2221828 ] [1.9384729] -11.036434684496214


In [4]:
import numpy as np


#修改动作空间
class NormalizeActionWrapper(gym.Wrapper):

    def __init__(self, env):
        #获取动作空间
        action_space = env.action_space

        #动作空间必须是连续值
        assert isinstance(action_space, gym.spaces.Box)

        #重新定义动作空间,在正负一之间的连续值
        #这里其实只影响env.action_space.sample的返回结果
        #实际在计算时,还是正负2之间计算的
        env.action_space = gym.spaces.Box(low=-1,
                                          high=1,
                                          shape=action_space.shape,
                                          dtype=np.float32)

        super().__init__(env)

    def reset(self):
        return self.env.reset()

    def step(self, action):
        #重新缩放动作的值域
        action = action * 2.0

        if action > 2.0:
            action = 2.0

        if action < -2.0:
            action = -2.0

        return self.env.step(action)


test(NormalizeActionWrapper(Pendulum()))

<NormalizeActionWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [-0.54957944  0.83544147 -0.09078036] [-0.16019736] -4.763413029567849
20 [-0.9533661   0.30181634  1.9025816 ] [-0.49939758] -9.000762938304984
40 [-0.47326204 -0.8809217   2.1182206 ] [-0.24065462] -4.170255549783669
60 [-0.42249337 -0.906366   -1.3803447 ] [-0.5824233] -4.989840889349022
80 [-0.9957207  -0.09241352 -3.4013486 ] [0.37694225] -10.529846907911079
100 [-0.81958973  0.5729508  -4.77027   ] [0.1187015] -7.218297211810192
120 [-0.1382315   0.99039996 -6.0278034 ] [-0.92172736] -5.145344530552183
140 [ 0.4695109   0.88292664 -4.2409635 ] [0.88921374] -1.93999787628309
160 [-0.02428142 -0.99970514 -3.6422586 ] [-0.14378788] -5.2680628703789365
180 [-0.07344457 -0.9972993  -3.2208362 ] [0.8033563] -4.74368605110887
200 [-0.4509822  -0.89253294 -2.6284041 ] [-0.42209548] -6.056971459022459


In [6]:
from gym.wrappers import TimeLimit


#修改状态
class StateStepWrapper(gym.Wrapper):

    def __init__(self, env):

        #状态空间必须是连续值
        assert isinstance(env.observation_space, gym.spaces.Box)

        #增加一个新状态字段
        low = np.concatenate([env.observation_space.low, [0.0]])
        high = np.concatenate([env.observation_space.high, [1.0]])

        env.observation_space = gym.spaces.Box(low=low,
                                               high=high,
                                               dtype=np.float32)

        super().__init__(env)

        self.step_current = 0

    def reset(self):
        self.step_current = 0
        return np.concatenate([self.env.reset()[0], [0.0]])

    def step(self, action):
        self.step_current += 1
        state, reward, done, _, info = self.env.step(action)

        #根据step_max修改done
        if self.step_current >= 100:
            done = True

        return self.get_state(state), reward, done, info

    def get_state(self, state):
        #添加一个新的state字段
        state_step = self.step_current / 100

        return np.concatenate([state, [state_step]])


test(StateStepWrapper(Pendulum()))

<StateStepWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>


ValueError: zero-dimensional arrays cannot be concatenated

In [None]:
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

#使用Monitor Wrapper,会在训练的过程中输出rollout/ep_len_mean和rollout/ep_rew_mean,就是增加些日志
#gym升级到0.26以后失效了,可能是因为使用了自定义的wapper
env = DummyVecEnv([lambda: Monitor(Pendulum().env)])

A2C('MlpPolicy', env, verbose=1).learn(1000)

In [None]:
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack

#VecNormalize,他会对state和reward进行Normalize
#Pendulum = lambda: gym.make('Pendulum-v1',render_mode="human")
env = DummyVecEnv([Pendulum])
env = VecNormalize(env)

#state = env.reset()
#action = env.action_space.sample()
#print(env.step([action]))

test(env, wrap_action_in_list=True)