In [6]:
"""
  TensorFlow translation of the torch example found here (written by SeanNaren).
  https://github.com/SeanNaren/TorchQLearningExample
  Original keras example found here (written by Eder Santana).
  https://gist.github.com/EderSantana/c7222daa328f0e885093#file-qlearn-py-L164
  The agent plays a game of catch. Fruits drop from the sky and the agent can choose the actions
  left/stay/right to catch the fruit before it reaches the ground.
"""

# import tensorflow as tf
import tensorflow.compat.v1 as tf 
tf.disable_v2_behavior()
import numpy as np
import random
import math
import os

# Parameters
epsilon = 1  # The probability of choosing a random action (in training). This decays as iterations increase. (0 to 1)
epsilonMinimumValue = 0.001 # The minimum value we want epsilon to reach in training. (0 to 1)
nbActions = 3 # The number of actions. Since we only have left/stay/right that means 3 actions.
epoch = 1001 # The number of games we want the system to run for.
hiddenSize = 100 # Number of neurons in the hidden layers.
maxMemory = 500 # How large should the memory be (where it stores its past experiences).
batchSize = 50 # The mini-batch size for training. Samples are randomly taken from memory till mini-batch size.
gridSize = 10 # The size of the grid that the agent is going to play the game on.
nbStates = gridSize * gridSize # We eventually flatten to a 1d tensor to feed the network.
discount = 0.9 # The discount is used to force the network to choose states that lead to the reward quicker (0 to 1)  
learningRate = 0.2 # Learning Rate for Stochastic Gradient Descent (our optimizer).

# Create the base model.
X = tf.placeholder(tf.float32, [None, nbStates])
W1 = tf.Variable(tf.truncated_normal([nbStates, hiddenSize], stddev=1.0 / math.sqrt(float(nbStates))))
b1 = tf.Variable(tf.truncated_normal([hiddenSize], stddev=0.01))  
input_layer = tf.nn.relu(tf.matmul(X, W1) + b1)
W2 = tf.Variable(tf.truncated_normal([hiddenSize, hiddenSize],stddev=1.0 / math.sqrt(float(hiddenSize))))
b2 = tf.Variable(tf.truncated_normal([hiddenSize], stddev=0.01))
hidden_layer = tf.nn.relu(tf.matmul(input_layer, W2) + b2)
W3 = tf.Variable(tf.truncated_normal([hiddenSize, nbActions],stddev=1.0 / math.sqrt(float(hiddenSize))))
b3 = tf.Variable(tf.truncated_normal([nbActions], stddev=0.01))
output_layer = tf.matmul(hidden_layer, W3) + b3

# True labels
Y = tf.placeholder(tf.float32, [None, nbActions])

# Mean squared error cost function
cost = tf.reduce_sum(tf.square(Y-output_layer)) / (2*batchSize)

# Stochastic Gradient Decent Optimizer
optimizer = tf.train.GradientDescentOptimizer(learningRate).minimize(cost)


# Helper function: Chooses a random value between the two boundaries.
def randf(s, e):
    return (float(random.randrange(0, (e - s) * 9999)) / 10000) + s;


# The environment: Handles interactions and contains the state of the environment
class CatchEnvironment():
    def __init__(self, gridSize):
        self.gridSize = gridSize
        self.nbStates = self.gridSize * self.gridSize
        self.state = np.empty(3, dtype = np.uint8) 

#   Returns the state of the environment.
    def observe(self):
        canvas = self.drawState()
        canvas = np.reshape(canvas, (-1,self.nbStates))
        return canvas

    def drawState(self):
        canvas = np.zeros((self.gridSize, self.gridSize))
        canvas[self.state[0]-1, self.state[1]-1] = 1  # Draw the fruit.
        # Draw the basket. The basket takes the adjacent two places to the position of basket.
        canvas[self.gridSize-1, self.state[2] -1 - 1] = 1
        canvas[self.gridSize-1, self.state[2] -1] = 1
        canvas[self.gridSize-1, self.state[2] -1 + 1] = 1    
        return canvas        

#   Resets the environment. Randomly initialise the fruit position (always at the top to begin with) and bucket.
    def reset(self): 
        initialFruitColumn = random.randrange(1, self.gridSize + 1)
        initialBucketPosition = random.randrange(2, self.gridSize + 1 - 1)
        self.state = np.array([1, initialFruitColumn, initialBucketPosition]) 
        return self.getState()

    def getState(self):
        stateInfo = self.state
        fruit_row = stateInfo[0]
        fruit_col = stateInfo[1]
        basket = stateInfo[2]
        return fruit_row, fruit_col, basket

