In [1]:
#!pip install gymnasium

In [2]:
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm import tqdm

import gymnasium as gym

In [20]:
# Agent
from keras.models import Model, Sequential
from keras.layers import Dense, Input
from keras.optimizers import Adam
from collections import deque

class DDQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

        self.tau = 0.01
        self.batch_size = 64
        self.memory = deque(maxlen=10000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.leaning_rate = 0.001
        self.model = self.build_model()
        self.target_model = self.build_model()

    def build_model(self):
        model = Sequential()
        model.add(Dense(128, input_dim = self.state_size, activation='relu'))
        model.add(Dense(64, activation='relu'))
        model.add(Dense(32, activation='relu'))
        model.add(Dense(self.action_size, activation='relu'))
        model.compile(loss='mse',optimizer='adam')
        return model

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        act_values = self.model.predict(state, verbose=0)
        return np.argmax(act_values[0])

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def replay(self):
        minibatch = self.sample(self.batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                predicted_action = np.argmax(self.model.predict(next_state,verbose=0)[0])
                target = reward + self.gamma * self.target_model.predict(next_state, verbose=0)[0][predicted_action]
            target_f = self.model.predict(state, verbose=0)
            target_f[0][action] = target
            self.model.fit(state,target_f, epochs=1, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target_model(self):
        weights = self.model.get_weights()
        target_weights = self.target_model.get_weights()
        for i in range(len(target_weights)):
            target_weights[i] = weights[i] * self.tau + target_weights[i] * (1- self.tau)
        self.target_model.set_weights(target_weights)

    def save(self, file):
        self.model.save_weights(file)

    def load(self, file):
        self.model.load_weights(file)
        self.target_model.load_weights(file)
            

In [22]:
# Train
env = gym.make("CartPole-v1")
state_size_, action_size_ = 4, 2
dqn_agent = DDQNAgent(state_size_, action_size_)
try:
    dqn_agent.load('cartpole.h5')
except:
    print("No previous model detected")
n_episodes = 100
n_steps = 200
capacity = 10000

streak_step = 0
total_step = 0

save_episodes = 10
for episode in tqdm(range(n_episodes)):
    cur_state_, _ = env.reset()
    for step in range(n_steps):
        streak_step += 1
        total_step += 1
        cur_state_ = np.reshape(cur_state_, [1, state_size_])
        action = dqn_agent.act(cur_state_)
        observation, reward, done, _, _ = env.step(action)
        observation = np.reshape(observation, [1, state_size_])
        dqn_agent.remember(cur_state_, action, reward, observation, done)
        if total_step % dqn_agent.batch_size == 0:
            dqn_agent.replay()
            dqn_agent.update_target_model()
        cur_state_ = observation
        if done:
            break
    if episode % save_episodes == 0:
        dqn_agent.save("cartpole.h5")
        print(f"Avg game: {streak_step / save_episodes}")
        streak_step = 0

  0%|                                                                                                                                                                | 0/100 [00:00<?, ?it/s]

Avg game: 2.1


 10%|███████████████                                                                                                                                        | 10/100 [00:28<04:32,  3.02s/it]

Avg game: 19.1


 19%|████████████████████████████▋                                                                                                                          | 19/100 [00:56<04:06,  3.05s/it]

Avg game: 21.0


 29%|███████████████████████████████████████████▊                                                                                                           | 29/100 [01:23<03:16,  2.76s/it]

Avg game: 18.3


 42%|███████████████████████████████████████████████████████████████▍                                                                                       | 42/100 [01:51<02:03,  2.13s/it]

Avg game: 19.2


 51%|█████████████████████████████████████████████████████████████████████████████                                                                          | 51/100 [02:28<02:35,  3.18s/it]

Avg game: 27.9


 60%|██████████████████████████████████████████████████████████████████████████████████████████▌                                                            | 60/100 [02:57<01:50,  2.76s/it]

Avg game: 18.4


 69%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                              | 69/100 [03:33<01:49,  3.53s/it]

Avg game: 27.3


 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                            | 81/100 [04:01<00:39,  2.08s/it]

Avg game: 18.7


 93%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍          | 93/100 [04:38<00:15,  2.27s/it]

Avg game: 21.0


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [04:57<00:00,  2.98s/it]


In [18]:
# Doesn't render in jupyter
obs = env.reset()
done = False
while not done:
    action = dqn_agent.act(obs)
    obs, rewards, done, _, _ = env.step(action)
    env.render()
env.close()