In [154]:
#Imports
import gym
import tensorflow as tf
import numpy as np
import random
from collections import deque 
env = gym.make('CartPole-v1')

In [155]:
#Global varaible
epsilon = 1

In [167]:
#Constants
REPLAY_MEMORY_SIZE = 10000
MIN_REPLAY_MEMORY_SIZE = 250
MIN_EPSILON = 0.01
EPSILON_DECAY = 0.98
BATCH_SIZE = 32
DISCOUNT = 0.9995
UPDATE_PREDICITON_EVERY = 2

In [157]:
#Model takes in actions and outputs Q values for each action which represnt the predicted score for the action
def create_model():
    model = tf.keras.Sequential([tf.keras.layers.Dense(64,input_shape = (4,)),
                                 tf.keras.layers.Dense(32, activation = 'relu'),
                                 tf.keras.layers.Dense(2, activation = 'linear')])
    model.compile(tf.keras.optimizers.Adam(0.001), loss = 'mse', metrics='accuracy')
    return model

In [158]:
#Create new models
updated_model = create_model()

#Prediction model is created to keep some consitancy in future Q values predictions
#This only gets updated after every couple runthroughs of the enviorment
prediction_model = create_model()
prediction_model.set_weights(updated_model.get_weights())

In [203]:
#Returns a action for the Cart Pole enviorment
#If epsilon is bigger there is a larger chance of performing random actions (exploring)
#Epsilon decays over training iterations and actions become more greedy (exploitation)
def get_action(obs):
    if random.uniform(0, 1) <= epsilon:
        return random.randint(0,1)
    else:
        q_vals = updated_model.predict(obs)
        print(q_vals)
        return np.argmax(q_vals)

In [201]:
#Using expirence and past rewards this trains the model to output higher Q values for actions that have better rewards given the state(input)
def train_model():
    #Test to see if it has enough expirence
    if len(replay_memory) >= MIN_REPLAY_MEMORY_SIZE:
        #Get a ranodm batch of samples from memory
        batch = random.sample(replay_memory, BATCH_SIZE)
        
        #X[0] is the intial observations. Getting the Q outputs for the intial enviorment condtions batch
        X = np.array([x[0] for x in batch]).reshape(BATCH_SIZE,4)
        current_qs_list = updated_model.predict(X)
        
        #X[3] is the new observations. Getting the Q outputs for the new enviorment condtions batch
        future_qs_list = prediction_model.predict(np.array([x[3] for x in batch]).reshape(BATCH_SIZE,4))
        
        y = []
        #Loop calculates what the Q output for each sample should look like    
        for i, (obs, action, reward, new_obs, done) in enumerate(batch):
            #This logic allows Q values to predict future Q values
            if not done:                        
                max_future_q = np.max(future_qs_list[i]) 
                new_q = reward + DISCOUNT * max_future_q #Q value for action should be determined from what the best future Q value is when taking the action
            else:
                new_q = reward #If the action caused the enviorment to terminate then the Q value should be the reward it got

            #changing the expected Q value for the action taken in the sample
            current_qs = current_qs_list[i]
            current_qs[action] = new_q

            y.append(current_qs)
        #Fitting the model to the batch of samples
        #input: State
        #output: Q values with an updated Q value for the action taken in the sample
        updated_model.fit(X, np.array(y), batch_size=BATCH_SIZE, shuffle=False)  

In [205]:
#Trainer

#Stores expirence with enviorment
replay_memory = deque(maxlen = REPLAY_MEMORY_SIZE)

#Initialization
obs, info = env.reset(return_info=True)
prediction_update_couter = 0
atempt = 0

#Allows training to run many times
while atempt < 1000: #can set any number of enviorment runthroughs
    #Retrive the action that should be performed and perfoms it
    action = get_action(np.array(obs).reshape(-1, 4))
    new_obs, reward, done, info = env.step(action)
    
    #Whenever an action is perfromed information of the affect of that action is stored
    replay_memory.append([obs, action, reward, new_obs, done]);
    
    train_model()
    
    #Reset enviorment, decay epsilon, and updated counters when done
    if done:
        obs, info = env.reset(return_info=True)
        
        #epsilon decay
        if epsilon > MIN_EPSILON:
            epsilon *= EPSILON_DECAY
            epsilon = max(MIN_EPSILON, epsilon)
            
        prediction_update_couter += 1
        atempt += 1
    #If not done update the currant state
    else:
        obs = new_obs
    #If we have runthough the enviorment enough times save the model weights and update the prediction model    
    if(prediction_update_couter >= UPDATE_PREDICITON_EVERY):
        updated_model.save_weights('model_weights')
        prediction_model.set_weights(updated_model.get_weights())
        prediction_update_couter = 0


