In [1]:
import gym
import random
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import collections
import copy

Common functions definition

In [8]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) 

    def add(self, state, action, reward, next_state, done): 
        self.buffer.append((state, action, reward, next_state, done)) 

    def sample(self, batch_size): 
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done 

    def size(self): 
        return len(self.buffer)

def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))

def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, done, _ = env.step(action)
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list

def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, done, _ = env.step(action)
                    replay_buffer.add(state, action, reward, next_state, done)
                    state = next_state
                    episode_return += reward
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                        transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
                        agent.update(transition_dict)
                return_list.append(episode_return)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list


def compute_advantage(gamma, lmbda, td_delta):
    print(td_delta.shape)
    # torch.Size([len of trajectory, 1])
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)

In [3]:
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class TRPO:
    """ TRPO算法 """
    def __init__(self, hidden_dim, state_space, action_space, lmbda,
                 kl_constraint, alpha, critic_lr, gamma, device):
        state_dim = state_space.shape[0]
        action_dim = action_space.n
        # 策略网络参数不需要优化器更新
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda  # GAE参数
        self.kl_constraint = kl_constraint  # KL距离最大限制
        self.alpha = alpha  # 线性搜索参数
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def hessian_matrix_vector_product(self, states, old_action_dists, vector):
        # 计算黑塞矩阵和一个向量的乘积
        new_action_dists = torch.distributions.Categorical(self.actor(states))
        kl = torch.mean(
            torch.distributions.kl.kl_divergence(old_action_dists,
                                                 new_action_dists))  # 计算平均KL距离
        kl_grad = torch.autograd.grad(kl,
                                      self.actor.parameters(),
                                      create_graph=True)
        kl_grad_vector = torch.cat([grad.view(-1) for grad in kl_grad])
        # KL距离的梯度先和向量进行点积运算
        kl_grad_vector_product = torch.dot(kl_grad_vector, vector)
        grad2 = torch.autograd.grad(kl_grad_vector_product,
                                    self.actor.parameters())
        grad2_vector = torch.cat([grad.view(-1) for grad in grad2])
        return grad2_vector

    def conjugate_gradient(self, grad, states, old_action_dists):  # 共轭梯度法求解方程
        x = torch.zeros_like(grad)
        r = grad.clone()
        p = grad.clone()
        rdotr = torch.dot(r, r)
        for i in range(10):  # 共轭梯度主循环
            Hp = self.hessian_matrix_vector_product(states, old_action_dists,
                                                    p)
            alpha = rdotr / torch.dot(p, Hp)
            x += alpha * p
            r -= alpha * Hp
            new_rdotr = torch.dot(r, r)
            if new_rdotr < 1e-10:
                break
            beta = new_rdotr / rdotr
            p = r + beta * p
            rdotr = new_rdotr
        return x

    def compute_surrogate_obj(self, states, actions, advantage, old_log_probs,
                              actor):  # 计算策略目标
        log_probs = torch.log(actor(states).gather(1, actions))
        ratio = torch.exp(log_probs - old_log_probs)
        return torch.mean(ratio * advantage)

    def line_search(self, states, actions, advantage, old_log_probs,
                    old_action_dists, max_vec):  # 线性搜索
        old_para = torch.nn.utils.convert_parameters.parameters_to_vector(
            self.actor.parameters())
        old_obj = self.compute_surrogate_obj(states, actions, advantage,
                                             old_log_probs, self.actor)
        for i in range(15):  # 线性搜索主循环
            coef = self.alpha**i
            new_para = old_para + coef * max_vec
            new_actor = copy.deepcopy(self.actor)
            torch.nn.utils.convert_parameters.vector_to_parameters(
                new_para, new_actor.parameters())
            new_action_dists = torch.distributions.Categorical(
                new_actor(states))
            kl_div = torch.mean(
                torch.distributions.kl.kl_divergence(old_action_dists,
                                                     new_action_dists))
            new_obj = self.compute_surrogate_obj(states, actions, advantage,
                                                 old_log_probs, new_actor)
            if new_obj > old_obj and kl_div < self.kl_constraint:
                return new_para
        return old_para

    def policy_learn(self, states, actions, old_action_dists, old_log_probs,
                     advantage):  # 更新策略函数
        surrogate_obj = self.compute_surrogate_obj(states, actions, advantage,
                                                   old_log_probs, self.actor)
        grads = torch.autograd.grad(surrogate_obj, self.actor.parameters())
        obj_grad = torch.cat([grad.view(-1) for grad in grads]).detach()
        # 用共轭梯度法计算x = H^(-1)g
        descent_direction = self.conjugate_gradient(obj_grad, states,
                                                    old_action_dists)

        Hd = self.hessian_matrix_vector_product(states, old_action_dists,
                                                descent_direction)
        max_coef = torch.sqrt(2 * self.kl_constraint /
                              (torch.dot(descent_direction, Hd) + 1e-8))
        new_para = self.line_search(states, actions, advantage, old_log_probs,
                                    old_action_dists,
                                    descent_direction * max_coef)  # 线性搜索
        torch.nn.utils.convert_parameters.vector_to_parameters(
            new_para, self.actor.parameters())  # 用线性搜索后的参数更新策略

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)
        advantage = compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
        old_action_dists = torch.distributions.Categorical(self.actor(states).detach())
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()  # 更新价值函数
        # 更新策略函数
        self.policy_learn(states, actions, old_action_dists, old_log_probs, advantage)

