In [2]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from DQN import qnet
from ReplayMemory import ReplayMemory

In [3]:
REPLAY_MEMORY = 10000
BATCH_SIZE = 128
LR = 1e-4
GAMMA = 0.99

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

N_ATOM = 51
V_MIN = -10.
V_MAX = 10.
V_STEP = ((V_MAX-V_MIN)/(N_ATOM-1))
V_RANGE = np.linspace(V_MIN, V_MAX, 51) # this is why C51

In [5]:
exploration_rate = 1
max_exploration_rate = 1
min_exploration_rate = 0.01
exploration_decay_rate = 0.01

In [6]:
TARGET_UPDATE_FREQUENCY = 20

In [8]:
class C51(nn.Module):
    N_ATOM = 51
    REPLAY_MEMORY = 10000
    LR = 1e-4
    VMIN = -10.
    VMAX = 10.
    TARGET_UPDATE_FREQ = 20
    
    V_RANGE = np.linspace(V_MIN, V_MAX, N_ATOM)
    
    def __init__(self, inputs, outputs):
        super(C51, self).__init__()
        
        self.state_size = inputs
        self.action_size = outputs
        
        self.target_net, self.pred_net = qnet(inputs, outputs * N_ATOM), qnet(inputs, outputs * N_ATOM)
        self.update_target()
        
        self.pred_net.to(device) # 확률값들
        self.target_net.to(device) # 확률값들
        
        self.memory_counter = 0
        self.learn_step = 0
        
        self.replay_buffer = ReplayMemory(REPLAY_MEMORY)
        
        self.optimizer = torch.optim.Adam(self.pred_net.parameters(), lr = LR)
        
        self.value_range = torch.FloatTensor(V_RANGE).to(device)
        
    def store_transition(self, s, a, r, n_s, d):
        self.memory_counter += 1
        self.replay_buffer.push(s, a, r, n_s, d)
    
    def pick_action(self, x, ep):
        x = torch.FloatTensor(x)
        x = x.to(device)
        
        if np.random.uniform() > ep:
            action_value_dist = self.pred_net(x)
            
            # need to repair <- not good version
            action_value = action_value_dist * np.r_[self.value_range, self.value_range]
            action_value = action_value.view(self.action_size, -1)
            action_value = torch.sum(action_value, axis=1)
            
            action = torch.argmax(action_value).data.cpu().numpy()
        
        return x
    
    def learn(self):
        self.learn_step += 1
        if self.learn_step % TARGET_UPDATE_FREQ == 0:
            self.update_target()
            
        
        
        pass
    
    def update_target(self):
        self.target_net.load_state_dict(self.pred_net.state_dict())