In [2]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from DQN import qnet
from ReplayMemory import ReplayMemory

In [3]:
REPLAY_MEMORY = 10000
BATCH_SIZE = 128
LR = 1e-4
GAMMA = 0.99

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

N_ATOM = 51
V_MIN = -10.
V_MAX = 10.
V_STEP = ((V_MAX-V_MIN)/(N_ATOM-1))
V_RANGE = np.linspace(V_MIN, V_MAX, 51) # this is why C51

In [5]:
exploration_rate = 1
max_exploration_rate = 1
min_exploration_rate = 0.01
exploration_decay_rate = 0.01

In [6]:
TARGET_UPDATE_FREQUENCY = 20

In [8]:
class C51(nn.Module):
    N_ATOM = 51
    REPLAY_MEMORY = 10000
    LR = 1e-4
    VMIN = -10.
    VMAX = 10.
    TARGET_UPDATE_FREQ = 20
    BATCH_SIZE = 128
    V_STEP = ((V_MAX-V_MIN)/(N_ATOM-1))
    
    V_RANGE = np.linspace(V_MIN, V_MAX, N_ATOM)
    
    def __init__(self, inputs, outputs):
        super(C51, self).__init__()
        
        self.state_size = inputs
        self.action_size = outputs
        
        self.target_net, self.pred_net = qnet(outputs), qnet(outputs)
        self.update_target()
        
        self.pred_net.to(device) # 확률값들
        self.target_net.to(device) # 확률값들
        
        self.memory_counter = 0
        self.learn_step = 0
        
        self.replay_buffer = ReplayMemory(REPLAY_MEMORY)
        
        self.optimizer = torch.optim.Adam(self.pred_net.parameters(), lr = LR)
        
        self.value_range = torch.FloatTensor(V_RANGE).to(device)
        
    def store_transition(self, s, a, r, n_s, d):
        self.memory_counter += 1
        self.replay_buffer.push(s, a, r, n_s, d)
    
    def pick_action(self, x, ep):
        x = torch.FloatTensor(x)
        x = x.to(device)
        
        if np.random.uniform() > ep:
            action_value_dist = self.pred_net(x)
            
            # need to repair <- not good version
            action_value = torch.sum(action_value_dist * self.value_range.view(1,1,-1), dim=2)
            # action_value = torch.sum(action_value, axis=1)
            action = torch.argmax(action_value).data.cpu().numpy()
            
        else:
            action = np.random.randint(0, self.action_size, (x.size(0)))
            
        
        return x
    
    def learn(self):
        self.learn_step += 1
        if self.learn_step % TARGET_UPDATE_FREQ == 0:
            self.update_target()
            
        minibatch = self.replay_buffer.sample(BATCH_SIZE)
        b_w, b_idxes = np.ones_like(b_r), None
        
        b_s = [x[0] for x in minibatch]
        b_a = [x[1] for x in minibatch]
        b_r = [x[2] for x in minibatch]
        b_ns = [x[3] for x in minibatch]
        b_d = [x[4] for x in minibatch]
        
        b_s = torch.FloatTensor(b_s).to(device)
        b_a = torch.LongTensor(b_a).to(device)
        b_ns = torch.FloatTensor(b_ns).to(device)
        
        q_eval = self.pred_net(b_s)
        mb_size = q_eval.size(0)
        q_eval = torch.stack([q_eval[i].index_select(0, b_a[i]) for i in range(mb_size)]).squeeze(1)
        
        q_target = np.zeros((mb_size, N_ATOM))
        
        q_next = self.target_net(b_ns).detach()
        q_next_mean = torch.sum(q_next * self.value_rage.view(1, 1, -1), dim=2)
        action_star = q_next_mean.argmax(dim=1)
        q_next = torch.stack([q_next[i].index_select(0, action_star[i]) for i in range(mb_size)]).squeeze(1)
        
        target_value = np.expand_dims(b_r, 1) + GAMMA * np.expands_dims((1.0 - b_d), 1) * np.expand_dims(self.value_range.data.cpu().numpy(), 0)
        target_pos = np.zeros_like(target_value)
        
        target_value = np.clip(target_value, V_MIN, V_MAX)
        target_pos = (target_value - V_MIN) / V_STEP
        
        lb = np.floor(target_pos).astype(int)
        ub = np.ceil(next_v_pos).astype(int)
        
        for i in range(mb_size):
            for j in range(N_ATOM):
                q_target[i, lb[i, j]] += (q_next * (ub - target_pos))[i, j]
                q_target[i, ub[i, j]] += (q_next * (target_pos - lb))[i, j]
        
        q_target = torch.FloatTensor(q_target).to(device)
        
        loss = q_target * (-torch.log(q_eval + 1e-8))
        loss = torch.mean(loss)
        
        b_w = torch.Tensor(b_w).to(device)
        loss = torch.mean(b_w * loss)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        
        pass
    
    def update_target(self):
        self.target_net.load_state_dict(self.pred_net.state_dict())