
# Rainbow DQN for Chrome Dino Game

This notebook provides a step-by-step guide to implementing Rainbow DQN for training an agent to play the Chrome Dino Game.
Rainbow DQN combines multiple reinforcement learning enhancements, making it suitable for complex tasks.


In [None]:

# Install necessary libraries
!pip install gym chromedino jax nnx optax numpy


[31mERROR: Could not find a version that satisfies the requirement chromedino (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for chromedino[0m[31m
[0m


## Environment Setup: Chrome Dino Game

The **Chrome Dino Game** involves controlling a dinosaur to avoid obstacles by jumping or ducking.
The goal is to maximize the score by surviving as long as possible.

### Observation Space
- 2D representation of the game state (e.g., pixel data or feature vectors).

### Action Space
- `0`: Do Nothing
- `1`: Jump
- `2`: Duck

### Reward
- Positive reward for surviving each timestep.
- Negative reward for collisions.



## Rainbow DQN Model

The Rainbow DQN model integrates state-of-the-art techniques like:
1. **Distributional Q-Learning**
2. **Dueling Network Architecture**
3. **Prioritized Experience Replay**
4. **Double Q-Learning**
5. **Noisy Networks**
6. **Multi-step Learning**

Below is the implementation of the model using the Flax `nnx` API.


In [None]:

import jax
import jax.numpy as jnp
from flax import nnx

class RainbowDQN(nnx.Module):
    action_dim: int
    atoms: int
    v_min: float
    v_max: float

    def __init__(self):
        self.support = jnp.linspace(self.v_min, self.v_max, self.atoms)
        self.feature_layer = nnx.Linear(128)
        self.value_layer = nnx.Sequential([
            nnx.Linear(128), nnx.relu, nnx.Linear(self.atoms)
        ])
        self.advantage_layer = nnx.Sequential([
            nnx.Linear(128), nnx.relu, nnx.Linear(self.action_dim * self.atoms)
        ])

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        features = nnx.relu(self.feature_layer(x))
        value = self.value_layer(features).reshape(-1, 1, self.atoms)
        advantage = self.advantage_layer(features).reshape(-1, self.action_dim, self.atoms)
        q_atoms = value + (advantage - advantage.mean(axis=1, keepdims=True))
        return nnx.softmax(q_atoms, axis=-1)

    def q_values(self, x: jnp.ndarray) -> jnp.ndarray:
        q_atoms = self(x)
        return (q_atoms * self.support).sum(axis=-1)



## Prioritized Replay Buffer

The replay buffer stores transitions and prioritizes sampling based on their importance.

### Key Features
- **Storage**: Saves states, actions, rewards, and transitions.
- **Prioritization**: Samples transitions based on TD error magnitudes.
- **Updates**: Updates priorities based on new TD errors.


In [None]:

import numpy as np

class PrioritizedReplayBuffer:
    def __init__(self, capacity, state_shape, action_dim, alpha=0.6):
        self.capacity = capacity
        self.ptr, self.size = 0, 0
        self.alpha = alpha
        self.states = np.zeros((capacity, *state_shape), dtype=np.float32)
        self.actions = np.zeros(capacity, dtype=np.int32)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.next_states = np.zeros((capacity, *state_shape), dtype=np.float32)
        self.dones = np.zeros(capacity, dtype=np.float32)
        self.priorities = np.zeros(capacity, dtype=np.float32)

    def store(self, state, action, reward, next_state, done):
        max_priority = self.priorities.max() if self.size > 0 else 1.0
        self.states[self.ptr] = state
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_states[self.ptr] = next_state
        self.dones[self.ptr] = done
        self.priorities[self.ptr] = max_priority
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size, beta=0.4):
        priorities = self.priorities[:self.size] ** self.alpha
        probabilities = priorities / priorities.sum()
        indices = np.random.choice(self.size, batch_size, p=probabilities)
        weights = (self.size * probabilities[indices]) ** (-beta)
        weights /= weights.max()
        return (
            self.states[indices], self.actions[indices], self.rewards[indices],
            self.next_states[indices], self.dones[indices], indices, weights
        )

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority


  and should_run_async(code)



## Training Procedure

The training loop integrates the environment, Rainbow DQN model, replay buffer, and optimization process.

### Highlights
- **Epsilon-Greedy Exploration**: Balances exploration and exploitation.
- **Experience Storage**: Stores transitions in the replay buffer.
- **Batch Updates**: Samples a batch of transitions for training.

Below is the implementation of the training loop.


In [None]:
import gym

def train_dino():
    env = gym.make("ChromeDino-v0")
    state_shape = env.observation_space.shape
    action_dim = env.action_space.n

    replay_buffer = PrioritizedReplayBuffer(100_000, state_shape, action_dim)
    model = RainbowDQN(action_dim=action_dim, atoms=51, v_min=-10, v_max=10)
    params = nnx.Collection(model.init(jax.random.PRNGKey(0), jnp.ones(state_shape)))
    opt_state = optax.adam(1e-4).init(params)

    episodes_total = 1000  # Renamed for clarity
    batch_size = 32
    gamma = 0.99
    rng = jax.random.PRNGKey(0)
    rewards, episodes = [], []


    for episode in range(episodes_total):
        state = env.reset()
        episode_reward = 0
        done = False

        while not done:
            q_values = model.q_values(jnp.expand_dims(state, axis=0))
            action = q_values.argmax() if np.random.rand() > 0.1 else env.action_space.sample()
            next_state, reward, done, _ = env.step(action)
            replay_buffer.store(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            rewards.append(episode_reward)
            episodes.append(episode)

            if replay_buffer.size > batch_size:
                params, opt_state, loss = train_step(rng, params, opt_state, replay_buffer, model, batch_size, gamma)

        print(f"Episode {episode + 1}, Reward: {episode_reward}")

    # Return the results from the function
    return episodes, rewards

# Call the function and get the results
episodes, rewards = train_dino()





NameError: name 'episodes' is not defined

# Saving the results in a dataset for further enchancements
We will save the episode and reward results into a csv file for inspections.

In [None]:
import pandas as pd
results = {'Episode': sum(episodes), 'Reward': rewards}
df = pd.DataFrame(results)
df.to_csv('dino_results.csv', index=False)