In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import gym
from collections import deque

class n_step_replay_buffer(object):
    def __init__(self, capacity, n_step, gamma):
        self.capacity = capacity
        self.n_step = n_step
        self.gamma = gamma
        self.memory = deque(maxlen=self.capacity)
        self.n_step_buffer = deque(maxlen=self.n_step)

    def _get_n_step_info(self):
        reward, next_observation, done = self.n_step_buffer[-1][-3:]
        for _, _, rew, next_obs, do in reversed(list(self.n_step_buffer)[: -1]):
            reward = self.gamma * reward * (1 - do) + rew
            next_observation, done = (next_obs, do) if do else (next_observation, done)
        return reward, next_observation, done

    def store(self, observation, action, reward, next_observation, done):
        observation = np.expand_dims(observation, 0)
        next_observation = np.expand_dims(next_observation, 0)

        self.n_step_buffer.append([observation, action, reward, next_observation, done])
        if len(self.n_step_buffer) < self.n_step:
            return
        reward, next_observation, done = self._get_n_step_info()
        observation, action = self.n_step_buffer[0][: 2]
        self.memory.append([observation, action, reward, next_observation, done])

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        observation, action, reward, next_observation, done = zip(* batch)
        return np.concatenate(observation, 0), action, reward, np.concatenate(next_observation, 0), done

    def __len__(self):
        return len(self.memory)




https://github.com/deligentfool/dqn_zoo/blob/master/N_step%20DQN/n_step_dqn.py


In [2]:
class ddqn(nn.Module):
    def __init__(self, observation_dim, action_dim):
        super(ddqn, self).__init__()
        self.observation_dim = observation_dim
        self.action_dim = action_dim

        self.fc1 = nn.Linear(self.observation_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, self.action_dim)

    def forward(self, observation):
        x = F.relu(self.fc1(observation))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def actB(self, observation, epsilon = 0.3):
        A = np.ones(self.action_dim, dtype=float) * (epsilon/self.action_dim)
        temp = self.forward(observation)
        best_action = temp.max(1)[1].data[0].item()
        A[best_action] += 1.0 - epsilon
       
        return A
    
    def actPi(self, observation, epsilon= 0.1):
        A = np.ones(self.action_dim, dtype=float) * (epsilon/self.action_dim)
        temp = self.forward(observation)
        best_action = temp.max(1)[1].data[0].item()
        A[best_action] += 1.0 - epsilon
       
        return A

def train(buffer, target_model, eval_model, gamma, optimizer, batch_size, loss_fn, count, soft_update_freq, n_step):
    observation, action, reward, next_observation, done = buffer.sample(batch_size)

    observation = torch.FloatTensor(observation)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    next_observation = torch.FloatTensor(next_observation)
    done = torch.FloatTensor(done)

    q_values = eval_model.forward(observation)
    next_q_values = target_model.forward(next_observation)
    argmax_actions = eval_model.forward(next_observation).max(1)[1].detach()
    next_q_value = next_q_values.gather(1, argmax_actions.unsqueeze(1)).squeeze(1)
    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    expected_q_value = reward + (gamma ** n_step) * (1 - done) * next_q_value

    #loss = loss_fn(q_value, expected_q_value.detach())
    loss = (expected_q_value.detach() - q_value).pow(2)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if count % soft_update_freq == 0:
        target_model.load_state_dict(eval_model.state_dict())




