In [1]:
#help taken from anirudh topiwala git-repo

import random
import gym
import math
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

Using TensorFlow backend.


In [2]:
def preprocess_state(state):
    return np.reshape(state, [1, 4])

def replay(model,gamma,memory, batch_size):
    
    x_batch, y_batch = [], []
    #sample a random mini batch out of the memory stored in previous episode
    minibatch = random.sample(memory, min(len(memory), batch_size))
    for state, action, reward, next_state, done in minibatch:
        y_target = model.predict(state)
        y_target[0][action] = reward if done else reward + gamma * np.max(model.predict(next_state)[0])
        x_batch.append(state[0])
        y_batch.append(y_target[0])

    model.fit(np.array(x_batch), np.array(y_batch), batch_size=len(x_batch), verbose=0)
    

In [3]:
#  all the parameters needed
n_episodes=7000
n_win_ticks=400
# can have any number of steps per episode thus none
max_env_steps=None
gamma=0.99
epsilon=1.0 
epsilon_min=0.01
epsilon_decay=0.995
alpha=0.01
alpha_decay=0.01 
batch_size=64 
monitor=False 
quiet=False

# Model
model = Sequential()
model.add(Dense(24, input_dim=4, activation='relu'))
model.add(Dense(36, activation='relu'))
model.add(Dense(2, activation='linear'))
model.compile(loss='mse', optimizer=Adam(lr=alpha, decay=alpha_decay))


# environment setting 
memory = deque(maxlen=100000)
env = gym.make('CartPole-v1')
if monitor: self.env = gym.wrappers.Monitor(self.env, '../data/cartpole-1', force=True)

scores = deque(maxlen=100)
solved = 0 
for e in range(n_episodes):
    state = preprocess_state(env.reset())
    done = False
    i = 0
    #run completely for one complete game for instance 
    #assuming some model for q given state and action initially the the model updates each time it runs
    while not done:
        env.render()
        
        #for exploration
        eps = max(epsilon_min, min(epsilon, 1.0 - math.log10((e + 1) * epsilon_decay)))
        action = env.action_space.sample() if (np.random.random() <= eps) else np.argmax(model.predict(state))
        
        #Next state prediction
        next_state, reward, done, _ = env.step(action)
        next_state = preprocess_state(next_state)
        memory.append((state, action, reward, next_state, done)) 
        state = next_state
        i += 1

    # score basically it the number of steps per episode it has survived thus has been assigned to i which is 
    # incremented every time step it is not done

    scores.append(i)
    mean_score = np.mean(scores)
    if mean_score >= n_win_ticks and e >= 100:
        if not quiet: print('Ran {} episodes. Solved after {} trials ✔'.format(e, e - 100))
        solved = 1
        break
            
    if e % 100 == 0 and not quiet:
        print('[Episode {}] - Mean survival time over last 100 episodes was {} streaks.'.format(e, mean_score))

    replay(model,gamma,memory,batch_size)
    
    #reduce the epsilon over episode so that we reduce the exploration and increase exploitation over time 
    if epsilon > epsilon_min:
            epsilon *= epsilon_decay

if not quiet and solved == 0: print('Did not solve after {} episodes'.format(e))
env.close()


[Episode 0] - Mean survival time over last 100 episodes was 14.0 streaks.
[Episode 100] - Mean survival time over last 100 episodes was 9.91 streaks.
[Episode 200] - Mean survival time over last 100 episodes was 30.34 streaks.
[Episode 300] - Mean survival time over last 100 episodes was 30.61 streaks.
[Episode 400] - Mean survival time over last 100 episodes was 38.68 streaks.
[Episode 500] - Mean survival time over last 100 episodes was 66.07 streaks.
[Episode 600] - Mean survival time over last 100 episodes was 127.08 streaks.
[Episode 700] - Mean survival time over last 100 episodes was 139.45 streaks.
[Episode 800] - Mean survival time over last 100 episodes was 135.4 streaks.
[Episode 900] - Mean survival time over last 100 episodes was 202.19 streaks.
[Episode 1000] - Mean survival time over last 100 episodes was 225.87 streaks.
[Episode 1100] - Mean survival time over last 100 episodes was 311.41 streaks.
[Episode 1200] - Mean survival time over last 100 episodes was 310.9 stre

In [4]:
# render the model  
time = []
for i_episode in range(100):
    observation = env.reset()
    for t in range(1000):
        env.render()
        state = preprocess_state(observation)
        action = np.argmax(model.predict(state)[0])
        observation, reward, done, info = env.step(action)
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            time.append((t+1))
            break
print(np.mean(time))
env.close()

Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 453 timesteps
Episode finished after 500 timesteps
Episode finished after 468 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 124 timesteps
Episode finished after 500 timesteps
Episode finished after 98 timesteps
Episode finished after 129 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 106 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Episode finished after 500 timesteps
Ep

In [5]:
print time

[500, 500, 500, 500, 500, 500, 500, 500, 453, 500, 468, 500, 500, 124, 500, 98, 129, 500, 500, 500, 500, 500, 500, 106, 500, 500, 500, 500, 500, 500, 500, 500, 500, 102, 500, 500, 500, 500, 500, 500, 500, 500, 486, 500, 500, 500, 500, 116, 500, 500, 500, 500, 498, 418, 500, 469, 500, 254, 500, 116, 500, 500, 453, 411, 500, 500, 110, 500, 480, 500, 500, 95, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 495, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 100]
