In [1]:
import gymnasium as gym

#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self):
        env = gym.make('CartPole-v1',render_mode="human")
        super().__init__(env)
        self.env = env

    def reset(self, seed=None):
        state, info = self.env.reset()
        return state, info

    def step(self, action):
        state, reward, done, truncated, info = self.env.step(action)
        return state, reward, done, truncated, info

env = MyWrapper()

env.reset()

(array([ 0.01381764, -0.01208821,  0.0242144 , -0.04279112], dtype=float32),
 {})

In [2]:
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


#自定义特征抽取层
class CustomCNN(BaseFeaturesExtractor):

    def __init__(self, observation_space: gym.spaces.Box, hidden_dim):
        super().__init__(observation_space, hidden_dim)

        self.sequential = torch.nn.Sequential(

            #[b, 4, 1, 1] -> [b, h, 1, 1]
            torch.nn.Conv2d(in_channels=observation_space.shape[0],
                            out_channels=hidden_dim,
                            kernel_size=1,
                            stride=1,
                            padding=0),
            torch.nn.ReLU(),

            #[b, h, 1, 1] -> [b, h, 1, 1]
            torch.nn.Conv2d(hidden_dim,
                            hidden_dim,
                            kernel_size=1,
                            stride=1,
                            padding=0),
            torch.nn.ReLU(),

            #[b, h, 1, 1] -> [b, h]
            torch.nn.Flatten(),

            #[b, h] -> [b, h]
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
        )

    def forward(self, state):
        b = state.shape[0]
        state = state.reshape(b, -1, 1, 1)
        return self.sequential(state)


model = PPO('CnnPolicy',
            env,
            policy_kwargs={
                'features_extractor_class': CustomCNN,
                'features_extractor_kwargs': {
                    'hidden_dim': 8
                },
            },
            verbose=0)
# 为什么不用输入observation_space这个变量
model

  from pandas.core.computation.check import NUMEXPR_INSTALLED


<stable_baselines3.ppo.ppo.PPO at 0x7bc6103277c0>

In [3]:
from stable_baselines3.common.evaluation import evaluate_policy

#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)



(25.1, 12.809761902549164)

In [4]:
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)

model.save('models/自定义特征抽取层')

Output()

In [5]:
model = PPO.load('models/自定义特征抽取层')

evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)

(167.1, 88.20141722217393)