In [10]:
num_episodes = 500
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
critic_lr = 1e-2
kl_constraint = 0.0005
alpha = 0.5
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
agent = TRPO(hidden_dim, env.observation_space, env.action_space, lmbda,
             kl_constraint, alpha, critic_lr, gamma, device)
return_list = train_on_policy_agent(env, agent, num_episodes)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()

mv_return = moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()

Iteration 0:   4%|▍         | 2/50 [00:00<00:02, 16.75it/s]

torch.Size([20, 1])
torch.Size([10, 1])
torch.Size([17, 1])
torch.Size([14, 1])


Iteration 0:  16%|█▌        | 8/50 [00:00<00:02, 17.37it/s]

torch.Size([50, 1])
torch.Size([20, 1])
torch.Size([17, 1])
torch.Size([11, 1])


Iteration 0:  20%|██        | 10/50 [00:00<00:02, 14.11it/s, episode=10, return=27.400]

torch.Size([71, 1])
torch.Size([44, 1])
torch.Size([91, 1])


Iteration 0:  28%|██▊       | 14/50 [00:00<00:02, 14.78it/s, episode=10, return=27.400]

torch.Size([29, 1])
torch.Size([27, 1])
torch.Size([43, 1])
torch.Size([33, 1])


Iteration 0:  36%|███▌      | 18/50 [00:01<00:02, 15.61it/s, episode=10, return=27.400]

torch.Size([52, 1])
torch.Size([69, 1])
torch.Size([32, 1])
torch.Size([24, 1])


Iteration 0:  40%|████      | 20/50 [00:01<00:01, 15.80it/s, episode=20, return=48.200]

torch.Size([82, 1])
torch.Size([53, 1])
torch.Size([152, 1])


Iteration 0:  48%|████▊     | 24/50 [00:01<00:01, 14.29it/s, episode=20, return=48.200]

torch.Size([33, 1])
torch.Size([65, 1])
torch.Size([63, 1])


Iteration 0:  52%|█████▏    | 26/50 [00:01<00:01, 14.27it/s, episode=20, return=48.200]

torch.Size([90, 1])
torch.Size([91, 1])


Iteration 0:  56%|█████▌    | 28/50 [00:01<00:01, 12.50it/s, episode=20, return=48.200]

