In [63]:
import random
import gym
import numpy as np
from collections import deque
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

In [64]:
EPOCHS = 500
THRESHOLD = 1000

In [65]:
class DQN():
    def __init__(self, env_string,batch_size=64):
        self.memory = deque(maxlen=100000)
        self.env = gym.make(env_string)
        input_size = self.env.observation_space.shape[0]
        self.action = np.array([-2,0,2], dtype=np.float32)
        action_size = len(self.action)
        self.batch_size = batch_size
        self.gamma = 1.0
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        
        alpha=0.01
        alpha_decay=0.01
        
        # Init model
        self.model = Sequential()
        self.model.add(Dense(24, input_dim=input_size, activation='tanh'))
        self.model.add(Dense(48, activation='tanh'))
        self.model.add(Dense(action_size, activation='linear'))
        self.model.compile(loss='mse', optimizer=Adam(learning_rate=alpha))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def choose_action(self, state, epsilon):
        if np.random.random() <= epsilon:
            return [np.random.choice(self.action)]
        else:
            return [self.action[np.argmax(self.model.predict(state, verbose=0)[0])]]

    def preprocess_state(self, state):
        return np.reshape(state, [1, 3])

    def replay(self, batch_size):
        x_batch, y_batch = [], []
        minibatch = random.sample(self.memory, min(len(self.memory), batch_size))
        for state, action, reward, next_state, done in minibatch:
            y_target = self.model.predict(state, verbose=0)
            action_index = np.where(self.action == action[0])[0][0]
            y_target[0][action_index] = reward if done else reward + self.gamma * np.max(self.model.predict(next_state, verbose=0)[0])
            x_batch.append(state[0])
            y_batch.append(y_target[0])
        
        self.model.fit(np.array(x_batch), np.array(y_batch), batch_size=len(x_batch), verbose=0)
        #epsilon = max(epsilon_min, epsilon_decay*epsilon) # decrease epsilon
       

    def train(self):
        scores = deque(maxlen=100)
        avg_scores = []
        

        for e in range(EPOCHS):
            state = self.env.reset()
            if e % 10 == 0:
                self.env.render()
            state = self.preprocess_state(state)
            done = False
            i = 0
            while not done:
                action = self.choose_action(state,self.epsilon)
                next_state, reward, done, _ = self.env.step(action)
                if e % 10 == 0:
                    self.env.render()
                next_state = self.preprocess_state(next_state)
                self.remember(state, action, reward, next_state, done)
                state = next_state
                self.epsilon = max(self.epsilon_min, self.epsilon_decay*self.epsilon) # decrease epsilon
                i += 1
                
            if e % 10 == 0:
                self.env.close()

            scores.append(i)
            mean_score = np.mean(scores)
            avg_scores.append(mean_score)
            print(f'Epoch number {e}')
            if mean_score >= THRESHOLD and e >= 10:
                print('Ran {} episodes. Solved after {} trials ✔'.format(e, e - 10))
                return avg_scores
            if (e + 1) % 10 == 0:
                print('[Episode {}] - Mean survival time over last 10 episodes was {} ticks.'.format(e, mean_score))

            self.replay(self.batch_size)
        
        print('Did not solve after {} episodes 😞'.format(e))
        return avg_scores

In [66]:
env_string = 'Pendulum-v0'
dqn = DQN(env_string)
avg_scores = dqn.train()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch number 0
Epoch number 1
Epoch number 2
Epoch number 3
Epoch number 4
Epoch number 5
Epoch number 6
Epoch number 7
Epoch number 8
Epoch number 9
[Episode 9] - Mean survival time over last 10 episodes was 200.0 ticks.
Epoch number 10
Epoch number 11


KeyboardInterrupt: 