In [3]:
if __name__ == '__main__':
    gamma = 0.99
    learning_rate = 1e-3
    batch_size = 64
    soft_update_freq = 200
    capacity = 10000
    exploration = 100
    decay = 0.99
    episode = 500
    n_step = 4
    render = False
    sigma = 0.5
    
    env = gym.make('CartPole-v0')
    env = env.unwrapped
    observation_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    target_net = ddqn(observation_dim, action_dim)
    eval_net = ddqn(observation_dim, action_dim)
    eval_net.load_state_dict(target_net.state_dict())
    optimizer = torch.optim.Adam(eval_net.parameters(), lr=learning_rate)
    buffer = n_step_replay_buffer(capacity, n_step, gamma)
    loss_fn = nn.MSELoss()
    count = 0

    weight_reward = None
    for i in range(episode):
        obs = env.reset()
        T = np.inf
        t = 0
        tau = 0
        stored_actions = {}
        stored_states = {}
        stored_rewards = {}
        stored_pho = {}
        stored_sigma={}
        reward_total = 0

        b_prob = eval_net.actB(torch.FloatTensor(np.expand_dims(obs, 0)))
        pi_prob = eval_net.actPi(torch.FloatTensor(np.expand_dims(obs, 0)))
        action = np.random.choice(list(range(action_dim)),p=b_prob)
        pho = pi_prob[action] / b_prob[action]
        
        stored_actions[0] = action
        stored_states[0] = obs
        stored_pho[0] = pho
        stored_rewards[0] = 0
        
        if render:
            env.render()
        while True:
            
            if t < T:
                next_obs, reward, done, info = env.step(action)
                reward_total += reward
                stored_states[(t+1) % (n_step+1)] = next_obs
                stored_rewards[(t+1) % (n_step+1)] = reward
                buffer.store(obs, action, reward, next_obs, done)
                obs = next_obs
                if done:
                    T= t+1
                else:
                    b_prob = eval_net.actB(torch.FloatTensor(np.expand_dims(obs, 0)))
                    pi_prob = eval_net.actPi(torch.FloatTensor(np.expand_dims(obs, 0)))
                    action = np.random.choice(list(range(action_dim)),p=b_prob)
                    pho = pi_prob[action] / b_prob[action]
                    stored_actions[(t+1)% (n_step+1)] = action
                    stored_pho[(t+1)% (n_step+1)] = pho
                    
            tau = t - n_step + 1
            if tau >= 0:
                if t + 1 < T:
                    q_actions = torch.FloatTensor(stored_states[(t+1)% (n_step+1)])
                    q = eval_net.forward(q_actions).detach()
                    G = q[stored_actions[(t+1)% (n_step+1)]]
                    
                for k in range(min(t+1, T), tau, -1):
                    if k == T:
                        G = stored_rewards[T% (n_step+1)]
                    else:
                        s_k = stored_states[k% (n_step+1)]
                        a_k = stored_actions[k% (n_step+1)]
                        r_k = stored_rewards[k% (n_step+1)]
                        pho_k = stored_pho[k% (n_step+1)]
                        
                        b_prob = eval_net.actB(torch.FloatTensor(np.expand_dims(s_k, 0)))
                        pi_prob = eval_net.actPi(torch.FloatTensor(np.expand_dims(s_k, 0)))
                        q_net = torch.FloatTensor(s_k)
                        q = eval_net.forward(q_net).detach()
                        
                        VBar = np.sum([(pi_prob[a]) * q[a] for a in range(action_dim)])
                        G = r_k + gamma * ((sigma * pho_k) + ((1-sigma) * pi_prob[a_k])) * (G - q[a_k])+ gamma * VBar

                s_tau = stored_states[tau% (n_step+1)]
                a_tau = stored_actions[tau% (n_step+1)]
                observationQ = torch.FloatTensor(s_tau)
                
                q_actions = torch.FloatTensor(s_tau)
                q = eval_net.forward(q_net).detach()
                
                q[a_tau] = q[a_tau] + learning_rate * (G - q[a_tau])
         
               
            if i > exploration:
                train(buffer, target_net, eval_net, gamma, optimizer, batch_size, loss_fn, count, soft_update_freq, n_step)
            
            if tau >= (T-1):
                if not weight_reward:
                    weight_reward = reward_total
                else:
                    weight_reward = 0.99 * weight_reward + 0.01 * reward_total
                print('episode: {}  reward: {}  weight_reward: {:.3f}'.format(i+1, reward_total, weight_reward))
                break
            else:
                t = t + 1

