In [1]:
import torch
import torch.nn as nn
from torch.distributions import Normal
import random
import numpy as np
from collections import deque  # 双向队列
import torch.optim as optim
import gym
from torch.distributions import Categorical

import torch.nn.functional as F # 这里面有onehot函数

In [2]:
from torch.utils.tensorboard import SummaryWriter

In [3]:
writer_policy_graph = SummaryWriter('sac_2018_actions/policy_net')
writer_softq_graph = SummaryWriter('sac_2018_actions/SoftQ_net')
writer_value_graph = SummaryWriter('sac_2018_actions/value_net')
writer_scale = SummaryWriter('sac_2018_actions')

In [4]:
class ValueNetwork(nn.Module):
    '''评估状态价值'''
    def __init__(self, input_dim, output_dim):
        super(ValueNetwork, self).__init__()
        self.l1 = nn.Linear(input_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, output_dim)
        
    def forward(self, state):
        x = torch.relu(self.l1(state))
        x = torch.relu(self.l2(x))
        x = self.l3(x)
        return x

class SoftQNetwork(nn.Module):
    '''评估离散动作价值'''
    def __init__(self, num_inputs, num_actions, hidden_size=256):
        super(SoftQNetwork, self).__init__()
        self.l1 = nn.Linear(num_inputs, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, num_actions)
        
    def forward(self, state):
        x = torch.relu(self.l1(state))
        x = torch.relu(self.l2(x))
        x = self.l3(x)
        return x
    
class PolicyNetwork(nn.Module):
    '''Actor输出多个离散动作的probs'''
    def __init__(self, num_inputs, num_actions, hidden_size=256):
        super(PolicyNetwork, self).__init__()
        self.l1 = nn.Linear(num_inputs, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, num_actions)
        
    def forward(self, state):
        '''根据给定的state得到各个离散动作的probs'''
        x = torch.relu(self.l1(state))
        x = torch.relu(self.l2(x))
        x = torch.softmax(self.l3(x), dim=-1)
        return x
    
    def sample(self, state):
        '''根据state得到probs，从而得到action和logporb'''
        action_probs = self.forward(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        action_logprob = dist.log_prob(action)
        
        return action, action_logprob

In [5]:
class BasicBuffer:
    def __init__(self, max_size):
        self.max_size = max_size
        self.buffer = deque(maxlen=max_size)
        
    def push(self, state, action, reward, next_state, done):
        experience = (state, action, np.array([reward]), next_state, done)
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        state_batch = []
        action_batch = []
        reward_batch = []
        next_state_batch = []
        done_batch = []
        
        batch = random.sample(self.buffer, batch_size)
        
        for experience in batch:
            state, action, reward, next_state, done = experience
            state_batch.append(state)
            action_batch.append(action)
            reward_batch.append(reward)
            next_state_batch.append(next_state)
            done_batch.append(done)
            
        return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)

In [6]:
def update_target(model, target_model, tau):
    for target_pam, pam in zip(target_model.parameters(), model.parameters()):
        target_pam.data.copy_((1. - tau) * target_pam + tau * pam)

