In [1]:
import gym
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

In [2]:
def create_dqn_model(input_shape, num_actions):
    model = Sequential()
    model.add(Conv2D(32, (8, 8), strides=(4, 4), activation='relu', input_shape=input_shape))
    model.add(Conv2D(64, (4, 4), strides=(2, 2), activation='relu'))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dense(num_actions, activation='linear'))
    return model

In [3]:
# Define the DemonAttack-v0 Gym environment
env = gym.make('DemonAttack-v0')
state_shape = env.observation_space.shape
num_actions = env.action_space.n

In [4]:
# Create the DQN model
model = create_dqn_model(state_shape, num_actions)

# Compile the model
optimizer = Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='mse')

In [5]:
# Define the hyperparameters for training
num_episodes = 10
max_steps_per_episode = 100
epsilon = 1.0
epsilon_decay = 0.99
min_epsilon = 0.01
batch_size = 32
replay_buffer_size = 100000
replay_buffer = []

In [6]:
# Main training loop
for episode in range(num_episodes):
    state = env.reset()
    total_reward = 0
    for step in range(max_steps_per_episode):
        # Choose an action using epsilon-greedy exploration
        if tf.random.uniform(()) < epsilon:
            action = env.action_space.sample()
        else:
            q_values = model.predict(state[None, ...])[0]
            action = tf.argmax(q_values).numpy()
        
        # Take a step in the environment
        next_state, reward, done, _ = env.step(action)
        
        # Add the transition to the replay buffer
        replay_buffer.append((state, action, reward, next_state, done))
        
        # Update the current state and total reward
        state = next_state
        total_reward += reward
        
        # Sample a batch from the replay buffer for training
        if len(replay_buffer) >= batch_size:
            batch_indices = tf.random.uniform((batch_size,), minval=0, maxval=len(replay_buffer), dtype=tf.int32)
            batch = [replay_buffer[i] for i in batch_indices]
            states, actions, rewards, next_states, dones = zip(*batch)
            states = tf.stack(states)
            actions = tf.constant(actions, dtype=tf.int32)
            rewards = tf.constant(rewards, dtype=tf.float32)
            next_states = tf.stack(next_states)
            dones = tf.constant(dones, dtype=tf.float32)
            
            # Compute target Q-values using the DQN model
            target_q_values = rewards + (1 - dones) * epsilon * tf.reduce_max(model.predict(next_states), axis=1)
            
            # Compute predicted Q-values for the chosen actions
            with tf.GradientTape() as tape:
                q_values = model(states)
                q_values = tf.reduce_sum(q_values * tf.one_hot(actions, num_actions), axis=1)
                # Compute the loss and update the model
                loss = tf.reduce_mean(tf.square(target_q_values - q_values))
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        if done:
            break

    # Decay epsilon for epsilon-greedy exploration
    epsilon = max(min_epsilon, epsilon * epsilon_decay)

    # Print the total reward for the episode
    print(f'Episode {episode + 1}, Total Reward: {total_reward}')              

Episode 1, Total Reward: 10.0
Episode 2, Total Reward: 0.0
Episode 3, Total Reward: 10.0
Episode 4, Total Reward: 30.0
Episode 5, Total Reward: 20.0
Episode 6, Total Reward: 10.0
Episode 7, Total Reward: 10.0
Episode 8, Total Reward: 20.0
Episode 9, Total Reward: 20.0
Episode 10, Total Reward: 10.0


In [7]:
state = env.reset()
done = False
while not done:
    q_values = model.predict(state[None, ...])[0]
    action = tf.argmax(q_values).numpy()
    next_state, reward, done, _ = env.step(action)
    state = next_state
    env.render()

In [8]:
env.close()