torch.Size([184, 1])
torch.Size([109, 1])
torch.Size([95, 1])


Iteration 0:  64%|██████▍   | 32/50 [00:02<00:01, 12.64it/s, episode=30, return=93.500]

torch.Size([85, 1])
torch.Size([59, 1])
torch.Size([126, 1])


Iteration 0:  72%|███████▏  | 36/50 [00:02<00:01, 12.72it/s, episode=30, return=93.500]

torch.Size([70, 1])
torch.Size([65, 1])
torch.Size([61, 1])


Iteration 0:  76%|███████▌  | 38/50 [00:02<00:00, 13.07it/s, episode=30, return=93.500]

torch.Size([57, 1])
torch.Size([60, 1])
torch.Size([87, 1])


Iteration 0:  80%|████████  | 40/50 [00:02<00:00, 13.05it/s, episode=40, return=72.000]

torch.Size([50, 1])
torch.Size([164, 1])


Iteration 0:  88%|████████▊ | 44/50 [00:03<00:00, 11.98it/s, episode=40, return=72.000]

torch.Size([139, 1])
torch.Size([93, 1])
torch.Size([47, 1])


Iteration 0:  92%|█████████▏| 46/50 [00:03<00:00, 12.47it/s, episode=40, return=72.000]

torch.Size([81, 1])
torch.Size([68, 1])
torch.Size([97, 1])


Iteration 0: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s, episode=50, return=90.400]


torch.Size([59, 1])
torch.Size([73, 1])
torch.Size([83, 1])


Iteration 1:   4%|▍         | 2/50 [00:00<00:03, 12.16it/s]

torch.Size([75, 1])
torch.Size([94, 1])
torch.Size([70, 1])


Iteration 1:  12%|█▏        | 6/50 [00:00<00:03, 13.95it/s]

torch.Size([50, 1])
torch.Size([56, 1])
torch.Size([97, 1])


Iteration 1:  16%|█▌        | 8/50 [00:00<00:03, 13.22it/s]

torch.Size([90, 1])
torch.Size([72, 1])
torch.Size([76, 1])


Iteration 1:  24%|██▍       | 12/50 [00:00<00:03, 12.64it/s, episode=60, return=76.200]

torch.Size([82, 1])
torch.Size([98, 1])
torch.Size([68, 1])


Iteration 1:  28%|██▊       | 14/50 [00:01<00:03, 11.22it/s, episode=60, return=76.200]

torch.Size([127, 1])
torch.Size([133, 1])
torch.Size([88, 1])


Iteration 1:  36%|███▌      | 18/50 [00:01<00:02, 11.75it/s, episode=60, return=76.200]

torch.Size([89, 1])
torch.Size([85, 1])
torch.Size([60, 1])


Iteration 1:  40%|████      | 20/50 [00:01<00:02, 12.37it/s, episode=70, return=89.900]

torch.Size([86, 1])
torch.Size([65, 1])
torch.Size([166, 1])


Iteration 1:  48%|████▊     | 24/50 [00:01<00:02, 11.79it/s, episode=70, return=89.900]

torch.Size([74, 1])
torch.Size([99, 1])
torch.Size([86, 1])


Iteration 1:  52%|█████▏    | 26/50 [00:02<00:01, 12.28it/s, episode=70, return=89.900]

torch.Size([72, 1])
torch.Size([86, 1])
torch.Size([97, 1])


Iteration 1:  56%|█████▌    | 28/50 [00:02<00:01, 11.51it/s, episode=70, return=89.900]

torch.Size([147, 1])
torch.Size([62, 1])
torch.Size([118, 1])


Iteration 1:  64%|██████▍   | 32/50 [00:02<00:01, 11.51it/s, episode=80, return=100.700]

torch.Size([87, 1])
torch.Size([152, 1])
torch.Size([95, 1])


Iteration 1:  68%|██████▊   | 34/50 [00:02<00:01, 11.59it/s, episode=80, return=100.700]