#   Returns the award that the agent has gained for being in the current environment state.
    def getReward(self):
        fruitRow, fruitColumn, basket = self.getState()
        if (fruitRow == self.gridSize - 1):  # If the fruit has reached the bottom.
            if (abs(fruitColumn - basket) <= 1): # Check if the basket caught the fruit.
                return 1
            else:
                return -1
        else:
            return 0

    def isGameOver(self):
        if (self.state[0] == self.gridSize - 1): 
            return True 
        else: 
            return False 

    def updateState(self, action):
        if (action == 1):
            action = -1
        elif (action == 2):
            action = 0
        else:
            action = 1
        fruitRow, fruitColumn, basket = self.getState()
        newBasket = min(max(2, basket + action), self.gridSize - 1) # The min/max prevents the basket from moving out of the grid.
        fruitRow = fruitRow + 1  # The fruit is falling by 1 every action.
        self.state = np.array([fruitRow, fruitColumn, newBasket])

#   Action can be 1 (move left) or 2 (move right)
    def act(self, action):
        self.updateState(action)
        reward = self.getReward()
        gameOver = self.isGameOver()
        return self.observe(), reward, gameOver, self.getState()   # For purpose of the visual, I also return the state.


# The memory: Handles the internal memory that we add experiences that occur based on agent's actions,
# and creates batches of experiences based on the mini-batch size for training.
class ReplayMemory:
    def __init__(self, gridSize, maxMemory, discount):
        self.maxMemory = maxMemory
        self.gridSize = gridSize
        self.nbStates = self.gridSize * self.gridSize
        self.discount = discount
        canvas = np.zeros((self.gridSize, self.gridSize))
        canvas = np.reshape(canvas, (-1,self.nbStates))
        self.inputState = np.empty((self.maxMemory, 100), dtype = np.float32)
        self.actions = np.zeros(self.maxMemory, dtype = np.uint8)
        self.nextState = np.empty((self.maxMemory, 100), dtype = np.float32)
        self.gameOver = np.empty(self.maxMemory, dtype = np.bool)
        self.rewards = np.empty(self.maxMemory, dtype = np.int8) 
        self.count = 0
        self.current = 0

#   Appends the experience to the memory.
    def remember(self, currentState, action, reward, nextState, gameOver):
        self.actions[self.current] = action
        self.rewards[self.current] = reward
        self.inputState[self.current, ...] = currentState
        self.nextState[self.current, ...] = nextState
        self.gameOver[self.current] = gameOver
        self.count = max(self.count, self.current + 1)
        self.current = (self.current + 1) % self.maxMemory

    def getBatch(self, model, batchSize, nbActions, nbStates, sess, X):
    
        # We check to see if we have enough memory inputs to make an entire batch, if not we create the biggest
        # batch we can (at the beginning of training we will not have enough experience to fill a batch).
        memoryLength = self.count
        chosenBatchSize = min(batchSize, memoryLength)

        inputs = np.zeros((chosenBatchSize, nbStates))
        targets = np.zeros((chosenBatchSize, nbActions))

        # Fill the inputs and targets up.
        for i in range(chosenBatchSize):
            if memoryLength == 1:
                memoryLength = 2
            # Choose a random memory experience to add to the batch.
            randomIndex = random.randrange(1, memoryLength)
            current_inputState = np.reshape(self.inputState[randomIndex], (1, 100))

            target = sess.run(model, feed_dict={X: current_inputState})

            current_nextState =  np.reshape(self.nextState[randomIndex], (1, 100))
            current_outputs = sess.run(model, feed_dict={X: current_nextState})      

            # Gives us Q_sa, the max q for the next state.
            nextStateMaxQ = np.amax(current_outputs)
            if (self.gameOver[randomIndex] == True):
                target[0, [self.actions[randomIndex]-1]] = self.rewards[randomIndex]
            else:
                # reward + discount(gamma) * max_a' Q(s',a')
                # We are setting the Q-value for the action to  r + gamma*max a' Q(s', a'). The rest stay the same
                # to give an error of 0 for those outputs.
                target[0, [self.actions[randomIndex]-1]] = self.rewards[randomIndex] + self.discount * nextStateMaxQ

            # Update the inputs and targets.
            inputs[i] = current_inputState
            targets[i] = target

        return inputs, targets

    