[[32.834576 31.321314]]
[[32.630146 32.06336 ]]
[[32.541485 32.99511 ]]
[[32.745472 32.518078]]
[[32.615128 33.44111 ]]
[[32.772274 32.95506 ]]
[[32.781914 32.27622 ]]
[[32.478485 33.00816 ]]
[[32.4665   32.334923]]
[[32.121956 33.077076]]
[[32.06719 32.41165]]
[[31.87528  31.559795]]
[[31.381647 32.130344]]
[[31.183292 31.295855]]
[[30.859108 30.280024]]
[[30.246588 30.691944]]
[[29.938208 29.702293]]
[[29.335451 30.137825]]
[[29.035051 29.170794]]
[[28.58251 28.05202]]
[[27.911858 28.306866]]
[[27.480936 27.225027]]
[[26.72829  27.362177]]
[[26.423943 26.449802]]
[[25.704405 25.342916]]
[[24.459152 24.294687]]
[[23.927319 24.34731 ]]
[[23.289167 23.301033]]
[[22.535538 22.08544 ]]
[[21.932821 21.99657 ]]
[[21.241282 20.813759]]
[[20.69464 20.75327]]
[[20.060602 19.597403]]
[[19.568848 19.560888]]
[[19.237206 19.716124]]
[[18.80228  18.766607]]
[[18.496078 18.92852 ]]
[[18.085821 17.984478]]
[[17.80645  18.151806]]
[[17.42191  17.211805]]
[[17.171457 17.38347 ]]
[[16.814243 16.446175]

[[18.040085 17.99084 ]]
[[18.115932 18.566462]]
[[18.169447 18.152378]]
[[18.120771 18.58724 ]]
[[18.268686 18.299032]]
[[18.117445 17.632483]]
[[18.08855  18.076984]]
[[18.041601 18.51292 ]]
[[18.189085 18.223083]]
[[18.038626 17.556273]]
[[18.010618 18.000519]]
[[17.980515 18.45553 ]]
[[18.113228 18.14613 ]]
[[17.963978 17.479082]]
[[17.937296 17.92309 ]]
[[17.94509  18.422234]]
[[18.04296  18.068224]]
[[17.895456 17.400934]]
[[17.870676 17.844696]]
[[17.945297 18.422447]]
[[17.980654 17.989319]]
[[17.835588 17.321762]]
[[17.813456 17.765245]]
[[17.932167 18.392889]]
[[17.92941  17.909271]]
[[17.999456 18.487032]]
[[18.019701 18.04089 ]]
[[17.865559 17.367624]]
[[17.834784 17.805876]]
[[17.945183 18.428709]]
[[17.934534 17.940731]]
[[17.785158 17.268847]]
[[17.759207 17.708336]]
[[17.87455 18.33228]]
[[17.868908 17.845253]]
[[17.972078 18.463226]]
[[17.954641 17.970798]]
[[17.798807 17.29482 ]]
[[17.766724 17.730562]]
[[17.87617 18.35106]]
[[17.864927 17.860916]]
[[17.919333 18.42325

[[19.103287 18.870033]]
[[19.145683 19.440891]]
[[19.070969 18.904736]]
[[19.1104  19.47796]]
[[19.031937 18.944   ]]
[[19.067183 19.519386]]
[[18.899042 19.425543]]
[[18.923311 19.03465 ]]
[[18.696552 18.32413 ]]
[[18.597303 18.72921 ]]
[[18.391663 18.032179]]
[[18.31149  18.449255]]
[[18.12483  17.763523]]
[[18.062202 18.1907  ]]
[[17.89293  17.514402]]
[[17.846752 17.950054]]
[[17.693752 17.28159 ]]
[[17.663332 17.724318]]
[[17.525917 17.06231 ]]
[[17.510977 17.510914]]
[[15.191888 15.669942]]
[[17.506315 18.041393]]
[[17.963171 18.148054]]
[[17.784992 17.467613]]
[[17.730236 17.899569]]
[[17.56865  17.227697]]
[[17.54227  17.676805]]
[[17.40888 17.02391]]
[[17.387754 17.4794  ]]
[[17.268375 16.837029]]
[[17.265778 17.299587]]
[[17.16225  16.660486]]
[[17.16044 17.12214]]
[[17.296154 17.766422]]
[[17.32001  17.301779]]
[[17.435883 17.934544]]
[[17.451077 17.460363]]
[[17.336784 16.802198]]
[[17.317274 17.245249]]
[[17.438627 17.872341]]
[[17.4636   17.395542]]
[[17.569807 18.015636]

[[17.740982 17.191353]]
[[17.67919  17.608023]]
[[17.755322 18.208565]]
[[17.78416  17.713692]]
[[17.849178 18.306427]]
[[17.883135 17.80662 ]]
[[17.94135  18.391724]]
[[17.981728 17.89367 ]]
[[18.03652 18.47555]]
[[18.086544 17.98008 ]]
[[18.131111 18.56067 ]]
[[18.18138  18.070591]]
[[18.219498 18.649406]]
[[18.273754 18.158741]]
[[18.300926 18.73439 ]]
[[18.352922 18.244034]]
[[18.36873  18.816797]]
[[18.421705 18.32503 ]]
[[18.421432 18.890825]]
[[18.472767 18.39901 ]]
[[18.462711 18.9644  ]]
[[18.510641 18.473494]]
[[18.488024 19.034828]]
[[18.548677 18.54712 ]]
[[18.529373 19.107895]]
[[18.597113 18.62265 ]]
[[18.534933 17.95582 ]]
[[18.376326 18.3337  ]]
[[18.359291 18.902664]]
[[18.448717 18.430357]]
[[18.413404 18.995026]]


[[18.4934  18.51937]]
[[18.441614 17.860561]]
[[18.279547 18.248468]]
[[18.260712 18.826996]]
[[18.35837  18.359108]]
[[18.321266 17.702784]]
[[18.160578 18.094532]]
[[18.142227 18.6747  ]]
[[18.244766 18.202902]]
[[18.211313 18.77497 ]]
[[18.307024 18.297945]]
[[18.261675 18.86767 ]]
[[18.350897 18.388468]]
[[18.137068 18.117636]]
[[16.310686 17.14786 ]]
[[18.445139 19.168856]]
[[18.519457 18.677717]]
[[18.468573 18.007566]]
[[18.282064 18.396713]]
[[18.250473 17.736822]]
[[18.084276 18.135967]]
[[18.068022 17.48192 ]]
[[17.910122 17.88471 ]]
[[17.895311 18.476667]]
[[18.011364 18.005054]]
[[17.974022 18.588533]]
[[18.089457 18.114662]]
[[18.071945 17.453451]]
[[17.928337 17.205158]]


[[17.797766 17.620731]]
[[17.80827  18.224483]]
[[17.970442 17.753498]]
[[17.97919  18.354061]]
[[18.1461   17.880732]]
[[18.154263 18.477005]]
[[18.321224 17.997423]]
[[18.316921 18.585241]]
[[18.488476 18.104359]]
[[18.478083 18.68834 ]]
[[18.647001 18.204689]]
[[18.635006 18.78859 ]]
[[18.805561 18.303062]]
[[18.791304 18.883133]]
[[18.959677 18.389948]]
[[18.937632 18.963976]]
[[19.103071 18.4677  ]]
[[19.077606 19.03777 ]]
[[19.181274 19.79117 ]]
[[19.462698 19.468834]]
[[19.596344 18.952694]]
[[19.53221  19.504778]]
[[19.600433 20.241938]]
[[19.844883 19.902506]]
[[19.940178 19.371004]]
[[19.827559 19.910162]]
[[19.91593  19.378773]]
[[19.807245 19.922546]]
[[19.906612 19.398413]]
[[19.799257 19.946505]]
[[19.900074 19.425365]]


[[19.795527 19.976852]]
[[19.900412 19.458878]]
[[19.7935   20.012983]]
[[19.899073 19.498466]]
[[19.781256 20.051771]]
[[19.88724  19.537903]]
[[19.759796 20.089348]]
[[19.856804 19.57739 ]]
[[19.723368 20.130651]]
[[19.829615 19.624727]]
[[19.698156 20.181635]]
[[19.811117 19.683414]]
[[19.676497 20.241991]]
[[19.79171  19.746252]]
[[19.651794 20.30563 ]]
[[19.766726 19.812315]]
[[19.753736 19.137285]]
[[19.489859 19.519886]]
[[19.503036 18.85807 ]]
[[19.257185 19.248304]]
[[19.15046 19.82672]]
[[19.310385 19.360306]]
[[19.340252 18.709044]]
[[19.102928 19.105104]]
[[19.15874  18.464018]]
[[18.935116 18.864649]]
[[19.032665 18.992764]]
[[18.921822 19.569921]]
[[19.104778 19.11297 ]]
[[19.152756 18.468521]]


[[18.90033 18.85813]]
[[18.778101 19.431463]]
[[18.950634 18.971766]]
[[18.991896 18.328835]]
[[18.722675 18.715607]]
[[18.595268 19.290916]]
[[18.794943 18.845304]]
[[18.86608  18.213163]]
[[18.618372 18.603987]]
[[18.506857 19.181864]]
[[18.730764 18.743872]]
[[18.828371 18.12143 ]]
[[18.587889 18.514503]]
[[18.48968  19.097443]]
[[18.633331 19.242464]]
[[18.875904 18.804747]]
[[18.752857 19.369354]]
[[18.98121  18.925138]]
[[18.835005 19.4793  ]]
[[19.043753 19.026327]]
[[18.877304 19.571268]]
[[19.066698 19.113043]]
[[19.124422 18.47102 ]]
[[18.822056 18.831633]]
[[18.906475 18.20316 ]]
[[18.61136  18.569283]]
[[18.452906 19.123379]]
[[18.673986 18.69044 ]]
[[18.765406 18.073023]]
[[18.484432 18.451817]]


[[18.342102 19.021603]]
[[18.57923  18.595114]]
[[18.679903 17.979614]]
[[18.408318 18.362803]]
[[18.276152 18.93516 ]]
[[18.524302 18.509018]]
[[18.380415 19.076822]]
[[18.619469 18.646591]]
[[18.729916 18.030619]]
[[18.462416 18.415182]]
[[18.333214 18.988184]]
[[18.588388 18.55998 ]]
[[18.44546 19.12773]]
[[18.689512 18.693974]]
[[18.794088 18.066654]]
[[18.51672  18.442348]]
[[18.626703 18.563969]]
[[18.4707   19.119282]]
[[18.6974   18.671791]]
[[18.525595 19.224194]]
[[18.740871 18.77262 ]]
[[18.826729 18.13529 ]]
[[18.53038  18.504965]]
[[18.376328 19.064192]]
[[18.616293 18.622227]]
[[18.723486 17.994587]]
[[18.44599  18.372425]]
[[18.307827 18.938135]]
[[18.555977 18.498941]]
[[18.4037   19.059101]]


[[18.644115 18.615726]]
[[18.476917 19.16981 ]]
[[18.698996 18.721268]]
[[18.789513 18.087399]]
[[18.487064 18.455408]]
[[18.324936 19.011887]]
[[18.558886 18.568087]]
[[18.661133 17.937439]]
[[18.367   18.30494]]
[[18.21697  18.861118]]
[[18.467606 18.420576]]
[[18.299995 18.970432]]
[[18.540695 18.530844]]
[[18.36511 19.08023]]
[[18.58514  18.630556]]
[[18.677208 17.996912]]
[[18.378223 18.36591 ]]
[[18.217108 18.924473]]
[[18.464561 18.484211]]
[[18.582302 17.857857]]
[[18.306284 18.233803]]
[[18.17187  18.798542]]
[[18.434734 18.362473]]
[[18.286745 18.92283 ]]
[[18.533484 18.481361]]
[[18.368162 19.03543 ]]
[[18.605255 18.590368]]
[[18.427471 19.141027]]
[[18.650213 18.692478]]
[[18.739311 18.05778 ]]
[[18.434801 18.428474]]


[[18.512833 18.548365]]
[[18.62157  17.922539]]
[[18.3275   18.294647]]
[[18.169897 18.85311 ]]
[[18.428522 18.417751]]
[[18.262505 18.971745]]
[[18.509686 18.528646]]
[[18.62189  17.898298]]
[[34.686554 32.104225]]
[[34.527206 32.819683]]
[[34.48276  33.719112]]
[[34.533623 34.794033]]
[[35.066727 34.637173]]
[[35.018684 35.671513]]
[[35.48938  35.494072]]
[[35.770004 35.10434 ]]
[[35.49082  35.919136]]
[[35.73007 35.52409]]
[[35.358578 36.333145]]
[[35.57284  35.935905]]
[[35.63308  35.348133]]
[[35.053486 35.997585]]
[[35.08889  35.419895]]
[[34.989964 34.658607]]
[[34.27618 35.15607]]
[[34.159153 34.409348]]
[[33.885067 33.489204]]
[[33.0734  33.82812]]
[[32.792934 32.935345]]
[[32.264465 31.896173]]


[[31.380459 32.036697]]
[[30.862612 31.037827]]
[[30.182743 29.89963 ]]
[[29.26348 29.87256]]
[[28.630571 28.769854]]
[[27.888847 27.497873]]
[[26.88096  27.333307]]
[[26.16488  26.089119]]
[[25.235142 25.967182]]
[[24.59315 24.76449]]
[[23.80859  23.377956]]
[[22.82955  23.124998]]
[[22.144564 21.781826]]
[[21.272413 21.576078]]
[[20.677511 20.268421]]
[[19.900995 20.101738]]
[[19.39007 18.82481]]
[[18.708252 18.690386]]
[[18.201447 18.755316]]
[[17.914713 17.692707]]
[[17.448624 17.765835]]
[[17.204    16.711237]]
[[16.792158 16.795767]]
[[16.546568 15.731115]]
[[16.1501   15.809478]]
[[15.914221 16.077429]]
[[15.853518 15.195587]]
[[15.653147 15.458989]]
[[15.569993 15.894645]]
[[15.631516 15.163956]]
[[15.568052 15.585152]]


[[15.6398    14.8380785]]
[[15.600287 15.248869]]
[[15.702885 15.836869]]
[[15.906715 15.237149]]
[[16.002728 15.798946]]
[[16.234604 16.536541]]
[[16.563612 16.087156]]
[[16.766027 16.792803]]
[[17.080666 16.317533]]
[[17.265566 16.995275]]
[[17.578783 17.847467]]
[[17.99165 17.51432]]
[[18.257204 18.326866]]
[[18.63461  17.959364]]
[[18.863157 18.742693]]
[[19.174236 19.689735]]
[[19.585098 19.451637]]
[[19.846516 20.36599 ]]
[[20.20908  20.095156]]
[[20.416685 20.979334]]
[[20.726873 20.680035]]
[[20.845285 21.486818]]
[[21.159767 21.22312 ]]
[[21.23551  19.995869]]
[[21.090847 20.465319]]
[[21.087328 21.126335]]
[[21.196579 20.608559]]
[[21.188494 21.266443]]
[[21.29616  20.746767]]
[[21.280226 21.400036]]


[[21.378675 20.874588]]
[[21.351967 21.525583]]
[[21.440826 21.00049 ]]
[[21.404774 21.650215]]
[[21.416872 20.4093  ]]
[[21.243475 20.87802 ]]
[[21.204714 21.533127]]
[[21.27781  21.011251]]
[[21.219303 21.660055]]
[[21.272131 21.135271]]
[[21.19684  21.783533]]
[[21.240925 21.25958 ]]
[[21.14268  20.551289]]
[[20.925606 21.020737]]
[[20.829922 20.320156]]
[[20.61021  20.792864]]
[[20.519217 20.097754]]
[[20.306623 20.57908 ]]
[[20.229181 19.89483 ]]
[[20.029991 20.387125]]
[[19.963987 19.709242]]
[[19.770853 20.203794]]
[[19.7116   19.530703]]
[[19.519262 20.02584 ]]
[[19.465927 19.358025]]
[[19.275047 19.856413]]
[[19.232132 19.195614]]
[[19.051105 19.701176]]
[[19.009186 19.045712]]
[[18.834251 18.207842]]


[[18.535343 18.540064]]
[[18.355621 17.704481]]
[[18.056557 18.042015]]
[[17.908928 18.572624]]
[[17.89997 17.93491]]
[[17.762402 17.11197 ]]
[[17.514618 17.466492]]
[[17.379667 18.000578]]
[[17.37646  17.359959]]
[[17.267054 17.900925]]
[[17.29576  17.268896]]
[[17.212097 17.815996]]
[[17.272112 17.193186]]
[[17.209978 17.74486 ]]
[[17.286268 17.125353]]
[[17.233984 17.676008]]
[[17.322437 17.056952]]
[[17.271757 17.60351 ]]
[[17.362825 16.982737]]
[[17.319485 17.52615 ]]
[[17.412203 16.895765]]
[[17.371141 17.428684]]
[[17.471138 16.794102]]
[[17.434925 17.321665]]
[[17.495087 18.021507]]
[[17.679007 17.546877]]
[[17.72113  18.229351]]
[[17.88234 17.73556]]
[[17.912321 18.405434]]
[[18.063025 17.903774]]
[[18.088688 18.56864 ]]


[[18.2284   18.060461]]
[[18.24736  18.718508]]
[[18.384296 18.206577]]
[[18.398396 18.857351]]
[[18.531033 18.342287]]
[[18.545704 18.99237 ]]
[[18.682102 18.47845 ]]
[[18.691303 19.123419]]
[[18.815722 18.607119]]
[[18.8108   19.246546]]
[[18.912498 18.723356]]
[[18.88637  19.352804]]
[[18.980635 18.828152]]
[[18.940569 19.453049]]
[[19.023394 18.929508]]
[[18.970833 19.552927]]
[[19.044964 19.030523]]
[[18.979961 19.65396 ]]
[[19.038857 19.132925]]
[[18.96133  18.428238]]
[[18.75388  18.872402]]
[[18.685398 18.175585]]
[[18.487066 18.625492]]
[[18.43781  17.940609]]
[[18.258276 18.404787]]
[[18.237803 17.736116]]
[[18.07522  18.204748]]
[[18.074911 17.543993]]
[[17.920698 18.016308]]
[[17.93386  17.358213]]
[[17.793756 17.830608]]


[[17.829645 17.180593]]
[[17.709883 17.657803]]
[[17.725792 18.316833]]
[[17.91281  17.854307]]
[[17.921375 18.506794]]
[[18.099277 18.03713 ]]
[[18.057674 18.66657 ]]
[[18.184717 18.174553]]
[[18.134066 18.79599 ]]
[[18.255154 18.298353]]
[[18.243647 17.617622]]
[[18.072289 18.059885]]
[[18.050726 18.692692]]
[[18.191456 18.205942]]
[[18.19765  17.534729]]
[[18.05071 17.98722]]
[[18.04674  18.626831]]
[[18.204216 18.150131]]
[[18.18289 18.78273]]
[[18.32009  18.297117]]
[[18.278051 18.91941 ]]
[[18.402218 18.429014]]
[[18.392859 17.756649]]
[[18.222458 18.199255]]
[[18.186495 18.827185]]
[[18.313807 18.34485 ]]
[[18.299654 17.67643 ]]
[[18.114538 18.114576]]
[[18.121798 17.458088]]
[[17.954605 17.908886]]
[[17.924278 18.545122]]


[[18.072256 18.075739]]
[[18.087885 17.42071 ]]
[[17.92012  17.867796]]
[[17.881119 18.49491 ]]
[[18.031559 18.02076 ]]
[[17.986193 18.642292]]
[[18.130867 18.16377 ]]
[[18.146683 17.502579]]
[[17.934618 17.929564]]
[[17.871386 18.545113]]
[[18.003647 18.064392]]
[[18.01062  17.401476]]
[[17.838526 17.843866]]
[[17.883614 17.195032]]
[[17.75424  17.652164]]
[[17.766037 18.292294]]
[[17.973377 17.832039]]
[[17.976854 18.465055]]
[[18.141535 17.989439]]
[[18.076494 18.599867]]
[[18.203678 18.106567]]
[[18.155102 18.719736]]
[[18.288055 18.224634]]
[[18.259819 18.844034]]
[[18.398708 18.35212 ]]
[[18.379396 18.976772]]
[[18.523825 18.48548 ]]
[[18.508684 19.111425]]
[[18.64696  18.616722]]
[[18.623848 19.235392]]
[[18.750845 18.732918]]


[[18.716043 19.344519]]
[[18.824837 18.835386]]
[[18.793684 18.141897]]
[[18.61559  18.565619]]
[[18.5771   19.176094]]
[[18.684092 18.671364]]
[[18.631285 19.27481 ]]
[[18.735703 18.773396]]
[[18.700115 18.087324]]
[[18.513096 18.50743 ]]
[[18.456182 19.108484]]
[[18.563116 18.610254]]
[[18.535212 17.930637]]
[[18.344576 18.351297]]
[[18.332329 17.685492]]
[[18.153255 18.113329]]
[[18.104671 18.723272]]
[[18.224554 18.243975]]
[[18.21399 17.58626]]
[[18.037008 18.021688]]
[[17.993467 18.640284]]
[[18.13557  18.178993]]
[[18.113796 17.52771 ]]
[[17.9156   17.962042]]
[[17.915623 17.319279]]
[[17.744343 17.76581 ]]
[[17.77963 17.13338]]
[[17.645851 17.591091]]
[[17.656748 18.233751]]
[[17.8411   17.785173]]
[[17.839535 18.421314]]


[[18.014904 17.966166]]
[[18.004343 18.595089]]
[[18.17274  18.134346]]
[[18.14279 18.74936]]
[[18.29707  18.279013]]
[[18.25347  18.887001]]
[[18.396595 18.41484 ]]
[[18.412096 17.762186]]
[[18.236664 18.188875]]
[[18.202925 18.799725]]
[[18.364506 18.332172]]
[[18.313265 18.932701]]
[[18.464956 18.459694]]
[[18.396893 19.053402]]
[[18.52635  18.574228]]
[[18.519526 17.910515]]
[[18.308031 18.316082]]
[[18.321478 17.660671]]
[[18.12637  18.069557]]
[[18.066986 18.66096 ]]
[[18.225584 18.19226 ]]
[[18.157135 18.781052]]
[[18.310484 18.314863]]
[[18.335196 17.66899 ]]
[[18.146078 18.08296 ]]
[[18.093138 18.681046]]
[[18.263287 18.223125]]
[[16.784492 17.806244]]
[[18.528605 19.292475]]
[[18.65315  18.816313]]


[[18.645    18.161049]]
[[18.412903 18.561493]]
[[18.426006 17.920399]]
[[18.208864 18.329529]]
[[18.230665 17.69479 ]]
[[18.017185 18.107174]]
[[18.050192 17.478786]]
[[17.901379 17.279718]]
[[17.724884 17.710566]]
[[17.685062 18.325012]]
[[17.877504 17.888521]]
[[17.933783 17.267027]]
[[17.714579 17.681305]]
[[17.646711 18.287342]]
[[17.81757  17.841581]]
[[17.852436 17.210117]]
[[17.677763 17.642656]]
[[17.649708 18.26179 ]]
[[17.843023 17.816193]]
[[17.815296 18.43133 ]]
[[18.003044 17.975637]]
[[17.980118 18.588062]]
[[18.170794 18.127954]]
[[18.14615  18.736359]]
[[18.331419 18.271357]]
[[18.301485 18.874092]]
[[18.482477 18.402376]]
[[18.442358 18.996643]]
[[18.611816 18.518814]]
[[18.557325 19.102379]]


[[18.709173 18.617699]]
[[18.631983 19.192253]]
[[18.764584 18.702316]]
[[18.659666 19.267168]]
[[18.776745 18.777153]]
[[18.76281  18.109089]]
[[18.529562 18.49933 ]]
[[18.427864 19.07416 ]]
[[18.560932 18.602352]]
[[18.525064 17.938131]]
[[18.27196 18.32981]]
[[18.273514 17.681206]]
[[18.059677 18.091179]]
[[18.088264 17.455158]]
[[17.906582 17.880764]]
[[17.873716 18.494783]]
[[18.06984  18.054295]]
[[18.0017   18.652409]]
[[18.15974  18.195086]]
[[18.1879   17.554926]]
[[18.016733 17.98231 ]]
[[17.98565 18.59419]]
[[18.170107 18.143023]]
[[18.137146 18.748734]]
[[18.312016 18.293327]]
[[18.278776 18.899586]]
[[18.454233 18.44417 ]]
[[18.408861 19.044094]]
[[18.581633 18.59109 ]]
[[18.627106 17.960499]]
[[18.452154 18.380901]]


[[18.40642 18.97891]]
[[18.582829 18.52436 ]]
[[18.518444 19.107904]]
[[18.674107 18.64324 ]]
[[18.585714 19.214739]]
[[18.721222 18.741602]]
[[18.726845 18.088604]]
[[18.503527 18.476027]]
[[18.416426 19.045176]]
[[18.553865 18.575071]]
[[18.559443 17.925468]]
[[18.330069 18.308922]]
[[18.239668 18.876953]]
[[18.39706  18.420544]]
[[18.426609 17.787262]]
[[18.212759 18.179466]]
[[18.097782 18.742735]]
[[18.245628 18.291525]]
[[18.18579 16.86346]]
[[17.88566 17.09589]]
[[17.744234 17.519764]]
[[17.746853 18.125452]]
[[18.009369 17.718468]]
[[17.924154 18.84823 ]]
[[18.493345 18.848455]]
[[18.714766 18.419964]]
[[18.655098 18.989824]]
[[18.864351 18.556145]]
[[18.792236 19.120897]]


[[18.998186 18.688002]]
[[18.908615 19.244352]]
[[19.096674 18.804798]]
[[18.991598 19.35229 ]]
[[19.161966 18.905344]]
[[19.033426 19.444317]]
[[19.197443 18.997017]]
[[19.060314 19.531748]]
[[19.214996 19.077713]]
[[19.068533 19.605988]]
[[19.214636 19.149303]]
[[19.054796 19.671455]]
[[19.015566 19.729013]]
[[19.141617 19.270653]]
[[19.127995 18.632168]]
[[18.826126 18.978437]]
[[18.84137  18.355614]]
[[18.570793 18.7193  ]]
[[18.613512 18.112648]]
[[18.364471 18.491987]]
[[18.43405  17.902908]]
[[18.21021  18.295326]]
[[18.301535 17.715637]]
[[18.100138 18.115105]]
[[18.20982  17.540993]]
[[18.024076 17.946411]]
[[17.96821 18.52902]]
[[18.208319 18.131794]]
[[18.123043 18.701817]]
[[18.339128 18.29557 ]]
[[18.234413 18.857212]]


[[18.432049 18.442144]]
[[18.496918 17.846037]]
[[18.255795 18.223766]]
[[18.13762  18.778112]]
[[18.33172  18.360907]]
[[18.395313 17.763117]]
[[18.231611 17.544018]]
[[17.962276 17.908665]]
[[17.834778 18.462286]]
[[18.030224 18.048594]]
[[18.099167 17.455986]]
[[17.874521 17.843256]]
[[17.799374 18.416325]]
[[18.040028 18.018663]]
[[17.972881 18.595007]]
[[18.21947 18.19784]]
[[18.146463 18.76892 ]]
[[18.3915  18.36672]]
[[18.32088  18.933262]]
[[18.56089 18.52598]]
[[18.477648 19.08557 ]]
[[18.697956 18.670042]]
[[18.5491   19.200249]]
[[18.729372 18.770277]]
[[18.774351 18.156927]]
[[18.518187 18.514606]]
[[18.361626 19.04338 ]]
[[18.511028 18.605795]]
[[18.53714  17.990084]]
[[18.285704 18.353825]]


[[18.347822 17.752928]]
[[18.136953 18.134516]]
[[18.067911 18.699196]]
[[18.293028 18.29173 ]]
[[18.231216 18.855637]]
[[18.456324 18.444239]]
[[18.389029 19.004671]]
[[18.610811 18.592192]]
[[18.53441 19.14572]]
[[18.74744  18.730957]]
[[18.658173 19.276997]]
[[18.858097 18.857176]]
[[18.756254 19.39686 ]]
[[18.946274 18.975983]]
[[18.949324 18.355955]]
[[18.67011 18.70021]]
[[18.69396  18.088032]]
[[18.432253 18.441755]]
[[18.477642 17.839611]]
[[18.246788 18.208035]]
[[18.159317 18.759247]]
[[18.360842 18.348888]]
[[18.268116 18.897564]]
[[18.4669  18.48687]]
[[18.542461 17.900438]]
[[18.33497  18.275105]]
[[18.266827 18.832315]]
[[18.491306 18.431541]]
[[18.406858 18.978994]]
[[18.617624 18.57177 ]]
[[18.51917 19.11096]]


[[18.724487 18.699936]]
[[18.622072 19.234377]]
[[18.813854 18.817915]]
[[18.876705 18.226025]]
[[18.527323 19.10793 ]]
[[18.72889 18.70098]]
[[18.601238 19.222334]]
[[18.78632  18.811188]]
[[18.838385 18.22038 ]]
[[18.575674 18.561705]]
[[18.411703 19.074043]]


In [199]:
def get_best_action(obs):
    q_vals = updated_model.predict(obs, verbose = 0)
    return np.argmax(q_vals)

In [207]:
#Run the currant model with best actions
obs, info = env.reset(return_info=True)
done = False
score = 0
while not done:
    obs, reward, done, info = env.step(get_best_action(np.array(obs).reshape(-1, 4)))
    score += reward
    env.render()
print(score)

500.0


In [170]:
len(replay_memory)

4837

In [173]:
epsilon

0.025817827831345935