episode: 1  reward: 13.0  weight_reward: 13.000
episode: 2  reward: 9.0  weight_reward: 12.960
episode: 3  reward: 10.0  weight_reward: 12.930
episode: 4  reward: 12.0  weight_reward: 12.921
episode: 5  reward: 11.0  weight_reward: 12.902
episode: 6  reward: 11.0  weight_reward: 12.883
episode: 7  reward: 10.0  weight_reward: 12.854
episode: 8  reward: 12.0  weight_reward: 12.845
episode: 9  reward: 11.0  weight_reward: 12.827
episode: 10  reward: 12.0  weight_reward: 12.819
episode: 11  reward: 9.0  weight_reward: 12.781
episode: 12  reward: 9.0  weight_reward: 12.743
episode: 13  reward: 12.0  weight_reward: 12.735
episode: 14  reward: 13.0  weight_reward: 12.738
episode: 15  reward: 12.0  weight_reward: 12.731
episode: 16  reward: 10.0  weight_reward: 12.703
episode: 17  reward: 12.0  weight_reward: 12.696
episode: 18  reward: 16.0  weight_reward: 12.729
episode: 19  reward: 11.0  weight_reward: 12.712
episode: 20  reward: 9.0  weight_reward: 12.675
episode: 21  reward: 10.0  weight

episode: 168  reward: 94.0  weight_reward: 23.328
episode: 169  reward: 90.0  weight_reward: 23.995
episode: 170  reward: 96.0  weight_reward: 24.715
episode: 171  reward: 102.0  weight_reward: 25.488
episode: 172  reward: 90.0  weight_reward: 26.133
episode: 173  reward: 117.0  weight_reward: 27.042
episode: 174  reward: 103.0  weight_reward: 27.801
episode: 175  reward: 41.0  weight_reward: 27.933
episode: 176  reward: 95.0  weight_reward: 28.604
episode: 177  reward: 106.0  weight_reward: 29.378
episode: 178  reward: 125.0  weight_reward: 30.334
episode: 179  reward: 26.0  weight_reward: 30.291
episode: 180  reward: 27.0  weight_reward: 30.258
episode: 181  reward: 98.0  weight_reward: 30.935
episode: 182  reward: 114.0  weight_reward: 31.766
episode: 183  reward: 80.0  weight_reward: 32.248
episode: 184  reward: 128.0  weight_reward: 33.206
episode: 185  reward: 94.0  weight_reward: 33.814
episode: 186  reward: 97.0  weight_reward: 34.446
episode: 187  reward: 108.0  weight_reward:

episode: 329  reward: 408.0  weight_reward: 129.745
episode: 330  reward: 482.0  weight_reward: 133.268
episode: 331  reward: 552.0  weight_reward: 137.455
episode: 332  reward: 1650.0  weight_reward: 152.581
episode: 333  reward: 691.0  weight_reward: 157.965
episode: 334  reward: 228.0  weight_reward: 158.665
episode: 335  reward: 64.0  weight_reward: 157.718
episode: 336  reward: 2978.0  weight_reward: 185.921
episode: 337  reward: 177.0  weight_reward: 185.832
episode: 338  reward: 216.0  weight_reward: 186.134
episode: 339  reward: 627.0  weight_reward: 190.542
episode: 340  reward: 356.0  weight_reward: 192.197
episode: 341  reward: 299.0  weight_reward: 193.265
episode: 342  reward: 528.0  weight_reward: 196.612
episode: 343  reward: 1621.0  weight_reward: 210.856
episode: 344  reward: 1866.0  weight_reward: 227.408
episode: 345  reward: 253.0  weight_reward: 227.664
episode: 346  reward: 113.0  weight_reward: 226.517
episode: 347  reward: 225.0  weight_reward: 226.502
episode: 

episode: 487  reward: 287.0  weight_reward: 370.141
episode: 488  reward: 761.0  weight_reward: 374.050
episode: 489  reward: 209.0  weight_reward: 372.399
episode: 490  reward: 154.0  weight_reward: 370.215
episode: 491  reward: 139.0  weight_reward: 367.903
episode: 492  reward: 166.0  weight_reward: 365.884
episode: 493  reward: 150.0  weight_reward: 363.725
episode: 494  reward: 133.0  weight_reward: 361.418
episode: 495  reward: 142.0  weight_reward: 359.224
episode: 496  reward: 269.0  weight_reward: 358.322
episode: 497  reward: 841.0  weight_reward: 363.148
episode: 498  reward: 503.0  weight_reward: 364.547
episode: 499  reward: 206.0  weight_reward: 362.961
episode: 500  reward: 414.0  weight_reward: 363.472
