In [2]:
import numpy as np
import gymnasium as gym
from stable_baselines3.common.env_checker import check_env


class GoLeftEnv(gym.Env):
    #支持的render模式,在jupyter中不支持human模式
    metadata = {'render.modes': ['console']}

    def __init__(self, init_pos=9, max_time=100):
        super().__init__()

        #初始位置
        self.pos = init_pos
        self.max_time = max_time
        self.time = 0

        #动作空间,这个环境中只有左,右两个动作
        self.action_space = gym.spaces.Discrete(2)

        #状态空间,一维数轴
        self.observation_space = gym.spaces.Box(low=0,
                                                high=10,
                                                shape=(1, ),
                                                dtype=np.float32)

    def reset(self, seed=None, options=None):
        #重置位置
        self.pos = 9

        #当前状态
        return np.array([self.pos], dtype=np.float32), {}

    def step(self, action):
        #执行动作
        if action == 0:
            self.pos -= 1

        if action == 1:
            self.pos += 1

        self.pos = np.clip(self.pos, 0, 10)
        self.time += 1

        #判断游戏结束
        done = self.pos == 0

        #判断时间截止
        truncate = self.time >= self.max_time

        #给予reward
        reward = 1 if self.pos == 0 else 0

        return np.array([self.pos], dtype=np.float32), reward, bool(done), bool(truncate), {}

    def render(self, mode='console'):
        if mode != 'console':
            raise NotImplementedError()
        print(self.pos)

    def close(self):
        pass


env = GoLeftEnv()

#检查环境是否合法
check_env(env, warn=True)

In [4]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

#包装环境
train_env = make_vec_env(lambda: env, n_envs=1)

#定义模型
model = PPO('MlpPolicy', train_env, verbose=0, device='cpu')

In [6]:
import gymnasium as gym


#测试一个环境
def test(model, env):
    state, _ = env.reset()
    done = False
    step = 0

    for i in range(100):
        action = model.predict(state)[0]

        next_state, reward, done, _, _ = env.step(action)

        if step % 1 == 0:
            print(step, state, action, reward)

        state = next_state
        step += 1

        if done:
            break


test(model, env)

0 [9.] 0 0
1 [8.] 0 0
2 [7.] 1 0
3 [8.] 0 0
4 [7.] 1 0
5 [8.] 1 0
6 [9.] 1 0
7 [10.] 1 0
8 [10.] 0 0
9 [9.] 1 0
10 [10.] 1 0
11 [10.] 0 0
12 [9.] 1 0
13 [10.] 1 0
14 [10.] 0 0
15 [9.] 1 0
16 [10.] 0 0
17 [9.] 0 0
18 [8.] 0 0
19 [7.] 1 0
20 [8.] 0 0
21 [7.] 1 0
22 [8.] 1 0
23 [9.] 0 0
24 [8.] 0 0
25 [7.] 0 0
26 [6.] 1 0
27 [7.] 0 0
28 [6.] 1 0
29 [7.] 1 0
30 [8.] 0 0
31 [7.] 1 0
32 [8.] 0 0
33 [7.] 1 0
34 [8.] 0 0
35 [7.] 1 0
36 [8.] 0 0
37 [7.] 1 0
38 [8.] 1 0
39 [9.] 0 0
40 [8.] 0 0
41 [7.] 0 0
42 [6.] 1 0
43 [7.] 1 0
44 [8.] 0 0
45 [7.] 1 0
46 [8.] 0 0
47 [7.] 1 0
48 [8.] 1 0
49 [9.] 0 0
50 [8.] 0 0
51 [7.] 0 0
52 [6.] 1 0
53 [7.] 0 0
54 [6.] 0 0
55 [5.] 1 0
56 [6.] 0 0
57 [5.] 0 0
58 [4.] 0 0
59 [3.] 0 0
60 [2.] 0 0
61 [1.] 0 1


In [7]:
model.learn(5000)

#测试
test(model, env)

0 [9.] 1 0
1 [10.] 1 0
2 [10.] 1 0
3 [10.] 1 0
4 [10.] 1 0
5 [10.] 1 0
6 [10.] 1 0
7 [10.] 1 0
8 [10.] 1 0
9 [10.] 1 0
10 [10.] 1 0
11 [10.] 1 0
12 [10.] 1 0
13 [10.] 1 0
14 [10.] 1 0
15 [10.] 1 0
16 [10.] 1 0
17 [10.] 1 0
18 [10.] 1 0
19 [10.] 1 0
20 [10.] 1 0
21 [10.] 1 0
22 [10.] 1 0
23 [10.] 1 0
24 [10.] 1 0
25 [10.] 1 0
26 [10.] 1 0
27 [10.] 1 0
28 [10.] 1 0
29 [10.] 1 0
30 [10.] 1 0
31 [10.] 1 0
32 [10.] 1 0
33 [10.] 1 0
34 [10.] 1 0
35 [10.] 1 0
36 [10.] 1 0
37 [10.] 1 0
38 [10.] 1 0
39 [10.] 1 0
40 [10.] 1 0
41 [10.] 1 0
42 [10.] 1 0
43 [10.] 1 0
44 [10.] 1 0
45 [10.] 1 0
46 [10.] 1 0
47 [10.] 1 0
48 [10.] 1 0
49 [10.] 1 0
50 [10.] 1 0
51 [10.] 1 0
52 [10.] 1 0
53 [10.] 1 0
54 [10.] 1 0
55 [10.] 1 0
56 [10.] 1 0
57 [10.] 1 0
58 [10.] 1 0
59 [10.] 1 0
60 [10.] 1 0
61 [10.] 1 0
62 [10.] 1 0
63 [10.] 1 0
64 [10.] 1 0
65 [10.] 1 0
66 [10.] 1 0
67 [10.] 1 0
68 [10.] 1 0
69 [10.] 1 0
70 [10.] 1 0
71 [10.] 1 0
72 [10.] 1 0
73 [10.] 1 0
74 [10.] 1 0
75 [10.] 1 0
76 [10.] 1 0
77 [10.] 1