In [1]:
import gym
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from collections import deque

In [2]:
# keras model approach
from tensorflow.keras import Model,Sequential
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Flatten,BatchNormalization,Dense, Input
from tensorflow.keras.activations import relu

In [3]:
# ADDING THE CODE SO THAT TENSORFLOW DOES NOT EAT THE WHOLE GPU MEMORY
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.keras.backend.set_floatx('float32')

In [4]:
env = gym.make('CartPole-v1')
env.action_space

Discrete(2)

In [5]:
adam = tf.keras.optimizers.Adam(learning_rate = 0.001)

In [6]:
def model_keras():
    
    inputs = Input(shape=(4,))
    
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(inputs)
    x = BatchNormalization()(x)
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(x)
    x = BatchNormalization()(x)
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(x)
    x = BatchNormalization()(x)
    x = Dense(100,activation='relu',kernel_initializer="glorot_uniform")(x)
    x = BatchNormalization()(x)
    output = Dense(2,activation='linear',kernel_initializer="glorot_uniform")(x)
    model = Model(inputs=inputs, outputs=output, name="RL_Value_Function")
    
    print(model.summary())
    
    #model.compile(optimizer=adam,loss='mean_squared_error',metrics=['mean_squared_error'])
    
    return model

def model_keras():
    model = Sequential()
    model.add(Dense(30, input_dim=4, activation='relu'))
    model.add(Dense(30, activation='relu'))
    model.add(Dense(2, activation='linear'))
    model.summary()
    model.compile(loss='mse', optimizer=adam)
    return model
sample_model = model_keras()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 30)                150       
_________________________________________________________________
dense_1 (Dense)              (None, 30)                930       
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 62        
Total params: 1,142
Trainable params: 1,142
Non-trainable params: 0
_________________________________________________________________


In [7]:
def create_y_true(current_q,current_r,lamb):
    current_r = np.array([current_r,]*2)
    current_r = current_r.reshape(32,2)
    return tf.math.multiply(current_q,lamb) + current_r

def custom_loss(y_true,y_pred):
    return tf.keras.losses.mean_squared_error(y_true,y_pred)

In [8]:
z = np.array([[1,2],[4,2],[5,1]])
z[2][1]

1

In [9]:

warmup = 10 #will start training after these many episodes have passed
training_count = 0 # this is a counter which is updated each time batch training is done
# after certain number of batches we remove the old data (in the starting of the list)

# to balance exploration
epsilon = {
"epsilon" : 1.0,
"epsilon_decay": 0.999,
"epsilon_min":0.01,
}

In [10]:
import random
random.seed(2020)


#@tf.function
def batch_train(model,gamma,SARSA):
    
    #decaying the exploration
    if epsilon['epsilon'] > epsilon['epsilon_min']:
         epsilon['epsilon'] =  epsilon['epsilon'] * epsilon['epsilon_decay']
    
    curr_state = SARSA[0]
    action = SARSA[1]
    reward = SARSA[2]
    next_state = SARSA[3]
    next_action = SARSA[4]
    q_next = model.predict(next_state)[0]
    target = reward + q_next[next_action]*gamma

    
    done = SARSA[5]
    if done:
        target = reward
    
    with tf.GradientTape() as tape:
        # logits is the forward pass
        logits = model(curr_state, training=True)
        
        q_target = np.array(logits)
        q_target[0][action] = target
        
        loss_value = custom_loss(q_target,logits)

    #we retrieve the gradients
    grads = tape.gradient(loss_value, model.trainable_weights)
    
    #THIS IS ONE STEP OF GRAD DESCENT (Minimizes the loss)
    adam.apply_gradients(zip(grads, model.trainable_weights))

def policy(q_vals,eps):
    # lets implement a policy which decays
    if np.random.rand() <= eps:  
        return random.randrange(2)
    else:
        action = np.argmax(q_vals[0])
        return action

