In [None]:
# After 500 time steps the game automaticly terminates

import time
import random
import heapq as hp
import gym
import numpy as np
from keras import backend as K
import keras
from collections import deque
from keras.models import Sequential, load_model, Model
from keras.layers.wrappers import TimeDistributed
from keras.layers import Dense, Dropout, Flatten, merge, Input, Lambda, merge, Activation, Embedding
from keras.optimizers import SGD, Adam, rmsprop


import tensorflow as tf
tf.reset_default_graph()




np.random.seed(10)
EPISODES = 1000


class DQNAgent:
    def __init__(self, state_size, action_size, batch_size, epsilon=1.0 ,epsilon_decay=0.98,gamma=0.99, num_atoms = 8, min_rd=-10, max_rd=500):# num_atoms = 51 since it's C51
        self.state_size = state_size
        self.action_size = action_size
        self.batch_size = batch_size
        self.memory = {}
        self.pqt=[]
        self.gamma = gamma   # discount rate 0.95
        self.epsilon = epsilon  # exploration rate
        self.epsilon_min = 0.01   # epsilon_min = 0.01
        self.epsilon_decay = epsilon_decay  #0.99
        self.learning_rate = 0.0001  # the learning rate:0.001
        self.num_atoms = num_atoms
        self.sep=1/self.num_atoms
        self.tau = np.array([(2*i+1)/(2*self.num_atoms) for i in range(self.num_atoms)])
        self.model = self._build_model()
        self.pev_model=self._build_model()
    
    def _EDM_loss(self, target, predicted):     
       
        target_tile = tf.tile(tf.reshape(target, [self.num_atoms, 1]), [1, self.num_atoms])
        predicted_tile = tf.tile(predicted, [self.num_atoms, 1])
        Huber_loss = tf.losses.huber_loss(target_tile, predicted_tile, reduction=tf.losses.Reduction.NONE)         
        diff = predicted_tile - target_tile
        tau = self.tau
        _tau = 1.0 - tau
        Loss = tf.where(tf.less(diff, 0.0), _tau * Huber_loss, tau * Huber_loss)
        loss = tf.reduce_mean(tf.reduce_sum(tf.reduce_mean(Loss, axis=1), axis=0))  
        
        return loss
    
    def _build_model(self):
        # Neural Net for Deep-Q learning Model
        state_input = Input(shape=(self.state_size,))
        l1 = Dense(32, input_dim=self.state_size, activation='relu')(state_input)
        l2 = Dense(32, activation='relu')(l1)
        l3 = Dense(32, activation='relu')(l2)
        
        distribution_list = []
        for i in range(self.action_size):
            distribution_list.append(Dense(self.num_atoms, activation='linear')(l3))

        model = Model(input=state_input, output=distribution_list)

        adam = Adam(lr=self.learning_rate)
        model.compile(loss=self._EDM_loss, optimizer=adam)
       
        return model
    
    def remember(self, eps, state, action, reward, next_state, done):
        
        
        Quantiles_medians = self.model.predict(state)        
        Quantiles_act_median = Quantiles_medians[action][0]        
        prd_rwd = np.mean(Quantiles_act_median)    
        Quantiles_medians_next = self.pev_model.predict(next_state)
        Qsa = [sum(Quantile[0]) for Quantile in Quantiles_medians_next]
        corr_rwd = reward + self.gamma*Qsa[np.argmax(Qsa)]
        
        hp.heappush(self.pqt,-abs(float(prd_rwd - corr_rwd)))
        self.memory[-abs(float(prd_rwd - corr_rwd))]=(state, action, reward, next_state, done)
        
        
    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        pQs = self.model.predict(state)
        Qsa = [sum(pQsa[0]) for pQsa in pQs]
        return np.argmax(Qsa) # returns action
    
    def replay(self, batch_size):
    
        for _ in range(batch_size):
            
            state, action, reward, next_state, done = self.memory[hp.heappop(self.pqt)]  
            if done:    
                target = np.ones(self.num_atoms)*reward
        
            if not done:
        
                Quantiles_medians_next = self.pev_model.predict(next_state)
                Qsa = [sum(pQsa[0]) for pQsa in Quantiles_medians_next]
                corr_rwd = reward + self.gamma*Quantiles_medians_next[np.argmax(Qsa)][0]
                target = corr_rwd
    
            target_f = self.model.predict(state)
            target_f[action][0] = target
            self.model.fit(state, target_f, epochs=1, verbose=0)
            
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)


if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    time_start=time.time()
    batch_size = 32
    agent = DQNAgent(state_size, action_size, batch_size)
    done = False
    
    vally=0
    rslt=[]
    ep=[]
    
    for e in range(EPISODES):
        
        done=False
        state = env.reset()  # It's obdervation 'o'
        state = np.reshape(state, [1, state_size]) # Encapsulating whole thing into and array i.e [[1,12,2,3]]
        for t in range(500):   
            env.render()
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            reward = 0 if t==499 or not done else -1
            next_state = np.reshape(next_state, [1, state_size])
            agent.remember(e, state, action, reward, next_state, done)
            state = next_state
              
            if done and not t==499:
                agent.pev_model.set_weights(agent.model.get_weights())
                print("episode: {}/{}, score: {}, e: {:.2}"
                      .format(e, EPISODES, t, agent.epsilon))
                break
        rslt.append(t)
        if t>=499 : 
            vally+=1
            ep.append(e)
            print("Time for which pole stand:",t,vally)
            if len(ep)>2:
                
                if list(range(3))==[smt-ep[-3] for smt in ep[-3:]]:
                    print("Done after episode:",e-1)
                    break
                 
        if len(agent.pqt) > batch_size:
            agent.replay(batch_size)
    end_time=time.time()
    print(end_time-time_start)