In [None]:
import torch 
import torch.nn as nn 
import torch.autograd as autograd 
import numpy as np 
import random 
from collections import deque
import gym

from collections import deque
import itertools

In [None]:
class DQN(nn.Module):
    
    def __init__(self, env, output_dim):
        super().__init__()

        in_features = int(np.prod(env.observation_space.shape))

        self.net = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.Tanh(),
            nn.Linear(64, env.action_space.n)
        )

    def forward(self, state):
        qvals = self.net(state)
        return qvals

    def act(self, state):
        state_t = torch.as_tensor(state, dtype=torch.float32)
        q_values = self.forward(state_t.unsqueeze(0))                           # 'q_values' outputs two values (left or right)
        max_q_index = torch.argmax(q_values, dim=1)[0]                          # find an index that corresponds to the maximum value  
        action = max_q_index.detach().item()                                    # 0 or 1
        return action    

In [None]:
class SumTree():
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros( 2*capacity - 1 )
        self.data = np.zeros( capacity, dtype=object )
        self.n_entries = 0

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])

    def total(self):
        return self.tree[0]

    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])

In [None]:
class PrioritizedBuffer:

    def __init__(self, max_size):
        self.sum_tree = SumTree(max_size)
        self.current_length = 0

    def push(self, state, action, reward, next_state, done):

        priority = 1.0 if self.current_length is 0 else self.sum_tree.tree.max()# 이 부분이 뭔가 이상하다,,,
        self.current_length = self.current_length + 1                           # current_length는 현재 PrioritizedBuffer의 크기를 추적한다. 
        
        experience = (state, action, np.array([reward]), next_state, done)      # priority = td_error ** self.alpha 
        self.sum_tree.add(priority, experience)

    def sample(self, batch_size):
        batch_idx, batch, priorities = [], [], []
        
        segment = self.sum_tree.total() / batch_size

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.sum_tree.get(s)
            priorities.append(p)
            batch.append(data)
            batch_idx.append(idx)

        sampling_probabilites = priorities / (self.sum_tree.total() + EP)
        IS_weights = np.power(self.sum_tree.n_entries * sampling_probabilites, -BETA)
        IS_weights /= IS_weights.max()
        # len(IS_weights) == 32

        state_batch = []
        action_batch = []
        reward_batch = []
        next_state_batch = []
        done_batch = []

        for transition in batch:
            state, action, reward, next_state, done = transition
            state_batch.append(state)
            action_batch.append(action)
            reward_batch.append(reward)
            next_state_batch.append(next_state)
            done_batch.append(done)

        return (state_batch, action_batch, reward_batch, next_state_batch, done_batch), batch_idx, IS_weights

    def update_priority(self, idx, td_error):
        priority = (np.abs(td_error) + EP) ** ALPHA                             # add epsilon to priority value
        self.sum_tree.update(idx, priority)

    def __len__(self):
        return self.current_length

In [None]:
BATCH_SIZE = 32
BUFFER_SIZE = 10000
REPLAY_BUFFER = PrioritizedBuffer(BUFFER_SIZE)
LEARNING_RATE = 5e-4
GAMMA = 0.99 

# annealing hyper-parameters
EPSILON_START = 1.0
EPSILON_END = 0.02
EPSILON_DECAY = 20000

ALPHA = 0.7

BETA = 0
BETA_START = 0.5
BETA_END = 1.0
BETA_ANNEAL = 20000
# annealing hyper-parameters

EP = 1e-6                                                                       # division by 0 방지 

EPISODE_REWARD = 0.0
REWARD_BUFFER  = deque([0.0], maxlen=100)     
MIN_REPLAY_SIZE = 100                                  

env_id = "CartPole-v1"
env = gym.make(env_id)

online_net = DQN(env, env.action_space.n)                                       # DDQN implementation
target_net = DQN(env, env.action_space.n)                                       # DDQN implementation

target_net.load_state_dict(online_net.state_dict())                             # target_net parameters <- online_net parameters

optimizer = torch.optim.Adam(online_net.parameters(), lr=LEARNING_RATE)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

In [None]:
state = env.reset()

