In [5]:
import tensorflow as tf
import numpy as np
import random

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

from collections import deque

In [6]:
class DQN:
  def __init__(self, state_dim, action_dim):
    self.state_dim = state_dim
    self.action_dim = action_dim

    self.main_model = self.create_model()
    self.target_model = self.create_model()
    self.target_model.set_weights(self.main_model.get_weights())

    self.target_main_delta = 0

  def create_model(self):
    model = Sequential()

    model.add(Dense(16, input_dim=self.state_dim, activation="relu"))
    model.add(Dense(16, activation="relu"))
    model.add(Dense(self.action_dim, activation="linear"))

    model.compile(optimizer="adam", loss="mean_squared_loss")

    return model

  def query_main(self, states):
    self.main_model.predict(np.array(states))

  def query_target(self, states):
    self.target_model.predict(np.array(states))

  def update_target(self):
    self.target_model.set_weights(self.main_model.get_weights())
    self.target_main_delta = 0

  def fit_main(self, X, y, batch_size):
    self.main_model.fit(np.array(X), np.array(y), batch_size=batch_size, verbose=0, shuffle=False)

In [10]:
class Memory:
  def __init__(self, size):
    self.size = size
    self.replay_buffer = deque(maxlen=self.size)

  def len(self):
    return len(self.replay_buffer)

  def add(self, state, action, reward, next_state, done):
    self.replay_buffer.append((state, action, reward, next_state, done))

  def sample(self, batch_size):
    return random.sample(self.replay_buffer, batch_size)

In [9]:
class Agent:
  def __init__(self, state_dim, action_dim, memory_size, epsilon, batch_size, discount, target_update):
    self.state_dim = state_dim
    self.action_dim = action_dim
    self.memory_size = memory_size
    self.epsilon = epsilon
    self.batch_size = batch_size
    self.discount = discount
    self.target_update = target_update

    self.dqn = DQN(self.state_dim, self.action_dim)
    self.memory = Memory(self.memory_size)

  def choose_action(self, state):
    if np.random.random() > self.epsilon:
      return np.argmax(self.dqn.query_main(state))
    else:
      return np.random.randint(0, self.action_dim)

  def train(self, episode_end):
    if self.memory.len() < self.batch_size:
      return

    batch = self.memory.sample(self.batch_size)

    X = [memory[0] for memory in batch]
    y = self.dqn.query_main(X)

    target_qs = self.dqn.query_target([memory[3] for memory in batch])

    for index, (state, action, reward, next_state, done) in enumerate(batch):
      if not done:
        target_q = np.max(target_qs[index])
        new_q = reward + self.discount * target_q
      else:
        new_q = reward

      y[index][action] = new_q

    self.dqn.fit_main(X, y, self.batch_size)

    if episode_end:
      self.dqn.target_main_delta += 1

    if self.dqn.target_main_delta % self.target_update == 0:
      self.dqn.update_target()