def main(_):
    print("Training new model")

    # Define Environment
    env = CatchEnvironment(gridSize)

    # Define Replay Memory
    memory = ReplayMemory(gridSize, maxMemory, discount)

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
  
    winCount = 0
    with tf.Session() as sess:   
        tf.initialize_all_variables().run() 

        for i in range(epoch):
            # Initialize the environment.
            err = 0
            env.reset()

            isGameOver = False

            # The initial state of the environment.
            currentState = env.observe()

            while (isGameOver != True):
                action = -9999  # action initilization
                # Decides if we should choose a random action, or an action from the policy network.
                global epsilon
                if (randf(0, 1) <= epsilon):
                    action = random.randrange(1, nbActions+1)
                else:          
                    # Forward the current state through the network.
                    q = sess.run(output_layer, feed_dict={X: currentState})          
                    # Find the max index (the chosen action).
                    index = q.argmax()
                    action = index + 1     

                # Decay the epsilon by multiplying by 0.999, not allowing it to go below a certain threshold.
                if (epsilon > epsilonMinimumValue):
                    epsilon = epsilon * 0.999

                nextState, reward, gameOver, stateInfo = env.act(action)

                if (reward == 1):
                    winCount = winCount + 1

                memory.remember(currentState, action, reward, nextState, gameOver)

                # Update the current state and if the game is over.
                currentState = nextState
                isGameOver = gameOver

                # We get a batch of training data to train the model.
                inputs, targets = memory.getBatch(output_layer, batchSize, nbActions, nbStates, sess, X)

                # Train the network which returns the error.
                _, loss = sess.run([optimizer, cost], feed_dict={X: inputs, Y: targets})  
                err = err + loss

            print("Epoch " + str(i) + ": err = " + str(err) + ": Win count = " + str(winCount) + " Win ratio = " + str(float(winCount)/float(i+1)*100))
        # Save the variables to disk.
        save_path = saver.save(sess, os.getcwd()+"/model.ckpt")
        print("Model saved in file: %s" % save_path)

if __name__ == '__main__':
    tf.app.run()

Training new model
Epoch 0: err = nan: Win count = 0 Win ratio = 0.0
Epoch 1: err = nan: Win count = 0 Win ratio = 0.0
Epoch 2: err = nan: Win count = 1 Win ratio = 33.33333333333333
Epoch 3: err = nan: Win count = 1 Win ratio = 25.0
Epoch 4: err = nan: Win count = 2 Win ratio = 40.0
Epoch 5: err = nan: Win count = 2 Win ratio = 33.33333333333333
Epoch 6: err = nan: Win count = 2 Win ratio = 28.57142857142857
Epoch 7: err = nan: Win count = 2 Win ratio = 25.0
Epoch 8: err = nan: Win count = 2 Win ratio = 22.22222222222222
Epoch 9: err = nan: Win count = 2 Win ratio = 20.0
Epoch 10: err = nan: Win count = 2 Win ratio = 18.181818181818183
Epoch 11: err = nan: Win count = 2 Win ratio = 16.666666666666664
Epoch 12: err = nan: Win count = 2 Win ratio = 15.384615384615385
Epoch 13: err = nan: Win count = 2 Win ratio = 14.285714285714285
Epoch 14: err = nan: Win count = 2 Win ratio = 13.333333333333334
Epoch 15: err = nan: Win count = 3 Win ratio = 18.75
Epoch 16: err = nan: Win count = 3 Win

Epoch 126: err = nan: Win count = 36 Win ratio = 28.346456692913385
Epoch 127: err = nan: Win count = 36 Win ratio = 28.125
Epoch 128: err = nan: Win count = 36 Win ratio = 27.906976744186046
Epoch 129: err = nan: Win count = 37 Win ratio = 28.46153846153846
Epoch 130: err = nan: Win count = 37 Win ratio = 28.24427480916031
Epoch 131: err = nan: Win count = 38 Win ratio = 28.78787878787879
Epoch 132: err = nan: Win count = 39 Win ratio = 29.32330827067669
Epoch 133: err = nan: Win count = 39 Win ratio = 29.1044776119403
Epoch 134: err = nan: Win count = 39 Win ratio = 28.888888888888886
Epoch 135: err = nan: Win count = 40 Win ratio = 29.411764705882355
Epoch 136: err = nan: Win count = 40 Win ratio = 29.1970802919708
Epoch 137: err = nan: Win count = 40 Win ratio = 28.985507246376812
Epoch 138: err = nan: Win count = 40 Win ratio = 28.776978417266186
Epoch 139: err = nan: Win count = 40 Win ratio = 28.57142857142857
Epoch 140: err = nan: Win count = 40 Win ratio = 28.368794326241137
E