In [7]:
class SAC_Agent:
    def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.env = env
        self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        
        self.gamma = gamma
        self.tau = tau
        self.update_step = 0
        self.delay_step = 2
        
        ## 初始化网络
        self.policy_net = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device)
        self.q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
        self.value_net = ValueNetwork(self.obs_dim, 1).to(self.device)
        self.target_value_net =  ValueNetwork(self.obs_dim, 1).to(self.device)
        
        # 初始化目标网络的权重
        update_target(self.value_net, self.target_value_net, tau=1.)
        
        ## 优化器
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr)
        
        ## Q网络和V网络的损失函数，都是mse
        self.loss_fn = nn.MSELoss()
        
        # 经验池
        self.replay_buffer = BasicBuffer(buffer_maxlen)
        
        # 用来记录各种标量的字典
        self.summuries = {}
        
    def get_action(self, state):
        '''根据probs选动作，就算概率小也能选到'''
        with torch.no_grad():
            state = torch.FloatTensor(state).to(self.device)
            action,_ = self.policy_net.sample(state)
        return action.detach().cpu().numpy().item()  # 返回0维数组的整形数字
    
    def test_get_action(self, state):
        '''直接根据最大概率选动作，而不是小概率也有可能选中，测试时候使用'''
        with torch.no_grad():
            state = torch.FloatTensor(state).to(self.device)
            action_prob = self.policy_net(state).detach().numpy()
            action = np.argmax(action_prob)  # 返回整形动作驱动环境
        return action

    
    def update(self, batch_size):
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
        states = torch.FloatTensor(states).to(self.device)
        
        actions = torch.tensor(actions).to(self.device)   # 动作是index长整型值，需要tensor化
        
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)  # 单个的Bool值不能转换，但是列表可以
        dones = dones.view(dones.size(0), -1)  # (batch_size, 1) 的形状
        
        # 计算q网络相关的值
        v_actions, v_log_pi = self.policy_net.sample(states)    # [batch_size],[batch_size]
        v_actions_onehot = F.one_hot(v_actions, self.action_dim)# [batch_size, action_dim]

        v_log_pi = v_log_pi.view(-1, 1)                         # [batch_size,1]
        
        v_q1 = self.q_net1(states) * v_actions_onehot           # [batch_size, action_dim]
        v_q1 = torch.sum(v_q1, -1, True)                       # [batch_size, 1]在倒数第一维度相加，保持维度
        v_q2 = self.q_net2(states) * v_actions_onehot
        v_q2 = torch.sum(v_q2, -1, True) 
        next_v = self.target_value_net(next_states)
        
        # value loss 
        v_target = torch.min(v_q1, v_q2) - v_log_pi.exp()       # [bacth_size, 1]
        curr_v = self.value_net(states)                   # [bacth_size, 1]

        v_loss = self.loss_fn(curr_v, v_target.detach()) # 目标值统统detach中断传播
        
        # q_loss and update_qnet
        actions_onehot = F.one_hot(actions, self.action_dim)
        
        curr_q1 = self.q_net1(states) * actions_onehot
        curr_q1 = torch.sum(curr_q1, -1, True)
        curr_q2 = self.q_net2(states) * actions_onehot
        curr_q2 = torch.sum(curr_q2, -1, True)
        expected_q = rewards + (1 - dones) * self.gamma * next_v
        q1_loss = self.loss_fn(curr_q1, expected_q.detach()) # 目标值不需要梯度计算，所以detach终止梯度
        q2_loss = self.loss_fn(curr_q2, expected_q.detach())
        self.summuries['q1_loss'] = q1_loss.detach().item()
        self.summuries['q2_loss'] = q2_loss.detach().item()
        
        self.value_optimizer.zero_grad()
        self.q1_optimizer.zero_grad()
        self.q2_optimizer.zero_grad()
        v_loss.backward()
        q1_loss.backward()
        q2_loss.backward()
        self.value_optimizer.step()
        self.q1_optimizer.step()
        self.q2_optimizer.step()
        
        # 延迟更新policy网络以及q目标网络，用的actions 是 (-1,1)之间的动作值
        new_actions, log_pi = self.policy_net.sample(states)         # [batch_size],[batch_size]
        new_actions_onehot = F.one_hot(new_actions, self.action_dim) # [batch_size],[batch_size]
        log_pi = log_pi.view(-1, 1)                                  # [batch_size,1]
        
        if self.update_step % self.delay_step == 0:
            # 更新 policy网络
            min_q = torch.min(torch.sum(self.q_net1(states) * new_actions_onehot, -1, True), 
                              torch.sum(self.q_net2(states) * new_actions_onehot, -1, True))
            
            policy_loss = (log_pi.exp() - min_q).mean()
            self.summuries['policy_loss'] = policy_loss.detach().item()
            
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()
            
            # 更新目标v网络
            update_target(self.value_net, self.target_value_net, tau=self.tau)
        
        self.update_step += 1


In [8]:
def train(env, agent, max_episode, max_steps, batch_size, render=True):
    global_step = 0
    
    for episode in range(max_episode):
        state = env.reset()
        episode_reward = 0
        episode_step = 0
        
        for step in range(max_steps):
            if render:
                env.render()
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            episode_reward += reward  # 每个回合能获得多少奖励
            global_step += 1          # 全部步数
            episode_step += 1         # 每个回合能走几步
            writer_scale.add_scalar('Main/every_step_reward', reward, global_step) # 每一小步的单步奖励
            
            if len(agent.replay_buffer.buffer) > batch_size:
                agent.update(batch_size)
                writer_scale.add_scalar('Loss/q1_loss', agent.summuries['q1_loss'], global_step)
                writer_scale.add_scalar('Loss/q2_loss', agent.summuries['q1_loss'], global_step)
                writer_scale.add_scalar('Loss/policy_loss',agent.summuries['policy_loss'], global_step)

            if done or step == max_steps - 1:
                writer_scale.add_scalar('Episode/episode_steps', episode_step, episode)
                writer_scale.add_scalar('Episode/episode_rewards', episode_reward, episode)
                print('Episode is {}, episod_reward is {}'.format(episode, episode_reward))
                break
            state = next_state

