# Flax basics


In [15]:
! pip install flax jax jaxlib optax gymnasium

  pid, fd = os.forkpty()


Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## DQN in CartPole-v1

In [16]:
import gymnasium as gym

import jax
import jax.numpy as jnp
import optax
import flax.linen as nn

import random
import numpy as np
from collections import deque
from typing import Sequence
from flax.training.train_state import TrainState


In [17]:
# Define the Q-network using flax.linen

class q_network(nn.Module):
    hidden_dims: Sequence[int]
    n_actions: int
    
    @nn.compact
    def __call__(self, x):
        """ Network architecture:
            input layer(state_dim)
            ==> hidden layer(128) + relu 
            ==> hidden layer(128) + relu
            ==> output layer(n_actions)
        """
        for h in self.hidden_dims:
            x = nn.Dense(h)(x)
            x = nn.relu(x)
        return nn.Dense(self.n_actions)(x) 

In [18]:
# Experience replay buffer
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, ne_state, done):
        self.buffer.append((state, action, reward, ne_state, done))

    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)    
        states, actions, rewards, ne_states, dones = map(np.array, zip(*samples))
        return states, actions, rewards, ne_states, dones

    def __len__(self):
        return len(self.buffer)

In [19]:
# DQN utils

@jax.jit
def select_action(params, state, epsilon, rng, n_actions):
    if random.random() < epsilon:
        return random.randint(0, n_actions - 1)  # Explore
    q_values = q_net(params, state)
    return int(jnp.argmax(q_values))  # Exploit

@jax.jit
def train_step(state: TrainState, batch, gamma: float):
    def loss_fn(params):
        states, actions, rewards, ne_states, dones = batch
        q_values = q_net.apply(params, states)
        q_actions = jnp.take_along_axis(q_values, actions[..., None], axis=1).squeeze()  # get value
        
        next_q_values = q_net.apply(state.params, ne_states)
        max_next_q = jnp.max(next_q_values, axis=1)
        
        target = rewards + (1 - dones) * gamma * max_next_q
        
        loss = jnp.mean((q_actions - target) ** 2)  # MSE loss
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [20]:
env = gym.make("CartPole-v1")
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]

q_net = q_network(hidden_dims=[128, 128], n_actions=n_actions)
rng = jax.random.PRNGKey(0)

# init model and optimizer
lr = 1e-3
input_dim = jnp.ones((state_dim,))
params = q_net.init(rng, input_dim)
tx = optax.adam(learning_rate=lr)
train_state = TrainState.create(
    apply_fn=q_net.apply,
    params=params,
    tx=tx
)

# hyperparameters
buffer = ReplayBuffer()
num_episodes = 1000
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.995
min_epsilon = 0.05

for episode in range(num_episodes):
    state, _ = env.reset()
    total_reward = 0
    
    for t in range(200):
        state_tensor = jnp.array(state, dtype=jnp.float32)
        actions = select_action(train_state.params, state_tensor, epsilon, rng, n_actions)

        next_state, reward, terminated, truncated, _ = env.step(actions)
        
        done = terminated or truncated
        total_reward += reward
        
        if len(buffer) > batch_size:
            states, actions, rewards, ne_states, dones = buffer.sample(batch_size)
            batch = (
                jnp.array(states, dtype=jnp.float32),
                jnp.array(actions, dtype=jnp.int32),
                jnp.array(rewards, dtype=jnp.float32),
                jnp.array(ne_states),
                jnp.array(dones, dtype=jnp.bool_)
            )
            
            train_state, loss = train_step(train_state, batch, gamma)
        
        if done:
            break
        
    
    epsilon = max(min_epsilon, epsilon * epsilon_decay)
    print(f"Episode {episode + 1}, Total Reward: {total_reward}, Epsilon: {epsilon:.2f}")



TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function select_action at /var/folders/gy/6wf0tc3n7276h9vd9dqcpkvw0000gn/T/ipykernel_48762/608396390.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument epsilon.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError