In [1]:
import gym
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from collections import deque

In [2]:
# keras model approach
from tensorflow.keras import Model,Sequential
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Flatten,BatchNormalization,Dense, Input
from tensorflow.keras.activations import relu

In [3]:
# ADDING THE CODE SO THAT TENSORFLOW DOES NOT EAT THE WHOLE GPU MEMORY
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.keras.backend.set_floatx('float32')

In [4]:
env = gym.make('MountainCar-v0')
env.action_space

Discrete(3)

In [5]:
adam = tf.keras.optimizers.Adam(learning_rate = 0.001)

In [6]:
def model_keras():
    
    inputs = Input(shape=(2,))
    
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(inputs)
    x = BatchNormalization()(x)
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(x)
    x = BatchNormalization()(x)
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(x)
    x = BatchNormalization()(x)
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(x)
    x = BatchNormalization()(x)
    output = Dense(3,activation='linear',kernel_initializer="glorot_uniform")(x)
    model = Model(inputs=inputs, outputs=output, name="RL_Value_Function")
    
    print(model.summary())
    
    model.compile(optimizer=adam,loss='mean_squared_error',metrics=['mean_squared_error'])
    
    return model
# model 2 is the target model
model_1 = model_keras()
model_2 = model_keras()

Model: "RL_Value_Function"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense (Dense)                (None, 100)               300       
_________________________________________________________________
batch_normalization (BatchNo (None, 100)               400       
_________________________________________________________________
dense_1 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_1 (Batch (None, 100)               400       
_________________________________________________________________
dense_2 (Dense)              (None, 100)               10100     
_________________________________________________________________
batch_normalization_2 (Batch (None, 100)         

In [7]:
def custom_loss(y_true,y_pred):
    return tf.keras.losses.mean_squared_error(y_true,y_pred)

In [8]:
# lets build some memory into the model to perform decorrelated batch updates
# this is TD learning

# so apparantly the NN has to be adjusted only according to the action taken by it
# for example if action 2 is taken then only the weights for action 2 should be changed
# for this reason we should only update the q_2 vector with regarding the chosen action
# rewards will also be added to that action only
# q_2 is what we thought the value of the state will be after doing action
# we also add the reward and make this the target for the NN

replay_batch = deque(maxlen = 3000)

warmup = 10 #will start training after these many episodes have passed

# to balance exploration
epsilon = {
"epsilon" : 1.0,
"epsilon_decay": 0.999,
"epsilon_min":0.01,
}

In [9]:
import random
random.seed(2020)

@tf.function
def train_step(model,new_state,new_rew,old_q,lamb): 
    
    with tf.GradientTape() as tape:
        # logits is the forward pass
        logits = model(new_state, training=True)
        
        loss_value = custom_loss(create_y_true(logits,new_rew,lamb),old_q)
    
    #we retrieve the gradients
    grads = tape.gradient(loss_value, model.trainable_weights)
    
    #THIS IS ONE STEP OF GRAD DESCENT (Minimizes the loss)
    adam.apply_gradients(zip(grads, model.trainable_weights))
    #model.fit(x = inputs,y = outputs,batch_size = 1,epochs = 1,verbose = 0)

#@tf.function
def batch_train(model_1,model_2,gamma,batch_size,epsilon):
    
    #decaying the exploration
    if epsilon['epsilon'] > epsilon['epsilon_min']:
         epsilon['epsilon'] =  epsilon['epsilon'] * epsilon['epsilon_decay']
   
    batch = random.sample(replay_batch,batch_size)
    
    batch_reward = []
    batch_action = []
    batch_done = []
    
    batch_current_state = np.zeros((batch_size, 2))
    batch_next_state = np.zeros((batch_size, 2))

    for i in range(batch_size):
        batch_reward.append(batch[i][2])
        batch_action.append(batch[i][1])
        batch_current_state[i] = batch[i][0]
        batch_next_state[i] = batch[i][3]
        batch_done.append(batch[i][4])
    #lets calculate the next state value as the current value will be calculated in 
    # in gradient tape
    
    next_q = model_2.predict(batch_next_state)
    
    max_q = []
    for i in next_q:
        max_q.append(max(i))
    max_q = np.array(max_q,dtype = 'float32')
    
    target = batch_reward + gamma*max_q # this is the Q learning Target
    
    logits = model_1(batch_current_state)
    
    q_target = np.array(logits)

    # replacing the q values of logits for which the action is taken
    # as only those have to be updated
    
    # VERY IMPORTANT NOTE, IF THE EPISODE ENDS THE DONE VALUE BECOMES TRUE
    # IT IS VERY IMPORTANT THAT THE NN UPDATES TOWARDS THIS TRUE VALUE RATHER THAN
    # ITS OWN THINKING VALUE (r + gamma*max(action)) THAT WE USE FOR ALL
    # NON TERMINAL REWARDS 
    # THIS MAKES OR BREAKS THE NETWORK VERY VERY IMPORTANT
    for i in range(batch_size):
        q_target[i][batch_action[i]] = target[i]
        if batch_done[i]:
            q_target[i][batch_action[i]] = batch_reward[i]
    
    model_1.fit(batch_current_state, q_target, batch_size=batch_size,
                       epochs=1, verbose=0)
    
    '''with tf.GradientTape() as tape:
        # logits is the forward pass
        logits = model(batch_current_state, training=True)
        
        q_target = np.array(logits)
        
        # replacing the q values of logits for which the action is taken
        # as only those have to be updated
        q_target[:,batch_action] = target
        # calculating the loss
        loss_value = custom_loss(q_target,logits)
    
    #we retrieve the gradients
    grads = tape.gradient(loss_value, model.trainable_weights)
    
    #THIS IS ONE STEP OF GRAD DESCENT (Minimizes the loss)
    adam.apply_gradients(zip(grads, model.trainable_weights))'''

def policy(q_vals,eps):
    # lets implement a policy which decays
    if np.random.rand() <= eps:  
        return random.randrange(3)
    else:
        action = np.argmax(q_vals[0])
        return action
def update_target_network():
    model_2.set_weights(model_1.get_weights())

In [None]:
global_steps = 0
# to have same networks in the starting
update_target_network()

for i in tqdm(range(500)):
    observation = env.reset()
    observation = np.expand_dims(observation, axis=0)
    done = False
    total_reward = 0
    for j in range(1000):        
        #storing the current state
        state_1 = observation
        
        # this is the current q values
        q_state = model_1(observation)
        
        action = policy(q_state,epsilon["epsilon"]) # current action
        observation,reward,done,info = env.step(action)
        
        # calculating the total reward
        total_reward = total_reward + reward

        observation = np.expand_dims(observation, axis=0)
        state_2 = observation
        state_reward = reward

        if j<990 and done:
            state_reward = 1000
        if j>998 and done == False:
            reward = -100
        
        replay_batch.append((state_1,action,state_reward,state_2,done))
        
        if i>warmup:
            batch_train(model_1,model_2,0.99,64,epsilon)
            global_steps = global_steps + 1
            
        if done:
            update_target_network()
            break
            
    print(total_reward)

  0%|          | 0/500 [00:00<?, ?it/s]



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



  0%|          | 1/500 [00:01<10:24,  1.25s/it]

-200.0


  0%|          | 2/500 [00:02<09:54,  1.19s/it]

-200.0


  1%|          | 3/500 [00:03<09:34,  1.15s/it]

-200.0


  1%|          | 4/500 [00:04<09:31,  1.15s/it]

-200.0


  1%|          | 5/500 [00:05<09:24,  1.14s/it]

-200.0


  1%|          | 6/500 [00:06<09:14,  1.12s/it]

-200.0


  1%|▏         | 7/500 [00:07<09:14,  1.12s/it]

-200.0


  2%|▏         | 8/500 [00:08<09:11,  1.12s/it]

-200.0


  2%|▏         | 9/500 [00:10<09:07,  1.11s/it]

-200.0


  2%|▏         | 10/500 [00:11<09:00,  1.10s/it]

-200.0


  2%|▏         | 11/500 [00:12<09:07,  1.12s/it]

-200.0


  2%|▏         | 12/500 [00:27<42:37,  5.24s/it]

-200.0


  3%|▎         | 13/500 [00:41<1:03:58,  7.88s/it]

-200.0


  3%|▎         | 14/500 [00:54<1:16:16,  9.42s/it]

-200.0


  3%|▎         | 15/500 [01:07<1:26:44, 10.73s/it]

-200.0


  3%|▎         | 16/500 [01:21<1:33:20, 11.57s/it]

-200.0


  3%|▎         | 17/500 [01:35<1:40:08, 12.44s/it]

-200.0


  4%|▎         | 18/500 [01:50<1:44:03, 12.95s/it]

-200.0


  4%|▍         | 19/500 [02:02<1:43:00, 12.85s/it]

-200.0


  4%|▍         | 20/500 [02:17<1:46:16, 13.28s/it]

-200.0


  4%|▍         | 21/500 [02:29<1:44:35, 13.10s/it]

-200.0


  4%|▍         | 22/500 [02:45<1:49:59, 13.81s/it]

-200.0


  5%|▍         | 23/500 [02:58<1:47:32, 13.53s/it]

-200.0


  5%|▍         | 24/500 [03:13<1:52:28, 14.18s/it]

-200.0


  5%|▌         | 25/500 [03:26<1:47:47, 13.62s/it]

-200.0


  5%|▌         | 26/500 [03:38<1:44:57, 13.29s/it]

-200.0


  5%|▌         | 27/500 [03:54<1:50:35, 14.03s/it]

-200.0


  6%|▌         | 28/500 [04:06<1:46:12, 13.50s/it]

-200.0


  6%|▌         | 29/500 [04:18<1:42:44, 13.09s/it]

-200.0


  6%|▌         | 30/500 [04:30<1:40:26, 12.82s/it]

-200.0


  6%|▌         | 31/500 [04:47<1:49:50, 14.05s/it]

-200.0


  6%|▋         | 32/500 [05:00<1:45:21, 13.51s/it]

-200.0


  7%|▋         | 33/500 [05:12<1:42:38, 13.19s/it]

-200.0


  7%|▋         | 34/500 [05:25<1:40:54, 12.99s/it]

-200.0


  7%|▋         | 35/500 [05:43<1:53:02, 14.59s/it]

-200.0


  7%|▋         | 36/500 [05:56<1:48:21, 14.01s/it]

-200.0


  7%|▋         | 37/500 [06:08<1:45:37, 13.69s/it]

-200.0


  8%|▊         | 38/500 [06:21<1:41:59, 13.25s/it]

-200.0


  8%|▊         | 39/500 [06:33<1:39:02, 12.89s/it]

-200.0


  8%|▊         | 40/500 [06:45<1:37:00, 12.65s/it]

-200.0


  8%|▊         | 41/500 [07:04<1:51:22, 14.56s/it]

-200.0


  8%|▊         | 42/500 [07:17<1:47:24, 14.07s/it]

-200.0


  9%|▊         | 43/500 [07:30<1:45:59, 13.92s/it]

-200.0


  9%|▉         | 44/500 [07:44<1:44:44, 13.78s/it]

-200.0


  9%|▉         | 45/500 [07:56<1:41:28, 13.38s/it]

-200.0


  9%|▉         | 46/500 [08:10<1:41:50, 13.46s/it]

-200.0


  9%|▉         | 47/500 [08:22<1:39:33, 13.19s/it]

-200.0


 10%|▉         | 48/500 [08:44<1:57:36, 15.61s/it]

-200.0


 10%|▉         | 49/500 [08:56<1:49:55, 14.62s/it]

-200.0


 10%|█         | 50/500 [09:08<1:44:46, 13.97s/it]

-200.0


 10%|█         | 51/500 [09:21<1:40:50, 13.48s/it]

-200.0


 10%|█         | 52/500 [09:33<1:37:50, 13.10s/it]

-200.0


 11%|█         | 53/500 [09:45<1:35:30, 12.82s/it]

-200.0


 11%|█         | 54/500 [09:57<1:33:56, 12.64s/it]

-200.0


 11%|█         | 55/500 [10:10<1:32:55, 12.53s/it]

-200.0


 11%|█         | 56/500 [10:33<1:55:55, 15.66s/it]

-200.0


 11%|█▏        | 57/500 [10:45<1:48:46, 14.73s/it]

-200.0


 12%|█▏        | 58/500 [10:58<1:43:10, 14.01s/it]

-200.0


 12%|█▏        | 59/500 [11:10<1:40:09, 13.63s/it]

-200.0


 12%|█▏        | 60/500 [11:24<1:39:26, 13.56s/it]

-200.0


 12%|█▏        | 61/500 [11:36<1:37:08, 13.28s/it]

-200.0


 12%|█▏        | 62/500 [11:49<1:36:34, 13.23s/it]

-200.0


 13%|█▎        | 63/500 [12:03<1:36:26, 13.24s/it]

-200.0


 13%|█▎        | 64/500 [12:19<1:42:00, 14.04s/it]

-200.0


 13%|█▎        | 65/500 [12:37<1:52:03, 15.46s/it]

-200.0


 13%|█▎        | 66/500 [12:58<2:02:15, 16.90s/it]

-200.0


 13%|█▎        | 67/500 [13:26<2:26:47, 20.34s/it]

-200.0


 14%|█▎        | 68/500 [13:39<2:10:50, 18.17s/it]

-200.0


 14%|█▍        | 69/500 [13:52<1:59:20, 16.61s/it]

-200.0


 14%|█▍        | 70/500 [14:05<1:51:09, 15.51s/it]

-200.0


 14%|█▍        | 71/500 [14:18<1:45:07, 14.70s/it]

-200.0


 14%|█▍        | 72/500 [14:30<1:40:26, 14.08s/it]

-200.0


 15%|█▍        | 73/500 [14:43<1:36:42, 13.59s/it]

-200.0


 15%|█▍        | 74/500 [14:56<1:34:23, 13.30s/it]

-200.0


 15%|█▌        | 75/500 [15:08<1:32:33, 13.07s/it]

-200.0


 15%|█▌        | 76/500 [15:21<1:31:55, 13.01s/it]

-200.0


 15%|█▌        | 77/500 [15:33<1:30:39, 12.86s/it]

-200.0


 16%|█▌        | 78/500 [15:46<1:29:16, 12.69s/it]

-200.0


 16%|█▌        | 79/500 [16:14<2:02:08, 17.41s/it]

-200.0


 16%|█▌        | 80/500 [16:27<1:51:17, 15.90s/it]

-200.0


 16%|█▌        | 81/500 [16:39<1:43:57, 14.89s/it]

-200.0


 16%|█▋        | 82/500 [16:52<1:39:10, 14.24s/it]

-200.0


 17%|█▋        | 83/500 [17:04<1:35:25, 13.73s/it]

-200.0


 17%|█▋        | 84/500 [17:17<1:32:38, 13.36s/it]

-200.0


 17%|█▋        | 85/500 [17:29<1:30:19, 13.06s/it]

-200.0


 17%|█▋        | 86/500 [17:42<1:28:48, 12.87s/it]

-200.0


 17%|█▋        | 87/500 [17:54<1:27:42, 12.74s/it]

-200.0


 18%|█▊        | 88/500 [18:07<1:27:05, 12.68s/it]

-200.0


 18%|█▊        | 89/500 [18:19<1:26:45, 12.66s/it]

-200.0


 18%|█▊        | 90/500 [18:32<1:25:57, 12.58s/it]

-200.0


 18%|█▊        | 91/500 [18:44<1:25:19, 12.52s/it]

-200.0


 18%|█▊        | 92/500 [18:57<1:25:14, 12.54s/it]

-200.0


 19%|█▊        | 93/500 [19:09<1:24:51, 12.51s/it]

-200.0


 19%|█▉        | 94/500 [19:21<1:24:20, 12.47s/it]

-200.0


 19%|█▉        | 95/500 [19:53<2:03:51, 18.35s/it]

-200.0


 19%|█▉        | 96/500 [20:06<1:51:58, 16.63s/it]

-200.0


 19%|█▉        | 97/500 [20:19<1:43:19, 15.38s/it]

-200.0


 20%|█▉        | 98/500 [20:31<1:37:11, 14.51s/it]

-200.0


 20%|█▉        | 99/500 [20:43<1:32:42, 13.87s/it]

-200.0


 20%|██        | 100/500 [20:56<1:30:08, 13.52s/it]

-200.0


 20%|██        | 101/500 [21:09<1:28:24, 13.29s/it]

-200.0


 20%|██        | 102/500 [21:21<1:26:26, 13.03s/it]

-200.0


 21%|██        | 103/500 [21:34<1:25:48, 12.97s/it]

-200.0


 21%|██        | 104/500 [21:46<1:24:26, 12.79s/it]

-200.0


 21%|██        | 105/500 [21:59<1:23:34, 12.69s/it]

-200.0


 21%|██        | 106/500 [22:11<1:22:53, 12.62s/it]

-200.0


 21%|██▏       | 107/500 [22:24<1:22:56, 12.66s/it]

-200.0


 22%|██▏       | 108/500 [22:37<1:22:10, 12.58s/it]

-200.0


 22%|██▏       | 109/500 [22:49<1:20:51, 12.41s/it]

-200.0


 22%|██▏       | 110/500 [23:01<1:21:19, 12.51s/it]

-200.0


 22%|██▏       | 111/500 [23:13<1:20:09, 12.36s/it]

-200.0


 22%|██▏       | 112/500 [23:26<1:19:51, 12.35s/it]

-200.0


 23%|██▎       | 113/500 [23:38<1:19:42, 12.36s/it]

-200.0


 23%|██▎       | 114/500 [23:50<1:19:02, 12.29s/it]

-200.0


 23%|██▎       | 115/500 [24:28<2:08:30, 20.03s/it]

-200.0


 23%|██▎       | 116/500 [24:41<1:53:20, 17.71s/it]

-200.0


 23%|██▎       | 117/500 [24:53<1:43:32, 16.22s/it]

-200.0


 24%|██▎       | 118/500 [25:06<1:37:05, 15.25s/it]

-200.0


 24%|██▍       | 119/500 [25:19<1:32:00, 14.49s/it]

-200.0


 24%|██▍       | 120/500 [25:32<1:28:11, 13.92s/it]

-200.0


 24%|██▍       | 121/500 [25:45<1:27:56, 13.92s/it]

-200.0


 24%|██▍       | 122/500 [25:59<1:26:17, 13.70s/it]

-200.0


 25%|██▍       | 123/500 [26:12<1:24:44, 13.49s/it]

-200.0


 25%|██▍       | 124/500 [26:24<1:22:04, 13.10s/it]

-200.0


 25%|██▌       | 125/500 [26:36<1:19:47, 12.77s/it]

-200.0


 25%|██▌       | 126/500 [26:48<1:18:01, 12.52s/it]

-200.0


 25%|██▌       | 127/500 [27:00<1:16:56, 12.38s/it]

-200.0


 26%|██▌       | 128/500 [27:12<1:16:07, 12.28s/it]

-200.0


 26%|██▌       | 129/500 [27:24<1:15:37, 12.23s/it]

-200.0


 26%|██▌       | 130/500 [27:36<1:15:00, 12.16s/it]

-200.0


 26%|██▌       | 131/500 [27:48<1:14:33, 12.12s/it]

-200.0


 26%|██▋       | 132/500 [28:01<1:15:20, 12.28s/it]

-200.0


 27%|██▋       | 133/500 [28:13<1:14:41, 12.21s/it]

-200.0


 27%|██▋       | 134/500 [28:25<1:14:17, 12.18s/it]

-200.0


 27%|██▋       | 135/500 [28:37<1:14:21, 12.22s/it]

-200.0


 27%|██▋       | 136/500 [28:50<1:15:26, 12.44s/it]

-200.0


 27%|██▋       | 137/500 [29:02<1:14:29, 12.31s/it]

-200.0


 28%|██▊       | 138/500 [29:46<2:10:48, 21.68s/it]

-200.0


 28%|██▊       | 139/500 [29:59<1:54:58, 19.11s/it]

-200.0


 28%|██▊       | 140/500 [30:13<1:45:22, 17.56s/it]

-200.0


 28%|██▊       | 141/500 [30:25<1:36:29, 16.13s/it]

-200.0


 28%|██▊       | 142/500 [30:38<1:29:13, 14.95s/it]

-200.0


 29%|██▊       | 143/500 [30:50<1:24:07, 14.14s/it]

-200.0


 29%|██▉       | 144/500 [31:02<1:20:45, 13.61s/it]

-200.0


 29%|██▉       | 145/500 [31:15<1:18:48, 13.32s/it]

-200.0


 29%|██▉       | 146/500 [31:27<1:16:45, 13.01s/it]

-200.0


 29%|██▉       | 147/500 [31:39<1:15:08, 12.77s/it]

-200.0


 30%|██▉       | 148/500 [31:52<1:13:55, 12.60s/it]

-200.0


 30%|██▉       | 149/500 [32:04<1:13:09, 12.51s/it]

-200.0


 30%|███       | 150/500 [32:16<1:12:34, 12.44s/it]

-200.0


 30%|███       | 151/500 [32:29<1:12:20, 12.44s/it]

-200.0


 30%|███       | 152/500 [32:41<1:11:51, 12.39s/it]

-200.0


 31%|███       | 153/500 [32:53<1:11:43, 12.40s/it]

-200.0


 31%|███       | 154/500 [33:06<1:12:17, 12.54s/it]

-200.0


 31%|███       | 155/500 [33:19<1:12:39, 12.64s/it]

-200.0


 31%|███       | 156/500 [33:32<1:12:14, 12.60s/it]

-200.0


 31%|███▏      | 157/500 [33:45<1:13:10, 12.80s/it]

-200.0


 32%|███▏      | 158/500 [33:57<1:12:38, 12.75s/it]

-200.0


 32%|███▏      | 159/500 [34:10<1:12:35, 12.77s/it]

-200.0


 32%|███▏      | 160/500 [34:24<1:14:10, 13.09s/it]

-200.0


 32%|███▏      | 161/500 [34:38<1:15:00, 13.27s/it]

-200.0


 32%|███▏      | 162/500 [34:52<1:15:40, 13.43s/it]

-200.0


 33%|███▎      | 163/500 [35:04<1:14:17, 13.23s/it]

-200.0


 33%|███▎      | 164/500 [35:17<1:13:16, 13.09s/it]

-200.0


 33%|███▎      | 165/500 [35:30<1:13:18, 13.13s/it]

-200.0


 33%|███▎      | 166/500 [35:43<1:12:22, 13.00s/it]

-200.0


 33%|███▎      | 167/500 [35:56<1:11:43, 12.92s/it]

-200.0


 34%|███▎      | 168/500 [38:15<4:40:43, 50.73s/it]

-200.0


 34%|███▍      | 169/500 [38:28<3:37:56, 39.51s/it]

-200.0


 34%|███▍      | 170/500 [38:41<2:53:58, 31.63s/it]

-200.0


 34%|███▍      | 171/500 [38:55<2:24:03, 26.27s/it]

-200.0


 34%|███▍      | 172/500 [39:08<2:02:16, 22.37s/it]

-200.0


 35%|███▍      | 173/500 [39:22<1:48:12, 19.86s/it]

-200.0


 35%|███▍      | 174/500 [39:35<1:36:44, 17.81s/it]

-200.0


 35%|███▌      | 175/500 [39:48<1:28:28, 16.33s/it]

-200.0


 35%|███▌      | 176/500 [40:01<1:22:40, 15.31s/it]

-200.0


 35%|███▌      | 177/500 [40:14<1:18:55, 14.66s/it]

-200.0


 36%|███▌      | 178/500 [40:28<1:16:35, 14.27s/it]

-200.0


 36%|███▌      | 179/500 [40:41<1:14:41, 13.96s/it]

-200.0


 36%|███▌      | 180/500 [40:54<1:13:16, 13.74s/it]

-200.0


 36%|███▌      | 181/500 [41:10<1:16:49, 14.45s/it]

-200.0


 36%|███▋      | 182/500 [41:23<1:14:08, 13.99s/it]

-200.0


 37%|███▋      | 183/500 [41:36<1:12:15, 13.68s/it]

-200.0


 37%|███▋      | 184/500 [41:49<1:11:00, 13.48s/it]

-200.0


 37%|███▋      | 185/500 [42:03<1:10:47, 13.48s/it]

-200.0


 37%|███▋      | 186/500 [43:13<2:40:09, 30.60s/it]

-200.0


 37%|███▋      | 187/500 [43:26<2:11:16, 25.16s/it]

-200.0


 38%|███▊      | 188/500 [43:38<1:50:48, 21.31s/it]

-200.0


 38%|███▊      | 189/500 [43:50<1:36:36, 18.64s/it]

-200.0


 38%|███▊      | 190/500 [44:03<1:27:22, 16.91s/it]

-200.0


 38%|███▊      | 191/500 [44:16<1:20:03, 15.54s/it]

-200.0


 38%|███▊      | 192/500 [44:28<1:15:11, 14.65s/it]

-200.0


 39%|███▊      | 193/500 [44:41<1:11:25, 13.96s/it]

-200.0


 39%|███▉      | 194/500 [44:53<1:09:13, 13.57s/it]

-200.0


 39%|███▉      | 195/500 [45:06<1:07:17, 13.24s/it]

-200.0


 39%|███▉      | 196/500 [45:19<1:06:32, 13.13s/it]

-200.0


 39%|███▉      | 197/500 [45:31<1:05:27, 12.96s/it]

-200.0


 40%|███▉      | 198/500 [45:44<1:04:43, 12.86s/it]

-200.0


 40%|███▉      | 199/500 [45:56<1:04:10, 12.79s/it]

-200.0


 40%|████      | 200/500 [46:10<1:04:30, 12.90s/it]

-200.0


 40%|████      | 201/500 [46:22<1:04:03, 12.86s/it]

-200.0


 40%|████      | 202/500 [46:35<1:03:34, 12.80s/it]

-200.0


 41%|████      | 203/500 [46:48<1:03:05, 12.74s/it]

-200.0


 41%|████      | 204/500 [54:37<12:18:23, 149.67s/it]

-200.0


 41%|████      | 205/500 [54:51<8:55:36, 108.94s/it] 

-200.0


 41%|████      | 206/500 [55:04<6:33:38, 80.34s/it] 

-200.0


 41%|████▏     | 207/500 [55:21<4:59:24, 61.31s/it]

-200.0


 42%|████▏     | 208/500 [55:34<3:48:17, 46.91s/it]

-200.0


 42%|████▏     | 209/500 [55:48<2:59:22, 36.99s/it]

-200.0


 42%|████▏     | 210/500 [56:02<2:25:02, 30.01s/it]

-200.0


 42%|████▏     | 211/500 [56:15<2:00:34, 25.03s/it]

-200.0


 42%|████▏     | 212/500 [56:30<1:45:11, 21.92s/it]

-200.0


 43%|████▎     | 213/500 [56:43<1:32:15, 19.29s/it]

-200.0


 43%|████▎     | 214/500 [56:56<1:23:16, 17.47s/it]

-200.0


 43%|████▎     | 215/500 [57:10<1:16:55, 16.19s/it]

-200.0


 43%|████▎     | 216/500 [57:23<1:13:07, 15.45s/it]

-200.0


 43%|████▎     | 217/500 [57:37<1:09:39, 14.77s/it]

-200.0


 44%|████▎     | 218/500 [57:51<1:08:32, 14.58s/it]

-200.0


 44%|████▍     | 219/500 [58:04<1:06:27, 14.19s/it]

-200.0


 44%|████▍     | 220/500 [58:17<1:04:36, 13.84s/it]

-200.0


 44%|████▍     | 221/500 [58:31<1:03:49, 13.73s/it]

-200.0


 44%|████▍     | 222/500 [58:44<1:02:51, 13.57s/it]

-200.0


 45%|████▍     | 223/500 [58:58<1:03:17, 13.71s/it]

-200.0


 45%|████▍     | 224/500 [59:11<1:01:55, 13.46s/it]

-200.0


 45%|████▌     | 225/500 [59:24<1:01:08, 13.34s/it]

-200.0


 45%|████▌     | 226/500 [59:37<1:00:30, 13.25s/it]

-200.0


 45%|████▌     | 227/500 [59:50<1:00:05, 13.21s/it]

-200.0


 46%|████▌     | 228/500 [1:00:03<1:00:28, 13.34s/it]

-200.0


 46%|████▌     | 229/500 [1:00:17<59:48, 13.24s/it]  

-200.0


 46%|████▌     | 230/500 [1:02:01<3:03:23, 40.75s/it]

-200.0


 46%|████▌     | 231/500 [1:02:14<2:24:11, 32.16s/it]

-200.0


 46%|████▋     | 232/500 [1:02:26<1:56:42, 26.13s/it]

-200.0


 47%|████▋     | 233/500 [1:02:38<1:37:25, 21.89s/it]

-200.0


 47%|████▋     | 234/500 [1:02:50<1:23:58, 18.94s/it]

-200.0


 47%|████▋     | 235/500 [1:03:02<1:14:33, 16.88s/it]

-200.0


 47%|████▋     | 236/500 [1:03:14<1:07:56, 15.44s/it]

-200.0


 47%|████▋     | 237/500 [1:03:26<1:03:22, 14.46s/it]

-200.0


 48%|████▊     | 238/500 [1:03:39<1:00:55, 13.95s/it]

-200.0


 48%|████▊     | 239/500 [1:03:51<58:33, 13.46s/it]  

-200.0


 48%|████▊     | 240/500 [1:04:03<56:42, 13.08s/it]

-200.0


 48%|████▊     | 241/500 [1:04:16<56:00, 12.97s/it]

-200.0


 48%|████▊     | 242/500 [1:04:28<55:06, 12.82s/it]

-200.0


 49%|████▊     | 243/500 [1:04:41<54:27, 12.71s/it]

-200.0


 49%|████▉     | 244/500 [1:04:54<54:37, 12.80s/it]

-200.0


 49%|████▉     | 245/500 [1:05:06<53:31, 12.59s/it]

-200.0


 49%|████▉     | 246/500 [1:05:18<52:44, 12.46s/it]

-200.0


 49%|████▉     | 247/500 [1:05:30<52:04, 12.35s/it]

-200.0


 50%|████▉     | 248/500 [1:22:42<22:16:52, 318.30s/it]

-200.0


 50%|████▉     | 249/500 [1:22:56<15:48:43, 226.79s/it]

-200.0


 50%|█████     | 250/500 [1:23:09<11:17:54, 162.70s/it]

-200.0


 50%|█████     | 251/500 [1:23:22<8:08:53, 117.81s/it] 

-200.0


 50%|█████     | 252/500 [1:23:37<5:59:38, 87.01s/it] 

-200.0


 51%|█████     | 253/500 [1:23:51<4:27:35, 65.00s/it]

-200.0


 51%|█████     | 254/500 [1:24:04<3:22:16, 49.33s/it]

-200.0


 51%|█████     | 255/500 [1:24:16<2:36:44, 38.39s/it]

-200.0


 51%|█████     | 256/500 [1:24:29<2:05:14, 30.80s/it]

-200.0


 51%|█████▏    | 257/500 [1:24:43<1:44:06, 25.71s/it]

-200.0


 52%|█████▏    | 258/500 [1:24:58<1:30:33, 22.45s/it]

-200.0


 52%|█████▏    | 259/500 [1:25:12<1:19:45, 19.86s/it]

-200.0


 52%|█████▏    | 260/500 [1:25:26<1:12:20, 18.08s/it]

-200.0


 52%|█████▏    | 261/500 [1:25:39<1:05:42, 16.50s/it]

-200.0


 52%|█████▏    | 262/500 [1:25:52<1:01:08, 15.42s/it]

-200.0


 53%|█████▎    | 263/500 [1:26:05<57:58, 14.68s/it]  

-200.0


 53%|█████▎    | 264/500 [1:26:18<56:21, 14.33s/it]

-200.0


 53%|█████▎    | 265/500 [1:26:31<54:18, 13.87s/it]

-200.0


 53%|█████▎    | 266/500 [1:26:44<52:45, 13.53s/it]

-200.0


 53%|█████▎    | 267/500 [1:26:56<51:38, 13.30s/it]

-200.0


 54%|█████▎    | 268/500 [1:27:09<51:10, 13.23s/it]

-200.0


 54%|█████▍    | 269/500 [1:27:23<50:49, 13.20s/it]

-200.0


 54%|█████▍    | 270/500 [1:27:36<50:40, 13.22s/it]

-200.0


 54%|█████▍    | 271/500 [1:27:49<50:02, 13.11s/it]

-200.0


 54%|█████▍    | 272/500 [1:28:02<49:53, 13.13s/it]

-200.0


 55%|█████▍    | 273/500 [1:28:15<49:48, 13.16s/it]

-200.0


 55%|█████▍    | 274/500 [1:28:29<50:03, 13.29s/it]

-200.0


 55%|█████▌    | 275/500 [1:28:41<49:06, 13.09s/it]

-200.0


 55%|█████▌    | 276/500 [1:28:54<48:42, 13.05s/it]

-200.0


 55%|█████▌    | 277/500 [1:29:09<50:29, 13.59s/it]

-200.0


 56%|█████▌    | 278/500 [1:29:23<50:36, 13.68s/it]

-200.0


 56%|█████▌    | 279/500 [1:29:37<50:30, 13.71s/it]

-200.0


 56%|█████▌    | 280/500 [1:29:51<50:41, 13.83s/it]

-200.0


 56%|█████▌    | 281/500 [1:30:05<50:44, 13.90s/it]

-200.0


 56%|█████▋    | 282/500 [1:30:20<51:22, 14.14s/it]

-200.0


 57%|█████▋    | 283/500 [1:30:33<50:41, 14.02s/it]

-200.0


 57%|█████▋    | 284/500 [1:30:47<49:59, 13.89s/it]

-200.0


 57%|█████▋    | 285/500 [1:31:00<48:29, 13.53s/it]

-200.0


 57%|█████▋    | 286/500 [1:31:13<47:33, 13.33s/it]

-200.0


 57%|█████▋    | 287/500 [1:31:26<47:09, 13.28s/it]

-200.0


 58%|█████▊    | 288/500 [1:31:39<47:22, 13.41s/it]

-200.0


 58%|█████▊    | 289/500 [1:31:52<46:11, 13.14s/it]

-200.0


 58%|█████▊    | 290/500 [1:32:05<45:30, 13.00s/it]

-200.0


 58%|█████▊    | 291/500 [1:32:17<45:05, 12.94s/it]

-200.0


 58%|█████▊    | 292/500 [1:32:30<45:01, 12.99s/it]

-200.0


 59%|█████▊    | 293/500 [1:32:43<44:37, 12.94s/it]

-200.0


 59%|█████▉    | 294/500 [1:32:56<44:07, 12.85s/it]

-200.0


 59%|█████▉    | 295/500 [1:33:09<44:16, 12.96s/it]

-200.0


 59%|█████▉    | 296/500 [1:33:22<44:26, 13.07s/it]

-200.0


 59%|█████▉    | 297/500 [1:33:35<43:57, 12.99s/it]

-200.0


 60%|█████▉    | 298/500 [1:33:48<43:29, 12.92s/it]

-200.0


 60%|█████▉    | 299/500 [1:34:02<44:07, 13.17s/it]

-200.0


 60%|██████    | 300/500 [1:34:15<43:28, 13.04s/it]

-200.0


 60%|██████    | 301/500 [1:34:27<43:02, 12.98s/it]

-200.0


 60%|██████    | 302/500 [1:34:40<42:41, 12.94s/it]

-200.0


 61%|██████    | 303/500 [2:15:01<40:14:15, 735.31s/it]

-200.0


 61%|██████    | 304/500 [2:15:15<28:15:12, 518.94s/it]

-200.0


 61%|██████    | 305/500 [2:15:29<19:53:53, 367.35s/it]

-200.0


 61%|██████    | 306/500 [2:15:44<14:06:23, 261.77s/it]

-200.0


 61%|██████▏   | 307/500 [2:15:58<10:02:24, 187.28s/it]

-200.0


 62%|██████▏   | 308/500 [2:16:11<7:12:16, 135.08s/it] 

-200.0


 62%|██████▏   | 309/500 [2:16:24<5:13:38, 98.52s/it] 

-200.0


 62%|██████▏   | 310/500 [2:16:37<3:50:52, 72.91s/it]

-200.0


 62%|██████▏   | 311/500 [2:16:51<2:53:25, 55.05s/it]

-200.0


 62%|██████▏   | 312/500 [2:17:04<2:13:09, 42.50s/it]

-200.0


 63%|██████▎   | 313/500 [2:17:18<1:45:57, 34.00s/it]

-200.0


 63%|██████▎   | 314/500 [2:17:31<1:26:06, 27.78s/it]

-200.0


 63%|██████▎   | 315/500 [2:17:44<1:12:07, 23.39s/it]

-200.0


 63%|██████▎   | 316/500 [2:17:58<1:02:14, 20.29s/it]

-200.0


 63%|██████▎   | 317/500 [2:18:13<57:12, 18.76s/it]  

-200.0


 64%|██████▎   | 318/500 [2:18:26<51:46, 17.07s/it]

-200.0


 64%|██████▍   | 319/500 [2:18:39<47:46, 15.83s/it]

-200.0


 64%|██████▍   | 320/500 [2:18:52<44:52, 14.96s/it]

-200.0


 64%|██████▍   | 321/500 [2:19:05<42:43, 14.32s/it]

-200.0


 64%|██████▍   | 322/500 [2:19:18<41:31, 14.00s/it]

-200.0


 65%|██████▍   | 323/500 [2:19:31<40:21, 13.68s/it]

-200.0


 65%|██████▍   | 324/500 [2:19:44<39:47, 13.57s/it]

-200.0


 65%|██████▌   | 325/500 [2:19:57<39:18, 13.48s/it]

-200.0


 65%|██████▌   | 326/500 [2:20:10<38:30, 13.28s/it]

-200.0


 65%|██████▌   | 327/500 [2:20:23<38:21, 13.30s/it]

-200.0


 66%|██████▌   | 328/500 [2:20:37<38:13, 13.33s/it]

-200.0


 66%|██████▌   | 329/500 [2:20:50<37:54, 13.30s/it]

-200.0


 66%|██████▌   | 330/500 [2:21:04<37:49, 13.35s/it]

-200.0


 66%|██████▌   | 331/500 [2:21:18<38:14, 13.58s/it]

-200.0


 66%|██████▋   | 332/500 [2:21:38<43:18, 15.47s/it]

-200.0


 67%|██████▋   | 333/500 [2:21:56<45:37, 16.39s/it]

-200.0


 67%|██████▋   | 334/500 [2:22:11<44:16, 16.00s/it]

-200.0


 67%|██████▋   | 335/500 [2:22:26<42:45, 15.55s/it]

-200.0


 67%|██████▋   | 336/500 [2:22:40<41:25, 15.15s/it]

-200.0


 67%|██████▋   | 337/500 [2:22:53<39:18, 14.47s/it]

-200.0


 68%|██████▊   | 338/500 [2:23:05<37:18, 13.82s/it]

-200.0


 68%|██████▊   | 339/500 [2:23:17<35:54, 13.38s/it]

-200.0


 68%|██████▊   | 340/500 [2:23:30<34:51, 13.07s/it]

-200.0


 68%|██████▊   | 341/500 [2:23:42<34:04, 12.86s/it]

-200.0


 68%|██████▊   | 342/500 [2:23:54<33:23, 12.68s/it]

-200.0


 69%|██████▊   | 343/500 [2:24:07<32:54, 12.58s/it]

-200.0


 69%|██████▉   | 344/500 [2:24:19<32:27, 12.48s/it]

-200.0


 69%|██████▉   | 345/500 [2:24:32<32:26, 12.56s/it]

-200.0


 69%|██████▉   | 346/500 [2:24:44<32:04, 12.50s/it]

-200.0


 69%|██████▉   | 347/500 [2:24:56<31:42, 12.44s/it]

-200.0


 70%|██████▉   | 348/500 [2:25:09<31:24, 12.40s/it]

-200.0


 70%|██████▉   | 349/500 [2:25:21<31:06, 12.36s/it]

-200.0


 70%|███████   | 350/500 [2:25:33<30:54, 12.36s/it]

-200.0


 70%|███████   | 351/500 [2:25:46<30:35, 12.32s/it]

-200.0


 70%|███████   | 352/500 [2:25:58<30:23, 12.32s/it]

-200.0


 71%|███████   | 353/500 [2:26:10<30:09, 12.31s/it]

-200.0


 71%|███████   | 354/500 [2:26:23<29:59, 12.33s/it]

-200.0


 71%|███████   | 355/500 [2:26:35<29:45, 12.31s/it]

-200.0


 71%|███████   | 356/500 [2:26:47<29:35, 12.33s/it]

-200.0


In [None]:
env.close()

In [None]:
#lets test the nn
for i in tqdm(range(500)):
    observation = env.reset()
    observation = np.expand_dims(observation, axis=0)
    total_reward = 0
    for j in range(1000):
        env.render()
        nn_out = model_2.predict(observation)
        action = policy(nn_out,0)
        print(nn_out[0])
        print(action)
        observation,reward,done,info = env.step(action)
        observation = np.expand_dims(observation, axis=0)
        total_reward = total_reward + reward
        if done:
            print("episode ended")
            break
env.close()

In [None]:
env.close()

In [None]:
# random action
for i in tqdm(range(50)):
    observation = env.reset()
    observation = np.expand_dims(observation, axis=0)
    for j in range(1000):
        env.render()
        observation,reward,done,info = env.step(env.action_space.sample())
        print(reward,done)
        if done:
            break