In [None]:
import sys                                   # Pro pripojeni knihovny Open AI Gym
sys.path.append('/home/xbucha02/libraries')  # Pripojeni knihovny Open AI Gym
import gym
import pylab
import random
import numpy as np
from collections import deque
from keras.layers import Dense
from keras.optimizers import Adam
from keras.models import Sequential

EPISODES = 4000


class DQNAgent:
    def __init__(self, state_size, action_size):
        # Cartpole이 학습하는 것을 보려면 "True"로 바꿀 것
        self.render = True

        # state와 action의 크기를 가져와서 모델을 생성하는데 사용함
        self.state_size = state_size
        self.action_size = action_size

        # Cartpole DQN 학습의 Hyper parameter 들
        # deque를 통해서 replay memory 생성
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        self.epsilon = 1.0
        self.epsilon_min = 0.005
        self.epsilon_decay = (self.epsilon - self.epsilon_min) / 50000
        self.batch_size = 64
        self.train_start = 1000
        self.memory = deque(maxlen=10000)

        # 학습할 모델과 타겟 모델을 생성
        self.model = self.build_model()
        self.target_model = self.build_model()
        # 학습할 모델을 타겟 모델로 복사 --> 타겟 모델의 초기화(weight를 같게 해주고 시작해야 함)
        self.update_target_model()

    # Deep Neural Network를 통해서 Q Function을 근사
    # state가 입력, 각 행동에 대한 Q Value가 출력인 모델을 생성
    def build_model(self):
        model = Sequential()
        model.add(Dense(32, input_dim=self.state_size, activation='relu', kernel_initializer='he_uniform'))
        model.add(Dense(16, activation='relu', kernel_initializer='he_uniform'))
        model.add(Dense(self.action_size, activation='linear', kernel_initializer='he_uniform'))
        model.summary()
        model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
        return model

    # 일정한 시간 간격마다 타겟 모델을 현재 학습하고 있는 모델로 업데이트
    def update_target_model(self):
        self.target_model.set_weights(self.model.get_weights())

    # 행동의 선택은 현재 네트워크에 대해서 epsilon-greedy 정책을 사용
    def get_action(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        else:
            q_value = self.model.predict(state)
            return np.argmax(q_value[0])

    # <s,a,r,s'>을 replay_memory에 저장함
    def replay_memory(self, state, action, reward, next_state, done):
        if action == 2:
            action = 1
        self.memory.append((state, action, reward, next_state, done))
        if self.epsilon > self.epsilon_min:
            self.epsilon -= self.epsilon_decay
        # print(len(self.memory))

    # replay memory에서 batch_size 만큼의 샘플들을 무작위로 뽑아서 학습
    def train_replay(self):
        if len(self.memory) < self.train_start:
            return
        batch_size = min(self.batch_size, len(self.memory))
        mini_batch = random.sample(self.memory, batch_size)

        update_input = np.zeros((batch_size, self.state_size))
        update_target = np.zeros((batch_size, self.action_size))

        for i in range(batch_size):
            state, action, reward, next_state, done = mini_batch[i]
            target = self.model.predict(state)[0]

            # 큐러닝에서와 같이 s'에서의 최대 Q Value를 가져옴. 단, 타겟 모델에서 가져옴
            if done:
                target[action] = reward
            else:
                target[action] = reward + self.discount_factor * \
                                          np.amax(self.target_model.predict(next_state)[0])
            update_input[i] = state
            update_target[i] = target

        # 학습할 정답인 타겟과 현재 자신의 값의 minibatch를 만들고 그것으로 한 번에 모델 업데이트
        self.model.fit(update_input, update_target, batch_size=batch_size, epochs=1, verbose=0)

    # 저장한 모델을 불러옴
    def load_model(self, name):
        self.model.load_weights(name)

    # 학습된 모델을 저장함
    def save_model(self, name):
        self.model.save_weights(name)


if __name__ == "__main__":
    # CartPole-v1의 경우 500 타임스텝까지 플레이가능
    env = gym.make('MountainCar-v0')
    # 환경으로부터 상태와 행동의 크기를 가져옴
    state_size = env.observation_space.shape[0]
    #action_size = env.action_space.n
    action_size = 2
    # DQN 에이전트의 생성
    agent = DQNAgent(state_size, action_size)
#    agent.load_model("./save_model/MountainCar_DQN.h5")
    scores, episodes = [], []

    for e in range(EPISODES):
        done = False
        score = 0
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        print(state)

        # 액션 0(좌), 1(아무것도 안함), 3(아무것도 하지 않는 액션을 하지 않기 위한 fake_action 선언
        fake_action = 0

        # 같은 액션을 4번하기 위한 카운터
        action_count = 0

        while not done:
            #if agent.render:
                #env.render()

            # 현재 상태에서 행동을 선택하고 한 스텝을 진행
            action_count = action_count + 1

            if action_count == 4:
                action = agent.get_action(state)
                action_count = 0

                if action == 0:
                    fake_action = 0
                elif action == 1:
                    fake_action = 2

            # 선택한 액션으로 1 step을 시행한다
            next_state, reward, done, info = env.step(fake_action)
            next_state = np.reshape(next_state, [1, state_size])
            # 에피소드를 끝나게 한 행동에 대해서 -100의 패널티를 줌
            #reward = reward if not done else -100

            # <s, a, r, s'>을 replay memory에 저장
            agent.replay_memory(state, fake_action, reward, next_state, done)
            # 매 타임스텝마다 학습을 진행
            agent.train_replay()
            score += reward
            state = next_state

            if done:
                env.reset()
                # 매 에피소드마다 학습하는 모델을 타겟 모델로 복사
                agent.update_target_model()

                # 각 에피소드마다 cartpole이 서있었던 타임스텝을 plot
                scores.append(score)
                episodes.append(e)
                #pylab.plot(episodes, scores, 'b')
                #pylab.savefig("./save_graph/MountainCar_DQN.png")
                print("episode:", e, "  score:", score, "  memory length:", len(agent.memory),
                      "  epsilon:", agent.epsilon)

        # 50 에피소드마다 학습 모델을 저장
        if e % 50 == 0:
             agent.save_model("./MountainCar_DQN.h5")

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_19 (Dense)             (None, 32)                96        
_________________________________________________________________
dense_20 (Dense)             (None, 16)                528       
_________________________________________________________________
dense_21 (Dense)             (None, 2)                 34        
Total params: 658
Trainable params: 658
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_22 (Dense)             (None, 32)                96        
_________________________________________________________________
dense_23 (Dense)             (None, 16)                528       
_________________________________________________________________
dense_24 (De

('episode:', 54, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.7811000000004222)
[[-0.53214119  0.        ]]
('episode:', 55, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.7771200000004299)
[[-0.40884064  0.        ]]
('episode:', 56, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.7731400000004376)
[[-0.5217874  0.       ]]
('episode:', 57, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.7691600000004453)
[[-0.57531798  0.        ]]
('episode:', 58, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.7651800000004529)
[[-0.59851152  0.        ]]
('episode:', 59, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.7612000000004606)
[[-0.49021063  0.        ]]
('episode:', 60, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.7572200000004683)
[[-0.5312937  0.       ]]
('episode:', 61, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.753240000000476)
[[-0.47410541  0.        

('episode:', 119, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.5224000000009212)
[[-0.56492163  0.        ]]
('episode:', 120, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.5184200000009289)
[[-0.42122326  0.        ]]
('episode:', 121, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.5144400000009366)
[[-0.43625846  0.        ]]
('episode:', 122, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.5104600000009443)
[[-0.53485142  0.        ]]
('episode:', 123, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.506480000000952)
[[-0.43803852  0.        ]]
('episode:', 124, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.5025000000009596)
[[-0.47603089  0.        ]]
('episode:', 125, '  score:', -179.0, '  memory length:', 10000, '  epsilon:', 0.4989379000009635)
[[-0.4174807  0.       ]]
('episode:', 126, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.49495790000096007)
[[-0.58206984 

('episode:', 184, '  score:', -137.0, '  memory length:', 10000, '  epsilon:', 0.2733913000007694)
[[-0.58546991  0.        ]]
('episode:', 185, '  score:', -160.0, '  memory length:', 10000, '  epsilon:', 0.27020730000076665)
[[-0.40841  0.     ]]
('episode:', 186, '  score:', -170.0, '  memory length:', 10000, '  epsilon:', 0.26682430000076374)
[[-0.54328018  0.        ]]
('episode:', 187, '  score:', -153.0, '  memory length:', 10000, '  epsilon:', 0.2637796000007611)
[[-0.53933161  0.        ]]
('episode:', 188, '  score:', -122.0, '  memory length:', 10000, '  epsilon:', 0.261351800000759)
[[-0.55871676  0.        ]]
('episode:', 189, '  score:', -137.0, '  memory length:', 10000, '  epsilon:', 0.2586255000007567)
[[-0.50622793  0.        ]]
('episode:', 190, '  score:', -150.0, '  memory length:', 10000, '  epsilon:', 0.2556405000007541)
[[-0.49228777  0.        ]]
('episode:', 191, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.2516605000007507)
[[-0.48213633  0.

('episode:', 249, '  score:', -170.0, '  memory length:', 10000, '  epsilon:', 0.0715456000008073)
[[-0.56195568  0.        ]]
('episode:', 250, '  score:', -158.0, '  memory length:', 10000, '  epsilon:', 0.06840140000080679)
[[-0.51031912  0.        ]]
('episode:', 251, '  score:', -168.0, '  memory length:', 10000, '  epsilon:', 0.06505820000080624)
[[-0.55849925  0.        ]]
('episode:', 252, '  score:', -158.0, '  memory length:', 10000, '  epsilon:', 0.06191400000080573)
[[-0.58906754  0.        ]]
('episode:', 253, '  score:', -149.0, '  memory length:', 10000, '  epsilon:', 0.05894890000080524)
[[-0.50164251  0.        ]]
('episode:', 254, '  score:', -154.0, '  memory length:', 10000, '  epsilon:', 0.05588430000080474)
[[-0.47912371  0.        ]]
('episode:', 255, '  score:', -167.0, '  memory length:', 10000, '  epsilon:', 0.0525610000008042)
[[-0.54053361  0.        ]]
('episode:', 256, '  score:', -157.0, '  memory length:', 10000, '  epsilon:', 0.04943670000080369)
[[-0.4

('episode:', 313, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.46604185  0.        ]]
('episode:', 314, '  score:', -97.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.56028426  0.        ]]
('episode:', 315, '  score:', -152.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.50993373  0.        ]]
('episode:', 316, '  score:', -157.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.5291352  0.       ]]
('episode:', 317, '  score:', -155.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.54571383  0.        ]]
('episode:', 318, '  score:', -156.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.52783762  0.        ]]
('episode:', 319, '  score:', -156.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.43624555  0.        ]]
('episode:', 320, '  score:', -88.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)


('episode:', 377, '  score:', -97.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.57783432  0.        ]]
('episode:', 378, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.56206404  0.        ]]
('episode:', 379, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.49217196  0.        ]]
('episode:', 380, '  score:', -173.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.51573059  0.        ]]
('episode:', 381, '  score:', -160.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.48549031  0.        ]]
('episode:', 382, '  score:', -161.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.40078454  0.        ]]
('episode:', 383, '  score:', -83.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.54670151  0.        ]]
('episode:', 384, '  score:', -156.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017

('episode:', 441, '  score:', -156.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.49199619  0.        ]]
('episode:', 442, '  score:', -148.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.59689812  0.        ]]
('episode:', 443, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.57981729  0.        ]]
('episode:', 444, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.46520012  0.        ]]
('episode:', 445, '  score:', -112.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.57470856  0.        ]]
('episode:', 446, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.59989785  0.        ]]
('episode:', 447, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.48450395  0.        ]]
('episode:', 448, '  score:', -178.0, '  memory length:', 10000, '  epsilon:', 0.0049801000008010

('episode:', 505, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.49666974  0.        ]]
('episode:', 506, '  score:', -165.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.55744294  0.        ]]
('episode:', 507, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.55273495  0.        ]]
('episode:', 508, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.50739382  0.        ]]
('episode:', 509, '  score:', -158.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.46668899  0.        ]]
('episode:', 510, '  score:', -118.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.56718423  0.        ]]
('episode:', 511, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.47755932  0.        ]]
('episode:', 512, '  score:', -160.0, '  memory length:', 10000, '  epsilon:', 0.0049801000008010

('episode:', 569, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.58230103  0.        ]]
('episode:', 570, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.47238704  0.        ]]
('episode:', 571, '  score:', -147.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.55369524  0.        ]]
('episode:', 572, '  score:', -151.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.52969814  0.        ]]
('episode:', 573, '  score:', -154.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.42894669  0.        ]]
('episode:', 574, '  score:', -151.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.56474974  0.        ]]
('episode:', 575, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.44417775  0.        ]]
('episode:', 576, '  score:', -157.0, '  memory length:', 10000, '  epsilon:', 0.0049801000008010

('episode:', 633, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.43320902  0.        ]]
('episode:', 634, '  score:', -87.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.5746954  0.       ]]
('episode:', 635, '  score:', -200.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.5022244  0.       ]]
('episode:', 636, '  score:', -162.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.48893867  0.        ]]
('episode:', 637, '  score:', -179.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.46544159  0.        ]]
('episode:', 638, '  score:', -97.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.44271793  0.        ]]
('episode:', 639, '  score:', -177.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[-0.43873391  0.        ]]
('episode:', 640, '  score:', -134.0, '  memory length:', 10000, '  epsilon:', 0.004980100000801017)
[[