Epoch 249: err = nan: Win count = 75 Win ratio = 30.0
Epoch 250: err = nan: Win count = 75 Win ratio = 29.880478087649404
Epoch 251: err = nan: Win count = 75 Win ratio = 29.761904761904763
Epoch 252: err = nan: Win count = 75 Win ratio = 29.64426877470356
Epoch 253: err = nan: Win count = 75 Win ratio = 29.527559055118108
Epoch 254: err = nan: Win count = 75 Win ratio = 29.411764705882355
Epoch 255: err = nan: Win count = 75 Win ratio = 29.296875
Epoch 256: err = nan: Win count = 75 Win ratio = 29.18287937743191
Epoch 257: err = nan: Win count = 76 Win ratio = 29.457364341085274
Epoch 258: err = nan: Win count = 76 Win ratio = 29.343629343629345
Epoch 259: err = nan: Win count = 76 Win ratio = 29.230769230769234
Epoch 260: err = nan: Win count = 76 Win ratio = 29.118773946360154
Epoch 261: err = nan: Win count = 76 Win ratio = 29.00763358778626
Epoch 262: err = nan: Win count = 77 Win ratio = 29.277566539923956
Epoch 263: err = nan: Win count = 77 Win ratio = 29.166666666666668
Epoch 

Epoch 372: err = nan: Win count = 103 Win ratio = 27.61394101876676
Epoch 373: err = nan: Win count = 104 Win ratio = 27.807486631016044
Epoch 374: err = nan: Win count = 104 Win ratio = 27.73333333333333
Epoch 375: err = nan: Win count = 104 Win ratio = 27.659574468085108
Epoch 376: err = nan: Win count = 104 Win ratio = 27.586206896551722
Epoch 377: err = nan: Win count = 105 Win ratio = 27.77777777777778
Epoch 378: err = nan: Win count = 105 Win ratio = 27.70448548812665
Epoch 379: err = nan: Win count = 106 Win ratio = 27.89473684210526
Epoch 380: err = nan: Win count = 106 Win ratio = 27.821522309711288
Epoch 381: err = nan: Win count = 106 Win ratio = 27.748691099476442
Epoch 382: err = nan: Win count = 106 Win ratio = 27.676240208877285
Epoch 383: err = nan: Win count = 107 Win ratio = 27.864583333333332
Epoch 384: err = nan: Win count = 107 Win ratio = 27.79220779220779
Epoch 385: err = nan: Win count = 107 Win ratio = 27.72020725388601
Epoch 386: err = nan: Win count = 108 Win

Epoch 492: err = nan: Win count = 133 Win ratio = 26.97768762677485
Epoch 493: err = nan: Win count = 134 Win ratio = 27.125506072874494
Epoch 494: err = nan: Win count = 134 Win ratio = 27.070707070707073
Epoch 495: err = nan: Win count = 135 Win ratio = 27.21774193548387
Epoch 496: err = nan: Win count = 135 Win ratio = 27.16297786720322
Epoch 497: err = nan: Win count = 135 Win ratio = 27.10843373493976
Epoch 498: err = nan: Win count = 135 Win ratio = 27.054108216432866
Epoch 499: err = nan: Win count = 135 Win ratio = 27.0
Epoch 500: err = nan: Win count = 136 Win ratio = 27.14570858283433
Epoch 501: err = nan: Win count = 136 Win ratio = 27.091633466135455
Epoch 502: err = nan: Win count = 136 Win ratio = 27.037773359840955
Epoch 503: err = nan: Win count = 137 Win ratio = 27.18253968253968
Epoch 504: err = nan: Win count = 137 Win ratio = 27.12871287128713
Epoch 505: err = nan: Win count = 137 Win ratio = 27.07509881422925
Epoch 506: err = nan: Win count = 137 Win ratio = 27.021

