The OpenAI Gym (https://gym.openai.com) provides us with a lot of different examples and games in which to train a learning agent. The task is to develop one of such agents. We will create a neural network that, given the state of the game (actually, two consecutive states), it outputs a family of quality values (Q-values) for each next possible move. The move with higher Q-value is chosen and performed in the game. This theoretical formalism was taken from https://www.nervanasys.com/demystifying-deep-reinforcement-learning/

In [3]:

from keras.utils import plot_model
import threading
import time

In [10]:
# INITIALIZATION: libraries, parameters, network...

from keras.models import Sequential      # One layer after the other
from keras.layers import Dense, Flatten  # Dense layers are fully connected layers, Flatten layers flatten out multidimensional inputs
from collections import deque            # For storing moves 
import World
import numpy as np
import gym                                # To train our network


import random     # For sampling batches from the observations


    # Parameters
D = deque()                                # Register where the actions will be stored

observetime = 300                          # Number of timesteps we will be acting on the game and observing results
epsilon = 0.7                              # Probability of doing a random move
gamma = 0.9                                # Discounted future reward. How much we care about steps further in time
mb_size = 50     

model = Sequential()
model.add(Dense(200, input_shape=(5,) , init='uniform', activation='relu'))
    #model.add(Flatten())       # Flatten input so as to have no problems with processing
model.add(Dense(200, init='uniform', activation='relu'))
model.add(Dense(4, init='uniform', activation='linear'))    # Same number of outputs as possible actions

model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
plot_model(model, to_file='model.png',show_shapes=True)
                          # Learning minibatch size



In [11]:
def do_action(action):
    r = -World.score
    if action == 0:
        World.try_move(0, -1)
    elif action == 1:
        World.try_move(0, 1)
    elif action == 2:
        World.try_move(-1, 0)
    elif action == 3:
        World.try_move(1, 0)
    else:
        return
    s2 = World.findState()
    r += World.score
    return r, s2


In [12]:
# FIRST STEP: Knowing what each action does (Observing)
def observe():
    state = World.findState()    
    # Game begins
       # (Formatting issues) Making the observation the first element of a batch of inputs 
    done = False
    for t in range(observetime):
        if np.random.rand() <= epsilon:
            action = np.random.randint(0, 3, size=1)[0]
        else:
            Q = model.predict(state)          # Q-values predictions
            action = np.argmax(Q,axis=1)[0]             # Move with highest Q-value is the chosen one
        reward,state_new = do_action(action)     # See state of the game, reward... after performing the action
                 # Update the input with the new state of the game
        D.append((state, action, reward, state_new, World.has_restarted()))
        #print((state, action, reward, state_new, done))# 'Remember' action and consequence
        state = state_new         # Update state
        if World.has_restarted():
            #env.reset()           # Restart game if it's finished
            state = World.findState()  
           # (Formatting issues) Making the observation the first element of a batch of inputs 
    print('Observing Finished')


In [17]:
# SECOND STEP: Learning from the observations (Experience replay)

                            # Sample some moves

def learn():
    inputs = np.zeros((mb_size,5))
    targets = np.zeros((mb_size, 4))
    for j in range(30):
        minibatch = random.sample(D, mb_size)  
        for i in range(0, mb_size):
            state = minibatch[i][0]
            action = minibatch[i][1]
            reward = minibatch[i][2]
            state_new = minibatch[i][3]
            done = minibatch[i][4]

    # Build Bellman equation for the Q function
            inputs[i:i+1] = state
            targets[i] = model.predict(state)
            Q_sa = model.predict(state_new)
            print("Target ",targets[i,action])
            if done:
                targets[i, action] = reward
            else:
                targets[i, action] = reward + gamma * np.max(Q_sa,axis=1)
            print("Target 1",targets[i,action])

    # Train network to output the Q function
    
        model.train_on_batch(inputs, targets)    
    print('Learning Finished')
    #print(inputs,targets)

In [18]:
# THIRD STEP: Play!
def test():
    print("Evaluating")
    state = World.findState()
    done = False
    tot_reward = 0.0
    while not done:                 # Uncomment to see game running
        Q = model.predict(state)        
        action = np.argmax(Q,axis=1)   
        #print(Q,action)
        reward,observation = do_action(action[0]) 
        tot_reward += reward
        print('Game : ',action,"   reward ",reward)
        if World.has_restarted():
            break
    print('Game ended! Total reward: {}'.format(tot_reward))

In [19]:
def process():
    for i in range(30):
        World.restart_game()
        observe()
        learn()
        test()
        D.clear()
process()        

Observing Finished
Target  -0.149237841368
Target 1 -0.5
Target  1.09349322319
Target 1 2.95587468147
Target  -0.0917944163084
Target 1 -0.5
Target  3.50652766228
Target 1 3.19581484795
Target  2.9754011631
Target 1 2.71974611282
Target  -0.0793486312032
Target 1 -0.5
Target  3.15190505981
Target 1 2.47786092758
Target  3.50652766228
Target 1 3.65587472916
Target  2.9754011631
Target 1 2.71974611282
Target  3.50652766228
Target 1 3.19581484795
Target  3.50652766228
Target 1 2.79581475258
Target  -0.158337652683
Target 1 -0.5
Target  1.06716942787
Target 1 2.63671445847
Target  3.68523740768
Target 1 2.95587468147
Target  1.10764360428
Target 1 3.11671352386
Target  -0.103967480361
Target 1 -0.5
Target  3.68523740768
Target 1 3.85587477684
Target  2.9754011631
Target 1 3.85587477684
Target  -0.0917944163084
Target 1 -0.5
Target  3.68523740768
Target 1 4.25587463379
Target  1.08003878593
Target 1 2.79581475258
Target  3.68523740768
Target 1 2.95587468147
Target  2.05211162567
Target 1 1.

Target  2.38325834274
Target 1 1.83026480675
Target  2.51072454453
Target 1 2.84493255615
Target  1.88760602474
Target 1 1.44487333298
Target  1.38319253922
Target 1 2.84493255615
Target  2.38325834274
Target 1 2.23026490211
Target  2.38325834274
Target 1 2.23026490211
Target  2.51072454453
Target 1 1.94493246078
Target  1.15832185745
Target 1 -0.6
Target  2.25584983826
Target 1 2.11821103096
Target  1.23297595978
Target 1 -0.6
Target  2.38325834274
Target 1 1.83026480675
Target  2.51072454453
Target 1 2.84493255615
Target  2.51072454453
Target 1 1.94493246078
Target  -0.301080912352
Target 1 0.0
Target  2.1313457489
Target 1 2.00714492798
Target  1.23297595978
Target 1 1.83026480675
Target  -0.284569978714
Target 1 -0.5
Target  1.38319253922
Target 1 0.905042469501
Target  1.25739765167
Target 1 0.0
Target  -0.250688970089
Target 1 -0.5
Target  1.23297595978
Target 1 1.83026480675
Target  2.51072454453
Target 1 2.84493255615
Target  1.22782492638
Target 1 1.10264706612
Target  2.38325

Target  1.51871883869
Target 1 1.57579243183
Target  -0.286233067513
Target 1 -0.5
Target  1.705555439
Target 1 2.18033957481
Target  1.42953634262
Target 1 1.52609109879
Target  -0.39615598321
Target 1 -0.5
Target  1.61062443256
Target 1 1.62726187706
Target  1.29241335392
Target 1 1.22726178169
Target  -0.393433332443
Target 1 0.0
Target  1.51871883869
Target 1 1.17579233646
Target  1.30555081367
Target 1 1.28033959866
Target  1.705555439
Target 1 2.18033957481
Target  -0.393433332443
Target 1 0.0
Target  1.705555439
Target 1 2.58033967018
Target  1.26287639141
Target 1 1.16312253475
Target  1.61062443256
Target 1 1.22726178169
Target  1.3440964222
Target 1 1.47775089741
Target  1.705555439
Target 1 2.18033957481
Target  1.32717490196
Target 1 0.0
Target  0.921025454998
Target 1 2.18033957481
Target  1.42953634262
Target 1 1.12609100342
Target  1.705555439
Target 1 2.18033957481
Target  1.30555081367
Target 1 1.28033959866
Target  1.32717490196
Target 1 -0.5
Target  1.61062443256
Tar

Target 1 0.575787603855
Target  -0.308220356703
Target 1 0.0
Target  1.28710865974
Target 1 1.77578759193
Target  -0.308220356703
Target 1 0.0
Target  -0.308220356703
Target 1 0.0
Target  0.878645539284
Target 1 1.07578754425
Target  1.36452245712
Target 1 1.97578763962
Target  1.14205741882
Target 1 0.948603332043
Target  0.878645539284
Target 1 1.07578754425
Target  1.07621657848
Target 1 1.97578763962
Target  1.36452245712
Target 1 1.97578763962
Target  1.36452245712
Target 1 1.07578754425
Target  1.36452245712
Target 1 1.97578763962
Target  1.14205741882
Target 1 0.948603332043
Target  1.28710865974
Target 1 1.02954351902
Target  0.873531877995
Target 1 1.1229621172
Target  -0.315054893494
Target 1 -0.5
Target  0.873648166656
Target 1 0.948603332043
Target  0.757360756397
Target 1 0.596997916698
Target  1.28710865974
Target 1 1.42954361439
Target  -0.334378510714
Target 1 -0.5
Target  1.28710865974
Target 1 1.42954361439
Target  -0.321351438761
Target 1 -0.5
Target  0.873531877995


Target  1.60566866398
Target 1 1.57111465931
Target  1.68798220158
Target 1 2.14510178566
Target  1.68798220158
Target 1 2.14510178566
Target  0.735469222069
Target 1 1.17111456394
Target  -0.267652958632
Target 1 0.2
Target  0.745245695114
Target 1 1.1153409481
Target  1.68798220158
Target 1 2.14510178566
Target  0.694826483727
Target 1 -0.5
Target  1.60566866398
Target 1 1.57111465931
Target  1.78707373142
Target 1 2.23195457458
Target  1.70217168331
Target 1 1.65578830242
Target  -0.266618996859
Target 1 -0.5
Target  -0.277078807354
Target 1 -0.5
Target  0.686303377151
Target 1 0.0
Target  1.45948433876
Target 1 1.47107708454
Target  1.78707373142
Target 1 2.23195457458
Target  0.701724290848
Target 1 1.40836632252
Target  -0.304416567087
Target 1 0.6
Target  1.78707373142
Target 1 2.23195457458
Target  1.78707373142
Target 1 1.33195447922
Target  1.61754250526
Target 1 1.58224880695
Target  1.78707373142
Target 1 2.23195457458
Target  -0.286944240332
Target 1 0.2
Target  1.78707373

Target  2.22267961502
Target 1 2.11199522018
Target  1.9371137619
Target 1 1.86957740784
Target  -0.284714311361
Target 1 0.0
Target  2.32037639618
Target 1 2.70041155815
Target  2.32037639618
Target 1 2.70041155815
Target  0.842253923416
Target 1 1.88833856583
Target  1.43057739735
Target 1 0.983845055103
Target  0.842253923416
Target 1 1.88833856583
Target  0.889151096344
Target 1 1.62447392941
Target  0.826444983482
Target 1 0.0
Target  2.32037639618
Target 1 2.70041155815
Target  -0.296373724937
Target 1 -0.5
Target  2.22267961502
Target 1 2.11199522018
Target  2.22267961502
Target 1 2.11199522018
Target  2.22267961502
Target 1 1.71199524403
Target  0.875469326973
Target 1 1.71199524403
Target  0.858990252018
Target 1 1.80041146278
Target  -0.338549911976
Target 1 -0.5
Target  0.826444983482
Target 1 0.0
Target  1.50286853313
Target 1 1.04549396038
Target  2.211810112
Target 1 2.1022207737
Target  0.951647222042
Target 1 1.79062902927
Target  0.699268817902
Target 1 1.15258157253
T