In [1]:
#!pip install gymnasium

In [2]:
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm import tqdm

import gymnasium as gym

In [9]:
# Agent
from keras.models import Model, Sequential
from keras.layers import Dense, Input
from keras.optimizers import Adam
from collections import deque

class DDQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

        self.tau = 0.01
        self.batch_size = 64
        self.memory = deque(maxlen=10000)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.model = self.build_model()
        self.target_model = self.build_model()
        self.double_dqn = True

    def build_model(self):
        model = Sequential()
        model.add(Dense(24, input_dim = self.state_size, activation='relu'))
        model.add(Dense(24, activation='relu'))
        # model.add(Dense(2, activation='linear'))
        
        # model.add(Dense(128, input_dim = self.state_size, activation='relu'))
        # model.add(Dense(64, activation='relu'))
        # model.add(Dense(32, activation='relu'))
        model.add(Dense(self.action_size, activation='relu'))
        model.compile(loss='mse',optimizer=Adam(lr=self.learning_rate))
        return model

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if self.epsilon > self.epison_min:
            self.epsilon *= self.epsilon_decay
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        act_values = self.model.predict(state, verbose=0)
        return np.argmax(act_values[0])

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def replay(self):
        minibatch = self.sample(self.batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = self.model.predict(state, verbose=0)
            if not done:
                if self.double_dqn:
                    predicted_action = np.argmax(self.model.predict(next_state,verbose=0)[0])
                    target_q = self.target_model.predict(next_state, verbose=0)[0][predicted_action]
                    target[0][action] = reward + self.gamma * target_q
                else:
                    target_q = self.target_model.predict(next_state)[0]
                    target[0][action] = reward + self.gamma * max(target_q)
            else:
                target[0][action] = reward
            self.model.fit(state,target, epochs=1, verbose=0)
        # if self.epsilon > self.epsilon_min:
        #     self.epsilon *= self.epsilon_decay

    def update_target_model(self):
        #TOODO: Try training without a target network
        weights = self.model.get_weights()
        target_weights = self.target_model.get_weights()
        for i in range(len(target_weights)):
            target_weights[i] = weights[i] * self.tau + target_weights[i] * (1- self.tau)
        self.target_model.set_weights(target_weights)

    def save(self, file):
        self.model.save_weights(file)

    def load(self, file):
        self.model.load_weights(file)
        self.target_model.load_weights(file)
            

In [10]:
# Train
env = gym.make("CartPole-v1")
state_size_, action_size_ = 4, 2
dqn_agent = DDQNAgent(state_size_, action_size_)
try:
    dqn_agent.load('cartpole.h5')
except:
    print("No previous model detected")
n_episodes = 600
n_steps = 200
capacity = 10000

streak_step = 0
total_step = 0

save_episodes = 10
for episode in tqdm(range(n_episodes)):
    cur_state_, _ = env.reset()
    for step in range(n_steps):
        streak_step += 1
        total_step += 1
        cur_state_ = np.reshape(cur_state_, [1, state_size_])
        action = dqn_agent.act(cur_state_)
        observation, reward, done, _, _ = env.step(action)
        observation = np.reshape(observation, [1, state_size_])
        if done:
            reward = -5
        dqn_agent.remember(cur_state_, action, reward, observation, done)
        if total_step % dqn_agent.batch_size == 0:
            dqn_agent.replay()
            dqn_agent.update_target_model()
        cur_state_ = observation
        if done:
            break
    if episode % save_episodes == 0:
        dqn_agent.save("cartpole.h5")
        print(f"Avg game: {streak_step / save_episodes}")
        streak_step = 0

  0%|                                                      | 1/600 [00:23<3:57:43, 23.81s/it]

Avg game: 7.9


  2%|▉                                                    | 11/600 [01:51<1:47:26, 10.94s/it]

Avg game: 24.7


  4%|█▊                                                   | 21/600 [02:59<1:00:32,  6.27s/it]

Avg game: 24.2


  5%|██▋                                                  | 31/600 [04:31<1:09:45,  7.36s/it]

Avg game: 25.7


  7%|███▌                                                 | 41/600 [06:00<1:18:18,  8.41s/it]

Avg game: 19.9


  8%|████▋                                                  | 51/600 [07:34<56:10,  6.14s/it]

Avg game: 29.4


 10%|█████▍                                               | 62/600 [09:30<1:21:25,  9.08s/it]

Avg game: 28.5


 12%|██████▎                                              | 72/600 [11:24<1:24:58,  9.66s/it]

Avg game: 32.0


 14%|███████▏                                             | 81/600 [12:56<1:16:47,  8.88s/it]

Avg game: 29.9


 15%|████████                                             | 91/600 [14:47<1:22:09,  9.68s/it]

Avg game: 29.7


 17%|████████▊                                           | 101/600 [16:44<1:25:00, 10.22s/it]

Avg game: 34.7


 17%|████████▊                                           | 101/600 [17:07<1:24:35, 10.17s/it]


KeyboardInterrupt: 

In [5]:
# Doesn't render in jupyter
obs = env.reset()
done = False
while not done:
    action = dqn_agent.act(obs)
    obs, rewards, done, _, _ = env.step(action)
    env.render()
env.close()

  gym.logger.warn(


InvalidArgumentError: Graph execution error:

Detected at node 'sequential/dense/MatMul' defined at (most recent call last):
    File "<frozen runpy>", line 198, in _run_module_as_main
    File "<frozen runpy>", line 88, in _run_code
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/gabe2max/.local/lib/python3.11/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 736, in start
      self.io_loop.start()
    File "/home/gabe2max/.local/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
      self._run_once()
    File "/usr/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
      handle._run()
    File "/usr/lib/python3.11/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue
      await self.process_one()
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 505, in process_one
      await dispatch(*args)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell
      await result
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 740, in execute_request
      reply_content = await reply_content
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/home/gabe2max/.local/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 546, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3024, in run_cell
      result = self._run_cell(
    File "/home/gabe2max/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell
      result = runner(coro)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/gabe2max/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/gabe2max/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_51545/1007868698.py", line 5, in <module>
      action = dqn_agent.act(obs)
    File "/tmp/ipykernel_51545/3859367989.py", line 38, in act
      act_values = self.model.predict(state, verbose=0)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/training.py", line 2554, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/training.py", line 2341, in predict_function
      return step_function(self, iterator)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/training.py", line 2327, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/training.py", line 2315, in run_step
      outputs = model.predict_step(data)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/training.py", line 2283, in predict_step
      return self(x, training=False)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/training.py", line 569, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1150, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/sequential.py", line 405, in call
      return super().call(inputs, training=training, mask=mask)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/functional.py", line 512, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/functional.py", line 669, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1150, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/gabe2max/.local/lib/python3.11/site-packages/keras/src/layers/core/dense.py", line 241, in call
      outputs = tf.matmul(a=inputs, b=self.kernel)
Node: 'sequential/dense/MatMul'
In[0] ndims must be >= 2: 1
	 [[{{node sequential/dense/MatMul}}]] [Op:__inference_predict_function_618523]