for step in itertools.count():  
    
    epsilon = np.interp(step, [0, EPSILON_DECAY], [EPSILON_START, EPSILON_END]) # annealing beta from EPSILON_START to EPSILON_END
    BETA    = np.interp(step, [0, BETA_ANNEAL],   [BETA_START, BETA_END])       # annealing beta from BETA_START to BETA_END

    random_sample = random.random()
    
    if random_sample <= epsilon:
        action = env.action_space.sample()
    else:
        action = online_net.act(state)

    next_state, reward, done, _ = env.step(action)
    REPLAY_BUFFER.push(state, action, reward, next_state, done)

    state = next_state
    EPISODE_REWARD += reward

    if done:
        state = env.reset()
        REWARD_BUFFER.append(EPISODE_REWARD)
        EPISODE_REWARD = 0.0

    # burn in steps. The length of REPLAY_BUFFER should be greater than the BATCH_SIZE.
    # This is because we are going to make BATCH from REPLAY_BUFFER
    if len(REPLAY_BUFFER) > BATCH_SIZE:

        # sample transitions from REPLAY_BUFFER
        transitions, idxs, IS_weights = REPLAY_BUFFER.sample(BATCH_SIZE)
        states, actions, rewards, next_states, dones = transitions

        # preprocess the followings: states, actions, rewards, next_states, dones, IS_weights
        states = torch.FloatTensor(np.array(states)).to(device)
        actions = torch.LongTensor(np.array(actions)).to(device)
        rewards = torch.FloatTensor(np.array(rewards)).to(device)
        next_states = torch.FloatTensor(np.array(next_states)).to(device)
        dones = torch.FloatTensor(np.array(dones)).to(device)
        IS_weights = torch.FloatTensor(np.array(IS_weights)).to(device)

        # reshape tensors to appropriate formats
        states  = torch.reshape(states,  (BATCH_SIZE, 4))
        actions = torch.reshape(actions, (BATCH_SIZE, 1))
        rewards = torch.reshape(rewards, (BATCH_SIZE, 1))
        next_states = torch.reshape(next_states, (BATCH_SIZE, 4))
        dones = torch.reshape(dones, (BATCH_SIZE, 1))
        IS_weights = torch.reshape(IS_weights, (BATCH_SIZE, 1))

        # Compute Targets
        online_with_new_states = online_net.forward(next_states)
        argmax_online_with_new_states = online_with_new_states.argmax(dim=1, keepdim=True)

        offline_with_new_states = target_net.forward(next_states)
        target_q_vals = torch.gather(input=offline_with_new_states, dim=1, index=argmax_online_with_new_states)
        targets = rewards + GAMMA * (1 - dones) * target_q_vals  

        # Compute Loss
        q_values = online_net.forward(states)
        action_q_values = torch.gather(input=q_values, dim=1, index=actions)  
        errors = torch.abs(action_q_values - targets).data.numpy()              # errors == TD error의 절댓값

        td_errors = torch.pow(action_q_values - targets, 2) * IS_weights        # MSE Loss
        td_errors_mean = td_errors.mean()

        # Update model
        optimizer.zero_grad()
        td_errors_mean.backward()
        optimizer.step()

        # update priorities
        for idx, error in zip(idxs, errors):
            REPLAY_BUFFER.update_priority(idx, error)                           # TD error의 절댓값으로 update해준다. 

        if step % 1000 == 0:
            target_net.load_state_dict(online_net.state_dict())        

        if step % 1000 == 0:
            print()
            print('Step', step)
            print('Avg Reward', np.mean(REWARD_BUFFER))                         # maximum length of reward_buffer is 100. Therefore, np.mean(reward_buffer) averages lastest 100 rewards
            print('Loss', td_errors_mean)
            print('BETA', BETA)
            print('ALPHA', ALPHA)


Step 1000
Avg Reward 18.846153846153847
Loss tensor(0.0037, grad_fn=<MeanBackward0>)
BETA 0.525
ALPHA 0.7

Step 2000
Avg Reward 20.418367346938776
Loss tensor(0.0171, grad_fn=<MeanBackward0>)
BETA 0.55
ALPHA 0.7

Step 3000
Avg Reward 23.71
Loss tensor(0.0997, grad_fn=<MeanBackward0>)
BETA 0.575
ALPHA 0.7

Step 4000
Avg Reward 26.17
Loss tensor(0.0984, grad_fn=<MeanBackward0>)
BETA 0.6
ALPHA 0.7

Step 5000
Avg Reward 30.41
Loss tensor(0.1591, grad_fn=<MeanBackward0>)
BETA 0.625
ALPHA 0.7


KeyboardInterrupt: ignored