torch.Size([127, 1])
torch.Size([152, 1])
torch.Size([87, 1])


Iteration 1:  76%|███████▌  | 38/50 [00:03<00:01, 10.26it/s, episode=80, return=100.700]

torch.Size([89, 1])
torch.Size([152, 1])
torch.Size([87, 1])


Iteration 1:  80%|████████  | 40/50 [00:03<00:00, 10.68it/s, episode=90, return=112.500]

torch.Size([97, 1])
torch.Size([72, 1])
torch.Size([85, 1])


Iteration 1:  88%|████████▊ | 44/50 [00:03<00:00, 11.29it/s, episode=90, return=112.500]

torch.Size([50, 1])
torch.Size([134, 1])
torch.Size([86, 1])


Iteration 1:  96%|█████████▌| 48/50 [00:04<00:00, 11.88it/s, episode=90, return=112.500]

torch.Size([93, 1])
torch.Size([59, 1])
torch.Size([89, 1])


Iteration 1: 100%|██████████| 50/50 [00:04<00:00, 11.76it/s, episode=100, return=80.100]


torch.Size([54, 1])
torch.Size([79, 1])


Iteration 2:   0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([67, 1])


Iteration 2:   4%|▍         | 2/50 [00:00<00:04, 10.32it/s]

torch.Size([122, 1])
torch.Size([161, 1])


Iteration 2:   8%|▊         | 4/50 [00:00<00:04,  9.22it/s]

torch.Size([126, 1])
torch.Size([81, 1])


Iteration 2:  12%|█▏        | 6/50 [00:00<00:04, 10.61it/s]

torch.Size([60, 1])
torch.Size([200, 1])


Iteration 2:  16%|█▌        | 8/50 [00:00<00:04,  9.50it/s]

torch.Size([94, 1])


Iteration 2:  18%|█▊        | 9/50 [00:00<00:04,  8.81it/s]

torch.Size([147, 1])


Iteration 2:  18%|█▊        | 9/50 [00:01<00:04,  8.81it/s, episode=110, return=113.500]

torch.Size([77, 1])


Iteration 2:  22%|██▏       | 11/50 [00:01<00:04,  9.17it/s, episode=110, return=113.500]

torch.Size([127, 1])
torch.Size([101, 1])
torch.Size([111, 1])


Iteration 2:  26%|██▌       | 13/50 [00:01<00:03,  9.61it/s, episode=110, return=113.500]

torch.Size([97, 1])


Iteration 2:  30%|███       | 15/50 [00:01<00:03, 10.07it/s, episode=110, return=113.500]

torch.Size([94, 1])
torch.Size([64, 1])
torch.Size([149, 1])


Iteration 2:  34%|███▍      | 17/50 [00:01<00:03, 10.11it/s, episode=110, return=113.500]

torch.Size([106, 1])
torch.Size([76, 1])


Iteration 2:  38%|███▊      | 19/50 [00:01<00:02, 10.57it/s, episode=110, return=113.500]

torch.Size([118, 1])


Iteration 2:  42%|████▏     | 21/50 [00:02<00:02,  9.88it/s, episode=120, return=104.300]

torch.Size([142, 1])
torch.Size([100, 1])
torch.Size([97, 1])


Iteration 2:  46%|████▌     | 23/50 [00:02<00:02, 10.22it/s, episode=120, return=104.300]

torch.Size([91, 1])


Iteration 2:  50%|█████     | 25/50 [00:02<00:02, 10.20it/s, episode=120, return=104.300]

torch.Size([150, 1])
torch.Size([115, 1])


Iteration 2:  54%|█████▍    | 27/50 [00:02<00:02,  9.64it/s, episode=120, return=104.300]

torch.Size([154, 1])
torch.Size([94, 1])
torch.Size([103, 1])


Iteration 2:  58%|█████▊    | 29/50 [00:02<00:02, 10.16it/s, episode=120, return=104.300]