In [19]:
for i in tqdm(range(10)):
    observation = env.reset()
    observation = np.expand_dims(observation, axis=0)
    total_reward = 0
    for j in range(400):
        #storing the current state
        state_1 = observation
        
        # this is the current q values
        q_state = sample_model(observation,training = False)
        
        action = policy(q_state,epsilon["epsilon"]) # current action
        
        observation,reward,done,info = env.step(action)
        
        # calculating the total reward
        total_reward = total_reward + reward
        
        
        #if done and j<195:
        #    reward = -1000
        observation = np.expand_dims(observation, axis=0)
        state_2 = observation
        state_reward = reward
        
        action_2 = policy(sample_model(observation,training = False),epsilon["epsilon"])
        
        SARSA = (state_1,action,state_reward,state_2,action_2,done)

        batch_train(sample_model,0.99,SARSA)
        
        
        if done:
            break
        if i >1150:
            env.render()
    print(total_reward)



  0%|          | 0/10 [00:00<?, ?it/s][A[A

 10%|█         | 1/10 [00:04<00:38,  4.24s/it][A[A

139.0




 20%|██        | 2/10 [00:08<00:34,  4.26s/it][A[A

141.0




 30%|███       | 3/10 [00:12<00:29,  4.25s/it][A[A

136.0




 40%|████      | 4/10 [00:16<00:25,  4.24s/it][A[A

133.0




 50%|█████     | 5/10 [00:21<00:21,  4.30s/it][A[A

145.0




 60%|██████    | 6/10 [00:26<00:17,  4.49s/it][A[A

129.0




 70%|███████   | 7/10 [00:30<00:13,  4.42s/it][A[A

130.0




 80%|████████  | 8/10 [00:34<00:08,  4.34s/it][A[A

136.0




 90%|█████████ | 9/10 [00:38<00:04,  4.16s/it][A[A

123.0




100%|██████████| 10/10 [00:42<00:00,  4.25s/it][A[A

133.0





In [14]:
env.close()

In [20]:
#lets test the nn
for i in tqdm(range(500)):
    observation = env.reset()
    observation = np.expand_dims(observation, axis=0)
    total_reward = 0
    done =False
    while not done:
        env.render()
        nn_out = sample_model.predict(observation)
        print(nn_out)
        action = policy(nn_out,0)
        print(action)
        observation,reward,done,info = env.step(action)
        observation = np.expand_dims(observation, axis=0)
        total_reward = total_reward + reward
        
env.close()



  0%|          | 0/500 [00:00<?, ?it/s][A[A

[[68.29171 71.2475 ]]
1
[[71.219475 72.037384]]
1
[[72.028305 69.42623 ]]
0
[[70.999466 71.502426]]
1
[[71.70239 68.75582]]
0
[[70.776886 70.807625]]
1
[[71.32012 67.89142]]
0
[[70.54251  69.934616]]
0
[[68.367096 70.20744 ]]
1
[[70.453766 69.29199 ]]
0
[[68.541794 70.018974]]
1
[[70.338356 68.5389  ]]
0
[[68.731476 69.79105 ]]
1
[[70.016106 67.60812 ]]
0
[[68.75289 69.24564]]
1
[[69.33293 66.39349]]
0
[[68.5649  68.13679]]
0
[[66.22346 67.96844]]
1
[[68.31078  67.238365]]
0
[[66.52211  67.776726]]
1
[[67.98934 66.23084]]
0
[[66.76691  67.522736]]
1
[[67.41022  64.990776]]
0
[[66.42806 66.56243]]
1
[[66.61086 63.51705]]
0
[[65.90555 65.17064]]
0
[[64.23615 65.55938]]
1
[[65.561874 64.023575]]
0
[[64.32151 65.25816]]
1
[[64.88499  62.616646]]
0
[[63.94938 64.15396]]
1
[[64.013306 61.00159 ]]
0
[[63.38495 62.64316]]
0
[[61.81084  63.228535]]
1
[[62.927784 61.345356]]
0
[[61.82258 62.79212]]
1
[[62.09292 59.77046]]
0
[[61.295803 61.339687]]
1
[[61.141605 58.035824]]
0
[[60.67048  59.72237



  0%|          | 1/500 [00:05<48:51,  5.87s/it][A[A


1
[[-0.6132529 -2.2028565]]
0
[[67.52624 70.01689]]
1
[[70.0721   69.979675]]
0
[[67.622375 69.88743 ]]
1
[[69.99712 69.50328]]
0
[[67.731125 69.73344 ]]
1
[[69.8989   68.938385]]
0
[[67.85418 69.55068]]
1
[[69.7755   68.271965]]
0
[[67.993576 69.334305]]
1
[[69.62653 67.4901 ]]
0
[[68.102   69.01302]]
1
[[69.051994 66.437416]]
0
[[67.94651 68.11369]]
1
[[68.33197 65.18113]]
0
[[67.640335 66.94645 ]]
0
[[65.57897 67.18477]]
1
[[67.34917 65.99751]]
0
[[65.856636 66.962776]]
1
[[66.94524 64.90834]]
0
[[65.841965 66.44848 ]]
1
[[66.20616 63.53735]]
0
[[65.36281 65.15886]]
0
[[63.450275 65.10783 ]]
1
[[65.070015 64.13484 ]]
0
[[63.54347 64.85349]]
1
[[64.6897   63.004646]]
0
[[63.54637 64.48137]]
1
[[63.93001  61.587944]]
0
[[63.065086 63.184254]]
1
[[63.058704 60.008038]]
0
[[62.48743  61.701294]]
0
[[60.990887 62.476173]]
1
[[61.982197 60.41382 ]]
0
[[60.913315 61.90226 ]]
1
[[61.157368 58.88318 ]]
0
[[60.383137 60.4883  ]]
1
[[60.220745 57.19866 ]]
0
[[59.758602 58.91572 ]]
0
[[58.4304



  0%|          | 2/500 [00:10<46:02,  5.55s/it][A[A

[[ 0.9951669  -0.31081414]]
0
[[-0.10017887 -0.06711507]]
1
[[-0.26048377 -2.1429253 ]]
0
[[67.38099 69.69462]]
1
[[69.6271 69.1015]]
0
[[67.43058 69.52997]]
1
[[69.50809 68.59569]]
0
[[67.50208 69.34173]]
1
[[69.3727   67.993004]]
0
[[67.596115 69.12482 ]]
1
[[69.21681 67.2788 ]]
0
[[67.67967  68.828354]]
1
[[68.853226 66.35148 ]]
0
[[67.522804 68.00906 ]]
1
[[68.179276 65.19217 ]]
0
[[67.309586 66.94098 ]]
0
[[65.05331 67.0125 ]]
1
[[67.16101  66.109344]]
0
[[65.30002 66.81299]]
1
[[66.83705 65.15436]]
0
[[65.536   66.56247]]
1
[[66.2657   63.976517]]
0
[[65.28276  65.581665]]
1
[[65.497604 62.583412]]
0
[[64.771904 64.26355 ]]
0
[[63.059864 64.64059 ]]
1
[[64.44666  63.202034]]
0
[[63.146786 64.361015]]
1
[[63.89057  61.946674]]
0
[[62.876682 63.492805]]
1
[[63.0847   60.465412]]
0
[[62.350758 62.108593]]
0
[[60.605206 62.408623]]
1
[[62.007    60.994205]]
0
[[60.670795 62.116596]]
1
[[61.396675 59.682503]]
0
[[60.406906 61.231968]]
1
[[60.56935  58.178677]]
0
[[59.862373 59.835743]

KeyboardInterrupt: 

In [17]:
env.close()

In [None]:
# random action
for i in tqdm(range(50)):
    observation = env.reset()
    observation = np.expand_dims(observation, axis=0)
    for j in range(1000):
        env.render()
        observation,reward,done,info = env.step(env.action_space.sample())
        print(reward,done)
        if done:
            break