# COURSE:   PGP [AI&ML]

## Learner :  Chaitanya Kumar Battula
## Module  : RNN
## Topic   : Cartpole Demo

In [1]:
import random
import gym
import numpy as np

from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from keras import backend as K
import tensorflow as tf

In [4]:
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
print("state_size:", state_size)
print("action_size:", action_size)

state_size: 4
action_size: 2


In [3]:
EPISODES =2


class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95    
        self.epsilon = 1.0  
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.99
        self.learning_rate = 0.001
        self.model = self._build_model()
        self.target_model = self._build_model()
        self.update_target_model()

        
    def _huber_loss(self, y_true, y_pred, clip_delta=1.0):
        error = y_true - y_pred
        cond  = K.abs(error) <= clip_delta
        squared_loss = 0.5 * K.square(error)
        quadratic_loss = 0.5 * K.square(clip_delta) + clip_delta * (K.abs(error) - clip_delta)
        return K.mean(tf.where(cond, squared_loss, quadratic_loss))

    
    
    def _build_model(self):
        model = Sequential()
        model.add(Dense(24, input_dim=self.state_size, activation='relu'))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss=self._huber_loss,
                      optimizer=Adam(lr=self.learning_rate))
        return model

    
    def update_target_model(self):
        self.target_model.set_weights(self.model.get_weights())

        
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

        
    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        act_values = self.model.predict(state)
        return np.argmax(act_values[0]) 

    
    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        print("minibatch")
        print(minibatch)
        for state, action, reward, next_state, done in minibatch:
            target = self.model.predict(state)
            print("target")
            print(target)
            if done:
                target[0][action] = reward
            else:
                t = self.target_model.predict(next_state)[0]
                target[0][action] = reward + self.gamma * np.amax(t)
            self.model.fit(state, target, epochs=1, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)


if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    done = False
    batch_size = 32
   
    for e in range(EPISODES):
        print("Episode No. " + str(e))
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        for time in range(500):
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            reward = reward if not done else -10
            next_state = np.reshape(next_state, [1, state_size])
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            print("STEP No : " + str(time))
            if done:
                agent.update_target_model()
                print("episode: {}/{}, score: {}, e: {:.2}"
                      .format(e, EPISODES, time, agent.epsilon))
                break
            if len(agent.memory) > batch_size:
                print("replaying")
                agent.replay(batch_size)
        