torch.Size([110, 1])


Iteration 2:  62%|██████▏   | 31/50 [00:03<00:01,  9.73it/s, episode=130, return=115.600]

torch.Size([116, 1])


Iteration 2:  64%|██████▍   | 32/50 [00:03<00:01,  9.48it/s, episode=130, return=115.600]

torch.Size([164, 1])


Iteration 2:  64%|██████▍   | 32/50 [00:03<00:01,  9.47it/s, episode=130, return=115.600]

torch.Size([137, 1])





KeyboardInterrupt: 

: 

In [9]:
return_list = train_on_policy_agent(env, agent, 10)

Iteration 0: 100%|██████████| 1/1 [00:00<00:00,  5.76it/s]


torch.Size([200, 1])


Iteration 1: 100%|██████████| 1/1 [00:00<00:00,  6.67it/s]


torch.Size([200, 1])


Iteration 2: 100%|██████████| 1/1 [00:00<00:00,  6.66it/s]


torch.Size([200, 1])


Iteration 3: 100%|██████████| 1/1 [00:00<00:00,  7.05it/s]


torch.Size([200, 1])


Iteration 4: 100%|██████████| 1/1 [00:00<00:00,  7.41it/s]


torch.Size([200, 1])


Iteration 5: 100%|██████████| 1/1 [00:00<00:00,  7.20it/s]


torch.Size([200, 1])


Iteration 6: 100%|██████████| 1/1 [00:00<00:00,  6.61it/s]


torch.Size([200, 1])


Iteration 7: 100%|██████████| 1/1 [00:00<00:00,  7.28it/s]


torch.Size([200, 1])


Iteration 8: 100%|██████████| 1/1 [00:00<00:00,  7.04it/s]


torch.Size([200, 1])


Iteration 9: 100%|██████████| 1/1 [00:00<00:00,  6.58it/s]

torch.Size([200, 1])





In [5]:
class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = 2.0 * torch.tanh(self.fc_mu(x))
        std = F.softplus(self.fc_std(x))
        return mu, std  # 高斯分布的均值和标准差