In [9]:
# env = gym.make('Ant-v2')
env = gym.make('CartPole-v0')

In [10]:
#SAC 2018 Params
tau = 0.1
gamma = 0.99
value_lr = 3e-4
q_lr = 3e-4
policy_lr = 3e-4
buffer_maxlen = 100000

#2018 agent
agent = SAC_Agent(env, gamma, tau, value_lr, q_lr, policy_lr, buffer_maxlen)

In [11]:
test_state_raw = env.reset()
test_state = torch.FloatTensor(test_state_raw).unsqueeze(0).to(agent.device)

In [12]:
writer_policy_graph.add_graph(agent.policy_net, test_state)
writer_softq_graph.add_graph(agent.q_net1, test_state)
writer_value_graph.add_graph(agent.value_net, test_state)

In [13]:
# train
episode_rewards = train(env, agent, 10000, 200, 64, render=True) # 一共是 5000轮，每轮最多1000步

Episode is 0, episod_reward is 19.0
Episode is 1, episod_reward is 26.0
Episode is 2, episod_reward is 20.0
Episode is 3, episod_reward is 35.0
Episode is 4, episod_reward is 12.0
Episode is 5, episod_reward is 18.0
Episode is 6, episod_reward is 38.0
Episode is 7, episod_reward is 26.0
Episode is 8, episod_reward is 31.0
Episode is 9, episod_reward is 17.0
Episode is 10, episod_reward is 13.0
Episode is 11, episod_reward is 11.0
Episode is 12, episod_reward is 22.0
Episode is 13, episod_reward is 12.0
Episode is 14, episod_reward is 10.0
Episode is 15, episod_reward is 31.0
Episode is 16, episod_reward is 20.0
Episode is 17, episod_reward is 18.0
Episode is 18, episod_reward is 13.0
Episode is 19, episod_reward is 79.0
Episode is 20, episod_reward is 19.0
Episode is 21, episod_reward is 16.0
Episode is 22, episod_reward is 27.0
Episode is 23, episod_reward is 31.0
Episode is 24, episod_reward is 21.0
Episode is 25, episod_reward is 32.0
Episode is 26, episod_reward is 10.0
Episode is 

Episode is 219, episod_reward is 12.0
Episode is 220, episod_reward is 11.0
Episode is 221, episod_reward is 23.0
Episode is 222, episod_reward is 18.0
Episode is 223, episod_reward is 16.0
Episode is 224, episod_reward is 15.0
Episode is 225, episod_reward is 20.0
Episode is 226, episod_reward is 24.0
Episode is 227, episod_reward is 21.0
Episode is 228, episod_reward is 13.0
Episode is 229, episod_reward is 11.0
Episode is 230, episod_reward is 13.0
Episode is 231, episod_reward is 19.0
Episode is 232, episod_reward is 17.0
Episode is 233, episod_reward is 20.0
Episode is 234, episod_reward is 15.0
Episode is 235, episod_reward is 35.0
Episode is 236, episod_reward is 47.0
Episode is 237, episod_reward is 9.0
Episode is 238, episod_reward is 27.0
Episode is 239, episod_reward is 17.0
Episode is 240, episod_reward is 31.0
Episode is 241, episod_reward is 24.0
Episode is 242, episod_reward is 15.0
Episode is 243, episod_reward is 28.0
Episode is 244, episod_reward is 14.0
Episode is 24

KeyboardInterrupt: 

In [None]:
env.action_space.low

In [None]:
torch.save(agent.policy_net,'policy_2018.pt')

# 测试

In [None]:
agent.policy_net = torch.load('policy_2018.pt')

In [None]:
with torch.no_grad():
    for i in range(100):
        obs = env.reset()
        for j in range(20000):
            env.render()
            action = agent.test_get_action(obs)
            next_obs, reward, done, _ = env.step(action)
            if done:
                break
            obs = next_obs