In [None]:
import os
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [None]:
random_seed = 42
torch.random.manual_seed(random_seed)
np.random.seed(random_seed)

# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gym.make('LunarLander-v2', render_mode='rgb_array')
observation_space = env.observation_space.shape[0]  # 状态空间
action_space = env.action_space.n   # 动作空间

In [None]:
class LunarLanderSolver:
    def __init__(self, hidden_size, input_size, output_size, learning_rate, eps, gamma):
        # 网络结构
        self.model = torch.nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=False),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size, bias=False),
            nn.Softmax(dim=-1)
        ).to(device)

        #优化器
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, eps=eps)
        self.reset()

        self.lr = learning_rate # 学习率
        self.eps = eps  # Epsilong
        self.gamma = gamma  #折扣

    def forward(self, x):
        return self.model(x)

    def reset(self):
        self.episode_actions = torch.tensor([], requires_grad=True, device=device)
        self.episode_rewards = []

    def save_checkpoint(self, directory, episode):
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = os.path.join(directory, 'checkpoint_{}.pth'.format(episode))
        torch.save(self.model.state_dict(), f=filename)
        print('保存当前模型至 \'{}\''.format(filename))

    def load_checkpoint(self, directory, filename):
        self.model.load_state_dict(torch.load(os.path.join(directory, filename)))
        print('重新开始训练 checkpoint \'{}\'.'.format(filename))
        return int(filename[11:-4])

    def backward(self):
        future_reward = 0
        rewards = []
        for r in self.episode_rewards[::-1]:
            future_reward = r + self.gamma * future_reward  # 折扣回报
            rewards.append(future_reward)
        rewards = torch.tensor(rewards[::-1], dtype=torch.float32, device=device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)   # 归一化
        loss = torch.sum(torch.mul(self.episode_actions, rewards).mul(-1))  # 损失函数
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.reset()

In [None]:
batch_size = 10
gamma = 0.99
lr = 1e-4
epsilon = 1e-4

# load_filename = 'checkpoint_6800.pth'
load_filename = None
save_directory = "./weight"
batch_rewards = []
episode = 0
max_episode = 2500
reward_list = []

model = LunarLanderSolver(hidden_size=1024, input_size=observation_space, output_size=action_space, learning_rate=lr, eps=epsilon, gamma=gamma)

if load_filename is not None:
    episode = model.load_checkpoint(save_directory, load_filename)
while episode<=max_episode:
    state = env.reset(seed=random_seed, options={})[0]
    done = False
    while not done:
        env.render()
        action_probs = model.forward(torch.tensor(state, dtype=torch.float32, device=device)) 
        distribution = Categorical(action_probs)    # 计算分布
        action = distribution.sample()  # 选择动作
        state, reward, done, _, _ = env.step(action.item())
        model.episode_actions = torch.cat([model.episode_actions, distribution.log_prob(action).reshape(1)])
        model.episode_rewards.append(reward)
        if done:
            batch_rewards.append(np.sum(model.episode_rewards))
            model.backward()
            episode += 1
            if episode % batch_size == 0:
                print('Batch: {}, average reward: {}'.format(episode // batch_size, np.array(batch_rewards).mean()))
                reward_list.append(np.array(batch_rewards).mean())
                batch_rewards = []
                
            if episode % 50 == 0 and save_directory is not None:
                model.save_checkpoint(save_directory, episode)


In [None]:
import matplotlib.pyplot as plt

plt.plot(reward_list)
plt.grid(True)
plt.xlim([0,2500])
# plt.ylim([-300,300])
plt.xlabel('Episode')
plt.ylabel('Batch Average Reward')