class TRPOContinuous:
    """ 处理连续动作的TRPO算法 """
    def __init__(self, hidden_dim, state_space, action_space, lmbda,
                 kl_constraint, alpha, critic_lr, gamma, device):
        state_dim = state_space.shape[0]
        action_dim = action_space.shape[0]
        self.actor = PolicyNetContinuous(state_dim, hidden_dim,
                                         action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.kl_constraint = kl_constraint
        self.alpha = alpha
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        mu, std = self.actor(state)
        action_dist = torch.distributions.Normal(mu, std)
        action = action_dist.sample()
        return [action.item()]

    def hessian_matrix_vector_product(self,
                                      states,
                                      old_action_dists,
                                      vector,
                                      damping=0.1):
        mu, std = self.actor(states)
        new_action_dists = torch.distributions.Normal(mu, std)
        kl = torch.mean(
            torch.distributions.kl.kl_divergence(old_action_dists,
                                                 new_action_dists))
        kl_grad = torch.autograd.grad(kl,
                                      self.actor.parameters(),
                                      create_graph=True)
        kl_grad_vector = torch.cat([grad.view(-1) for grad in kl_grad])
        kl_grad_vector_product = torch.dot(kl_grad_vector, vector)
        grad2 = torch.autograd.grad(kl_grad_vector_product,
                                    self.actor.parameters())
        grad2_vector = torch.cat(
            [grad.contiguous().view(-1) for grad in grad2])
        return grad2_vector + damping * vector

    def conjugate_gradient(self, grad, states, old_action_dists):
        x = torch.zeros_like(grad)
        r = grad.clone()
        p = grad.clone()
        rdotr = torch.dot(r, r)
        for i in range(10):
            Hp = self.hessian_matrix_vector_product(states, old_action_dists,
                                                    p)
            alpha = rdotr / torch.dot(p, Hp)
            x += alpha * p
            r -= alpha * Hp
            new_rdotr = torch.dot(r, r)
            if new_rdotr < 1e-10:
                break
            beta = new_rdotr / rdotr
            p = r + beta * p
            rdotr = new_rdotr
        return x

    def compute_surrogate_obj(self, states, actions, advantage, old_log_probs,
                              actor):
        mu, std = actor(states)
        action_dists = torch.distributions.Normal(mu, std)
        log_probs = action_dists.log_prob(actions)
        ratio = torch.exp(log_probs - old_log_probs)
        return torch.mean(ratio * advantage)

    def line_search(self, states, actions, advantage, old_log_probs,
                    old_action_dists, max_vec):
        old_para = torch.nn.utils.convert_parameters.parameters_to_vector(
            self.actor.parameters())
        old_obj = self.compute_surrogate_obj(states, actions, advantage,
                                             old_log_probs, self.actor)
        for i in range(15):
            coef = self.alpha**i
            new_para = old_para + coef * max_vec
            new_actor = copy.deepcopy(self.actor)
            torch.nn.utils.convert_parameters.vector_to_parameters(
                new_para, new_actor.parameters())
            mu, std = new_actor(states)
            new_action_dists = torch.distributions.Normal(mu, std)
            kl_div = torch.mean(
                torch.distributions.kl.kl_divergence(old_action_dists,
                                                     new_action_dists))
            new_obj = self.compute_surrogate_obj(states, actions, advantage,
                                                 old_log_probs, new_actor)
            if new_obj > old_obj and kl_div < self.kl_constraint:
                return new_para
        return old_para

    def policy_learn(self, states, actions, old_action_dists, old_log_probs,
                     advantage):
        surrogate_obj = self.compute_surrogate_obj(states, actions, advantage,
                                                   old_log_probs, self.actor)
        grads = torch.autograd.grad(surrogate_obj, self.actor.parameters())
        obj_grad = torch.cat([grad.view(-1) for grad in grads]).detach()
        descent_direction = self.conjugate_gradient(obj_grad, states,
                                                    old_action_dists)
        Hd = self.hessian_matrix_vector_product(states, old_action_dists,
                                                descent_direction)
        max_coef = torch.sqrt(2 * self.kl_constraint /
                              (torch.dot(descent_direction, Hd) + 1e-8))
        new_para = self.line_search(states, actions, advantage, old_log_probs,
                                    old_action_dists,
                                    descent_direction * max_coef)
        torch.nn.utils.convert_parameters.vector_to_parameters(
            new_para, self.actor.parameters())

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        rewards = (rewards + 8.0) / 8.0  # 对奖励进行修改,方便训练
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                       dones)
        td_delta = td_target - self.critic(states)
        advantage = compute_advantage(self.gamma, self.lmbda,
                                      td_delta.cpu()).to(self.device)
        mu, std = self.actor(states)
        old_action_dists = torch.distributions.Normal(mu.detach(),
                                                      std.detach())
        old_log_probs = old_action_dists.log_prob(actions)
        critic_loss = torch.mean(
            F.mse_loss(self.critic(states), td_target.detach()))
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        self.policy_learn(states, actions, old_action_dists, old_log_probs,
                          advantage)

In [6]:
num_episodes = 2000
hidden_dim = 128
gamma = 0.9
lmbda = 0.9
critic_lr = 1e-2
kl_constraint = 0.00005
alpha = 0.5
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'Pendulum-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
agent = TRPOContinuous(hidden_dim, env.observation_space, env.action_space,
                       lmbda, kl_constraint, alpha, critic_lr, gamma, device)
return_list = train_on_policy_agent(env, agent, num_episodes)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()

mv_return = moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()

DeprecatedEnv: Env Pendulum-v0 not found (valid versions include ['Pendulum-v1'])