# Imports

In [None]:
import tensorflow as tf
from keras import Sequential

from keras.api.layers import *
import numpy as np
import random
import matplotlib.pyplot as plt

# Hyperparameters

In [None]:
gamma = 0.99  # Discount factor (Increase gamma to encourage long-term rewards)
epsilon = 1  # Exploration rate
epsilon_min = 0.1
epsilon_decay = 0.95 # Decay in exploration
learning_rate = 0.001 # Reasonable learning rate for 32 batch size
batch_size = 32
memory = []

# Build the DQN

In [None]:

model = Sequential([
    Input(shape=(7,)), 
    Dense(32, activation='relu'),
    Dense(32, activation='relu'),
    Dense(3, activation='linear')
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss='mse')

# Experience replay

In [None]:
def replay():
    td_errors = []
    for state, action, reward, next_state, done in memory:
        # debugging
        if state.shape[-1] != 7 or next_state.shape[-1] != 7:
            raise ValueError(f"State or next_state has incorrect shape: "
                            f"{state.shape}, {next_state.shape}")
        # 
        target = reward
        if not done:
            target += gamma * np.amax(model.predict(next_state[np.newaxis]))
        predicted = model.predict(state[np.newaxis])[0][action]
        td_errors.append(abs(target - predicted))

    probabilities = np.array(td_errors) / sum(td_errors)
    batch_indices = np.random.choice(len(memory), batch_size, p=probabilities)
    batch = [memory[i] for i in batch_indices]
    # debugging
    for sample in batch:
        state, action, reward, next_state, done = sample
        assert state.shape[-1] == 7 and next_state.shape[-1] == 7, \
            f"Sampled state or next_state has incorrect shape: {state.shape}, {next_state.shape}"
    # 

    for state, action, reward, next_state, done in batch:
        target = reward
        if not done:
            target += gamma * np.amax(model.predict(next_state[np.newaxis]))
        target_f = model.predict(state[np.newaxis])
        target_f[0][action] = target
        model.fit(state[np.newaxis], target_f, epochs=1, verbose=0)

# Train the model

In [None]:
from main import BrickBreakerEnv
from tqdm import tqdm 
tf.keras.utils.disable_interactive_logging()

env = BrickBreakerEnv()
episodes = 100

rewards = []
scores = []
times = []

with tqdm(total=episodes, desc="Training Progress") as pbar:
    for e in range(episodes):
        state = env.reset()
        total_reward = 0

        for _ in range(200):  # Max steps per episode
            if np.random.rand() <= epsilon:
                action = env.action_space.sample()  # Exploration
            else:
                q_values = model.predict(state[np.newaxis])
                action = np.argmax(q_values)  # Exploitation: pick best action

            next_state, reward, done, info = env.step(action)
            memory.append((state, action, reward, next_state, done))
            state = next_state

            if done:
                break

        replay()
        rewards.append(total_reward)  
        scores.append(info['score'])  
        times.append(info['time'])    
        print(f"Episode: {e}, Total Reward: {total_reward}, Time: {info['time']:.2f}s, Score: {info['score']}")

        # Update the progress bar
        pbar.update(1)




# Visualization

In [None]:
import matplotlib.pyplot as plt

# Create an episode index
episodes_index = list(range(1, len(rewards) + 1))

# Plotting
plt.figure(figsize=(12, 6))

# Rewards plot
plt.subplot(1, 3, 1)
plt.plot(episodes_index, rewards, label="Rewards", color="blue")
plt.title("Rewards Over Episodes")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.grid(True)
plt.legend()

# Scores plot
plt.subplot(1, 3, 2)
plt.plot(episodes_index, scores, label="Scores", color="orange")
plt.title("Scores Over Episodes")
plt.xlabel("Episode")
plt.ylabel("Score")
plt.grid(True)
plt.legend()

# Times plot
plt.subplot(1, 3, 3)
plt.plot(episodes_index, times, label="Elapsed Time (s)", color="green")
plt.title("Elapsed Time Over Episodes")
plt.xlabel("Episode")
plt.ylabel("Time (s)")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

# Q-value visualization

In [None]:
q_values = model.predict(state[np.newaxis])[0]
plt.bar(["Left", "Stay", "Right"], q_values)
plt.title(f"Q-values (State: {state})")
plt.show()

In [None]:
model.save("brick_breaker_dqn.keras")