In [1]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import gym
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [2]:
def buildModel():
    inputs = tf.keras.Input(shape=(1,4))
    x = layers.Dense(32,activation='relu')(inputs)
    x = layers.Dense(32,activation='relu')(x)
    x = layers.Dense(2)(x)

    return tf.keras.Model(inputs = inputs, outputs = x) 

In [3]:
env = gym.make('CartPole-v1')

model = buildModel()
loss = tf.keras.losses.MeanSquaredError()
opt = tf.keras.optimizers.Adam(1e-2)

In [4]:
gamma = 0.9
e = 0.9

In [7]:
for epoch in range(5000):

    done = False
    state = np.array(env.reset())
    totalReward = 0
    
    while not done:
        with tf.GradientTape() as tape:
            q = model(state[None,:])
            if np.random.rand() < e:
                a = np.random.randint(low=0, high=env.action_space.n)
            else:
                a = tf.argmax(tf.squeeze(q)).numpy()

            state,reward,done,info = env.step(a)

            q_pred = q[None,0,a]
            q_next = model(state[None,:])
            q_true = reward + gamma * (1. - done) * tf.math.reduce_max(q_next, axis=1)
            loss_val = loss(tf.stop_gradient(q_true),q_pred)
            grads = tape.gradient(loss_val,model.trainable_weights)
            opt.apply_gradients(zip(grads, model.trainable_weights))

            e *= 0.999
            e = 0.05 if e < 0.05 else e
            totalReward += reward

    print(epoch,"   ",totalReward)

env.close()

85.0
1921     90.0
1922     104.0
1923     107.0
1924     115.0
1925     118.0
1926     108.0
1927     107.0
1928     107.0
1929     106.0
1930     107.0
1931     114.0
1932     105.0
1933     108.0
1934     114.0
1935     94.0
1936     100.0
1937     11.0
1938     78.0
1939     102.0
1940     100.0
1941     103.0
1942     107.0
1943     103.0
1944     110.0
1945     104.0
1946     84.0
1947     101.0
1948     9.0
1949     10.0
1950     11.0
1951     10.0
1952     37.0
1953     116.0
1954     17.0
1955     11.0
1956     16.0
1957     17.0
1958     101.0
1959     105.0
1960     144.0
1961     102.0
1962     107.0
1963     108.0
1964     104.0
1965     106.0
1966     15.0
1967     9.0
1968     10.0
1969     131.0
1970     10.0
1971     10.0
1972     73.0
1973     99.0
1974     102.0
1975     64.0
1976     89.0
1977     70.0
1978     99.0
1979     79.0
1980     96.0
1981     55.0
1982     93.0
1983     48.0
1984     63.0
1985     92.0
1986     65.0
1987     67.0
1988     103.0
1989     13