Episode No. 0
STEP No : 0
STEP No : 1
STEP No : 2
STEP No : 3
STEP No : 4
STEP No : 5
STEP No : 6
STEP No : 7
STEP No : 8
STEP No : 9
STEP No : 10
STEP No : 11
STEP No : 12
STEP No : 13
STEP No : 14
STEP No : 15
STEP No : 16
STEP No : 17
STEP No : 18
STEP No : 19
STEP No : 20
STEP No : 21
STEP No : 22
STEP No : 23
STEP No : 24
STEP No : 25
STEP No : 26
STEP No : 27
episode: 0/2, score: 27, e: 1.0
Episode No. 1
STEP No : 0
STEP No : 1
STEP No : 2
STEP No : 3
STEP No : 4
replaying
minibatch
[(array([[-0.02122683, -0.7301147 ,  0.00187974,  0.9374491 ]]), 0, 1.0, array([[-0.03582912, -0.92526195,  0.02062873,  1.2307221 ]]), False), (array([[ 0.04252223,  0.17686589,  0.01936232, -0.26590154]]), 0, 1.0, array([[ 0.04605955, -0.01852698,  0.01404429,  0.03282499]]), False), (array([[ 0.00353612,  0.4296414 , -0.0230236 , -0.57668526]]), 0, 1.0, array([[ 0.01212895,  0.23484961, -0.0345573 , -0.29134335]]), False), (array([[-0.07674722, -1.31628047,  0.07583908,  1.83624861]]), 1, 1.0, arra

target
[[0.15911376 0.12777862]]
target
[[0.17677237 0.09972414]]
target
[[0.11595101 0.0989605 ]]
target
[[0.19965708 0.11961284]]
target
[[0.163133   0.11178268]]
target
[[0.0930938  0.08825486]]
target
[[0.17516552 0.14919724]]
target
[[0.16743353 0.15146828]]
target
[[0.29050043 0.23086482]]
target
[[0.09749966 0.10681043]]
target
[[0.27425268 0.15517688]]
target
[[0.11666493 0.11920737]]
target
[[0.27052873 0.15937638]]
target
[[0.1076186  0.12691031]]
target
[[0.2948472  0.26866475]]
target
[[0.4218673  0.21114507]]
target
[[0.31454363 0.18582793]]
target
[[0.3801945  0.22142051]]
target
[[0.14370306 0.15586199]]
target
[[0.23102413 0.23412693]]
target
[[0.4441931 0.2812815]]
target
[[0.22673675 0.24364422]]
target
[[0.44936463 0.4329713 ]]
target
[[0.1804685  0.18869375]]
target
[[0.3697276  0.37267718]]
target
[[0.29322374 0.30195847]]
target
[[0.19166139 0.21846314]]
target
[[0.47350365 0.37228975]]
target
[[0.4176002  0.41742212]]
target
[[0.58306485 0.43807992]]
STEP No : 6


target
[[0.3934991  0.43942034]]
target
[[0.41511106 0.4490021 ]]
target
[[0.5017066 0.5593414]]
target
[[1.1157973 1.1881676]]
target
[[0.6750045 0.6952904]]
target
[[0.8643204 0.9256721]]
target
[[0.52379805 0.56380016]]
target
[[0.6310373 0.6753574]]
target
[[0.68811184 0.7260354 ]]
target
[[0.95036626 0.98672974]]
target
[[0.7676636  0.88089716]]
target
[[1.0647137 1.1305596]]
target
[[0.7797326 0.8137274]]
target
[[0.6506032 0.7577406]]
target
[[1.269779  1.2980139]]
target
[[0.70031023 0.7404223 ]]
target
[[1.0991108 1.2050012]]
target
[[0.59219813 0.6413597 ]]
target
[[1.0465187 1.066282 ]]
target
[[0.51419586 0.58523315]]
target
[[0.5184092 0.5944902]]
target
[[0.7190834 0.8124057]]
target
[[0.5855674  0.64031553]]
target
[[1.3315746 1.3777685]]
target
[[0.5274961  0.61703384]]
target
[[1.364589  1.3897601]]
target
[[1.115987  1.2290448]]
target
[[0.5815961 0.679132 ]]
target
[[0.96755177 1.0896598 ]]
target
[[1.6697716 1.6746407]]
target
[[0.5412082 0.6317649]]
STEP No : 8
rep

target
[[1.2113056 1.2514752]]
target
[[0.58604276 0.67559654]]
target
[[0.988531  1.0454088]]
target
[[1.4898918 1.5121598]]
target
[[1.0077924 1.0358244]]
target
[[1.2349145 1.2477844]]
target
[[0.75266904 0.78271204]]
target
[[0.8811772 0.9470372]]
target
[[0.6862541 0.7239191]]
target
[[0.61012805 0.67516166]]
target
[[0.6264371 0.7044146]]
target
[[0.85833824 0.84256107]]
target
[[0.99151343 0.96942747]]
target
[[1.1292012 1.1162579]]
target
[[1.0848136 1.0351384]]
target
[[0.68058985 0.69913006]]
target
[[0.6394992  0.67452925]]
target
[[0.646414  0.6895277]]
target
[[0.8372661  0.89192575]]
target
[[1.0623707 1.0129219]]
target
[[0.753086  0.7518673]]
target
[[1.7951151 1.6663384]]
target
[[0.8224801  0.79762065]]
target
[[1.5008395 1.3902091]]
target
[[0.7178459 0.7844213]]
target
[[0.83154064 0.8828573 ]]
target
[[0.6640327 0.6947376]]
target
[[0.6772598  0.73614275]]
target
[[0.79597515 0.7704399 ]]
target
[[0.8236131 0.7887673]]
STEP No : 10
replaying
minibatch
[(array([[ 0.