In [None]:
from src import replay_memory_agent, deep_q_agent, epsi_greedy
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.layers.normalization import BatchNormalization
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import gym
from keras.callbacks import Callback

class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

def build_network(input_states,
                  output_states,
                  hidden_layers,
                  nuron_count,
                  activation_function,
                  dropout):
    """
    Build and initialize the neural network with a choice for dropout
    """
    model = Sequential()
    model.add(Dense(nuron_count, input_dim=input_states))
    model.add(Activation(activation_function))
    model.add(Dropout(dropout))
    for i_layers in range(0, hidden_layers - 1):
        model.add(Dense(nuron_count))     
        model.add(Activation(activation_function))
        model.add(BatchNormalization())
        model.add(Dropout(dropout))
    model.add(Dense(output_states))
    sgd = Adam(lr=0.003, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
    model.compile(loss='mean_squared_error', optimizer=sgd)
    return model

q_nn = build_network(4, 2, 5, 20, "relu", 0.0);
#q_nn.load_weights("model_1")
target_nn = build_network(4, 2, 5, 20, "relu", 0.0);
target_nn.set_weights(q_nn.get_weights())

replay1 = replay_memory_agent(4, 5000)

his = LossHistory()

dqn_controller = deep_q_agent(action_value_model=q_nn,
                              target_model=target_nn,
                              states_len=4,
                              replay_memory=replay1,
                              call = [his])

env = gym.make("CartPole-v0")

# Book keeping
avg_reward_episodes = []
# Global time step
gt = 0

for episodes in range(0, 5000):

    # Initial State
    state = env.reset()
    done=False
    
    # Clear the reward buffer
    rewards = []
    if gt > 10000:
        epsilon = max(0.01, epsilon-0.0009)
    else:
        epsilon = 0.20
    
    episode_time = 0

    while not(done):
        gt += 1

        # Reshape the state
        state = np.asarray(state)
        state = state.reshape(1,4)

        # Pick a action based on the state
        q_values = q_nn.predict_on_batch(state)

        if np.random.rand() <= epsilon:
            action = np.random.choice([0, 1])
        else:
            action = np.argmax(q_values)


        # Implement action and observe the reward signal
        state_new, reward, done, _ = env.step(action)
        rewards.append(reward)

        # Update the replay memory
        replay1.replay_memory_update(state, state_new, reward, action, done)

        if gt > 5000:
            update = True if gt%5000==0 else False
            dqn_controller.train_q(update)
            if update:
                print("Updated :",gt)

        state = state_new

        episode_time += 1
        if episode_time >= 200:
            break

    avg_reward_episodes.append(sum(rewards))
    if episodes%100 == 0:
        print(sum(rewards), "Episode Count :" ,episodes)
        q_nn.save_weights("model"+str(episodes))

np.save("sum_rewards", avg_reward_episodes)
plt.plot(avg_reward_episodes)
plt.show()

[2017-12-11 20:51:01,179] Making new env: CartPole-v0


9.0 Episode Count : 0
11.0 Episode Count : 100
13.0 Episode Count : 200
11.0 Episode Count : 300
10.0 Episode Count : 400
11.0 Episode Count : 500
9.0 Episode Count : 600
11.0 Episode Count : 700
15.0 Episode Count : 800
Updated : 10000
14.0 Episode Count : 900
9.0 Episode Count : 1000
10.0 Episode Count : 1100
9.0 Episode Count : 1200
Updated : 15000
10.0 Episode Count : 1300
9.0 Episode Count : 1400
11.0 Episode Count : 1500
10.0 Episode Count : 1600


In [3]:
his

<__main__.LossHistory at 0x11e869f28>

In [8]:
his.losses

[3.6758557e+10,
 3.6758553e+10,
 3.6758479e+10,
 3.6758282e+10,
 3.6758086e+10,
 3.6757799e+10,
 3.6757426e+10,
 3.6757066e+10,
 3.6756586e+10,
 3.6756152e+10]

array([[ 1649.91485255,  1584.63693889]])

2358904.1307523968

0
1
2