Epoch 611: err = nan: Win count = 164 Win ratio = 26.797385620915033
Epoch 612: err = nan: Win count = 164 Win ratio = 26.753670473083197
Epoch 613: err = nan: Win count = 164 Win ratio = 26.710097719869708
Epoch 614: err = nan: Win count = 164 Win ratio = 26.666666666666668
Epoch 615: err = nan: Win count = 164 Win ratio = 26.623376623376622
Epoch 616: err = nan: Win count = 164 Win ratio = 26.580226904376016
Epoch 617: err = nan: Win count = 164 Win ratio = 26.537216828478964
Epoch 618: err = nan: Win count = 164 Win ratio = 26.494345718901453
Epoch 619: err = nan: Win count = 165 Win ratio = 26.61290322580645
Epoch 620: err = nan: Win count = 165 Win ratio = 26.570048309178745
Epoch 621: err = nan: Win count = 165 Win ratio = 26.527331189710612
Epoch 622: err = nan: Win count = 165 Win ratio = 26.484751203852326
Epoch 623: err = nan: Win count = 165 Win ratio = 26.442307692307693
Epoch 624: err = nan: Win count = 166 Win ratio = 26.56
Epoch 625: err = nan: Win count = 166 Win ratio 

Epoch 731: err = nan: Win count = 205 Win ratio = 28.005464480874316
Epoch 732: err = nan: Win count = 205 Win ratio = 27.967257844474762
Epoch 733: err = nan: Win count = 205 Win ratio = 27.9291553133515
Epoch 734: err = nan: Win count = 206 Win ratio = 28.027210884353742
Epoch 735: err = nan: Win count = 206 Win ratio = 27.98913043478261
Epoch 736: err = nan: Win count = 206 Win ratio = 27.951153324287652
Epoch 737: err = nan: Win count = 206 Win ratio = 27.91327913279133
Epoch 738: err = nan: Win count = 206 Win ratio = 27.875507442489848
Epoch 739: err = nan: Win count = 206 Win ratio = 27.837837837837835
Epoch 740: err = nan: Win count = 206 Win ratio = 27.800269905533064
Epoch 741: err = nan: Win count = 206 Win ratio = 27.762803234501348
Epoch 742: err = nan: Win count = 206 Win ratio = 27.725437415881558
Epoch 743: err = nan: Win count = 207 Win ratio = 27.82258064516129
Epoch 744: err = nan: Win count = 207 Win ratio = 27.78523489932886
Epoch 745: err = nan: Win count = 207 Wi

Epoch 851: err = nan: Win count = 232 Win ratio = 27.230046948356808
Epoch 852: err = nan: Win count = 232 Win ratio = 27.198124267291913
Epoch 853: err = nan: Win count = 233 Win ratio = 27.28337236533958
Epoch 854: err = nan: Win count = 233 Win ratio = 27.251461988304094
Epoch 855: err = nan: Win count = 233 Win ratio = 27.2196261682243
Epoch 856: err = nan: Win count = 233 Win ratio = 27.18786464410735
Epoch 857: err = nan: Win count = 233 Win ratio = 27.156177156177158
Epoch 858: err = nan: Win count = 233 Win ratio = 27.124563445867288
Epoch 859: err = nan: Win count = 233 Win ratio = 27.093023255813954
Epoch 860: err = nan: Win count = 234 Win ratio = 27.177700348432055
Epoch 861: err = nan: Win count = 234 Win ratio = 27.1461716937355
Epoch 862: err = nan: Win count = 235 Win ratio = 27.230590961761298
Epoch 863: err = nan: Win count = 236 Win ratio = 27.314814814814813
Epoch 864: err = nan: Win count = 237 Win ratio = 27.398843930635834
Epoch 865: err = nan: Win count = 237 Wi

Epoch 971: err = nan: Win count = 262 Win ratio = 26.954732510288064
Epoch 972: err = nan: Win count = 262 Win ratio = 26.92702980472765
Epoch 973: err = nan: Win count = 263 Win ratio = 27.002053388090346
Epoch 974: err = nan: Win count = 263 Win ratio = 26.974358974358974
Epoch 975: err = nan: Win count = 263 Win ratio = 26.946721311475407
Epoch 976: err = nan: Win count = 263 Win ratio = 26.919140225179124
Epoch 977: err = nan: Win count = 263 Win ratio = 26.89161554192229
Epoch 978: err = nan: Win count = 263 Win ratio = 26.86414708886619
Epoch 979: err = nan: Win count = 263 Win ratio = 26.836734693877553
Epoch 980: err = nan: Win count = 263 Win ratio = 26.809378185524974
Epoch 981: err = nan: Win count = 263 Win ratio = 26.782077393075355
Epoch 982: err = nan: Win count = 263 Win ratio = 26.754832146490337
Epoch 983: err = nan: Win count = 263 Win ratio = 26.727642276422763
Epoch 984: err = nan: Win count = 264 Win ratio = 26.802030456852794
Epoch 985: err = nan: Win count = 264

SystemExit: 