In [1]:
from model import AwaleNet
import jax
from awale.env import AwaleJAX
import equinox as eqx
from jax import numpy as jnp
from jax import random
import optax
import copy
from typing import Tuple , List , Any , NamedTuple
import numpy as np
import flashbax as fbx


In [2]:
key1 , key2 = jax.random.split(jax.random.PRNGKey(2))

q_model = AwaleNet(key1)
q_target_model = copy.deepcopy(q_model)


In [3]:
q_target_model

AwaleNet(
  layers=[
    Linear(weight=f32[64,14], bias=f32[64], in_features=14, out_features=64),
    Linear(weight=f32[32,64], bias=f32[32], in_features=64, out_features=32)
  ],
  output_layer=Linear(
    weight=f32[12,32],
    bias=f32[12],
    in_features=32,
    out_features=12
  ),
  dropout_rate=0.3,
  config=ModelConfig(input_size=14, hidden_sizes=[64, 32], dropout_rate=0.3)
)

In [4]:
gamma = 0.99

optimizer = optax.adam(learning_rate=0.0005)
opt_state = optimizer.init(eqx.filter(q_model , eqx.is_array))

In [197]:
@eqx.filter_jit
def get_q_targets(experience_1 , experience_2 , key) :
    """
    Compute target Q-values considering only valid actions
    """
    # Extract values from experience dictionaries
    rewards = experience_1[ "reward" ]  # shape: (batch_size,)
    dones = experience_1[ "done" ]  # shape: (batch_size,)
    next_states_board = experience_2[ "board" ]  # shape: (batch_size, 12)
    next_states_score = experience_2[ "score" ]  # shape: (batch_size, 2)
    next_valid_actions = experience_2[ "valid_actions" ]  # shape: (batch_size, 12)

    # Generate random keys for batch processing
    batch_size = rewards.shape[ 0 ]
    keys = jax.random.split(key , batch_size)

    # Get Q-values for next states using target network
    next_q_values = jax.vmap(q_target_model)(
            next_states_board ,
            next_states_score ,
            next_valid_actions ,
            keys
    )  # shape: (batch_size, 12)
    print(next_q_values.shape)

    next_q_values = jax.vmap(select_action)(next_q_values , next_valid_actions , keys)
    print(next_q_values.shape)

    def single_calculation(reward , done , next_q) :
        return reward + gamma * (1.0 - done) * next_q

    # Compute targets using Bellman equation
    return jax.vmap(single_calculation)(rewards , dones , next_q_values)

In [217]:

@eqx.filter_value_and_grad
def compute_loss(experience_1 , targets , key) :
    """
    Compute DQN loss with gradient tracking
    Args:
        experience_1: States (batch_size, )
        targets: Target Q-values (batch_size,)
        key: PRNGKey 
    """
    # Extract values from experience dictionary
    boards = experience_1[ "board" ]  # shape: (batch_size, 12)
    actions = experience_1[ "action" ]  # shape: (batch_size,)
    scores = experience_1[ "score" ]  # shape: (batch_size, 2)
    valid_actions = experience_1[ "valid_actions" ]  # shape: (batch_size, 12)

    # Get batch size from the input
    batch_size = boards.shape[ 0 ]

    # Generate random keys for batch processing
    keys = jax.random.split(key , batch_size)

    # Get Q-values for current states
    predicted_q_all = jax.vmap(q_model)(
            boards ,
            scores ,
            valid_actions ,
            keys
    )  # shape: (batch_size, 12)

    # Make sure types match
    selected_q_values = jax.vmap(select_action)(predicted_q_all , valid_actions , keys)
    print(selected_q_values)

    # Return mean loss
    return jnp.mean((selected_q_values - targets) ** 2)

In [218]:
def optimize(states , actions , rewards , next_states , dones) :
    # compute target Q-values
    q_targets = get_q_targets(next_states , rewards , dones , actions)

    # compute Q-values
    q_values = jax.vmap(q_model)(states)

    # compute loss
    loss = jnp.mean((q_values - q_targets) ** 2)

    # compute gradients
    grads = jax.grad(loss)(opt_state)


In [219]:
@eqx.filter_jit
def calculate_epsilon(step , epsilon_start , epsilon_finish ,
                      total_timesteps , exploration_fraction) :
    finish_step = total_timesteps * exploration_fraction
    if step > finish_step :
        return epsilon_finish
    epsilon_range = epsilon_start - epsilon_finish
    return epsilon_finish + (((finish_step - step) / finish_step) * epsilon_range)

In [220]:
total_timesteps = 1000
batch_size = 32
buffer_size = 1000000
train_frequency = 4
seed = 1
target_network_update_frequency = 1000
learning_rate = 1e-4
epsilon_start = 1
epsilon_finish = 0.01
exploration_fraction = 0.1
learn_start_size = 100

In [221]:
@eqx.filter_jit
def select_action(
        probs: jnp.ndarray , valid_actions: jnp.ndarray , key: jax.random.PRNGKey
) :
    """Select an action between 0 and 11."""
    masked_probs = jnp.where(valid_actions , probs , 0.0)
    masked_probs = masked_probs / (jnp.sum(masked_probs) + 1e-8)
    return jax.random.categorical(key , jnp.log(masked_probs + 1e-8)).astype(int)


@eqx.filter_jit
def get_valid_actions(positions: jnp.ndarray) :
    possible_actions = jnp.zeros(12)
    possible_actions = possible_actions.at[ positions ].set(1)
    return possible_actions


In [222]:
episode_end_steps = [ ]
episode_rewards = [ ]

In [223]:
env = AwaleJAX( )
state = env.reset(jax.random.PRNGKey(0))

In [224]:
buffer = fbx.make_flat_buffer(max_length=1000 , min_length=64 , sample_batch_size=32)
fake_timestep = { "board" : state.board , "score" : state.score ,
                  "valid_actions" : get_valid_actions(state.action_space) , "action" : jnp.int8(0) , "reward" : jnp.float32(0) , "done" : False , }
memory = buffer.init(fake_timestep)



In [225]:
for step in range(total_timesteps) :
    key = jax.random.PRNGKey(step)
    # Select random action with p(epsilon), else argmax(q).
    epsilon = calculate_epsilon(step , epsilon_start , epsilon_finish ,
                                total_timesteps , exploration_fraction)
    if jax.random.uniform(key) < epsilon :
        # Random action
        action = jax.random.choice(key , state.action_space)

    else :
        # valid actions
        valid_actions = get_valid_actions(state.action_space)
        # Get Q-values
        q_values = q_model(board=state.board , scores=state.score , valid_actions=valid_actions , key=key)
        # Select action with highest Q-value
        action = select_action(q_values , valid_actions , key)[ 0 ]

    # Take action
    next_state , reward , done = env.step(state , action)

    # Add transition to replay memory
    memory = buffer.add(memory , { "board" : state.board , "score" : state.score ,
                                   "valid_actions" : get_valid_actions(state.action_space) , "action" : jnp.int8(action) , "reward" : reward , "done" : done })
    if done :
        episode_end_steps.append(step)
        episode_rewards.append(reward)
        next_state = env.reset(jax.random.PRNGKey(step))

    state = next_state
    if step > learn_start_size and step % train_frequency == 0 and buffer.can_sample(memory) :
        # Sample replay experiences.
        batch = buffer.sample(memory , key)

        # Compute target Q-values
        q_targets = get_q_targets(batch.experience.first , batch.experience.second , key)
        # Compute Q-values

        # Compute loss
        loss = compute_loss(batch.experience.first , q_targets , key)

        print(loss)




[[ 0]
 [ 8]
 [ 7]
 [ 8]
 [ 2]
 [ 8]
 [ 3]
 [ 4]
 [ 0]
 [ 6]
 [ 8]
 [ 2]
 [ 4]
 [ 4]
 [ 5]
 [ 2]
 [ 7]
 [ 4]
 [ 4]
 [ 3]
 [ 7]
 [ 0]
 [ 8]
 [ 3]
 [ 7]
 [ 0]
 [ 0]
 [ 7]
 [ 3]
 [ 7]
 [11]
 [11]]
(Array(172.45575, dtype=float32), {'action': None, 'board': None, 'done': None, 'reward': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32), 'score': None, 'valid_actions': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [

In [109]:
rng_key = jax.random.PRNGKey(0)  # Source of randomness.
batch = buffer.sample(memory , rng_key)  # Sample

In [171]:
batch.experience.first[ "action" ].shape

(32,)

In [53]:
batch.experience.second

{'action': Array([7], dtype=int8),
 'board': Array([[ 0,  2, 11,  3,  1,  1,  2,  3,  0,  1,  0,  0]], dtype=int8),
 'done': Array([False], dtype=bool),
 'reward': Array([2.], dtype=float32),
 'score': Array([[ 5, 16]], dtype=int8),
 'valid_actions': Array([[0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0.]], dtype=float32)}

In [83]:
key = jax.random.PRNGKey(0)

In [85]:
key = jax.random.split(key , 32)

In [86]:
key

Array([[3864235129, 1482072793],
       [1178783573,  955665731],
       [ 792409384, 1741258570],
       [ 519273896,  249394534],
       [1878013843, 4159128055],
       [2070707689, 1605263765],
       [3500152212, 1876574459],
       [1421029721,  277885049],
       [ 400790508, 3220228019],
       [1938920955, 1953859146],
       [2558934974, 3730924574],
       [ 848489424,  330715923],
       [2882287740,  887450378],
       [ 209891383, 3284945018],
       [3009137967, 3214895957],
       [3377263742, 2375760382],
       [2612127288, 3632685311],
       [2954523166, 4063401906],
       [1111675924, 3044665702],
       [2321405076, 2615623995],
       [ 402066021,  283142392],
       [3446067388, 3041618983],
       [3224072842, 1864786479],
       [1060626807, 2589178291],
       [3653168630, 1463922400],
       [ 446060362, 2568410976],
       [3782165314, 3812911679],
       [2339880062, 3541939123],
       [3198715237, 2029634584],
       [2795290120, 1759568708],
       [22

In [40]:
train = jnp.array([ True ] * 32)

In [41]:
train

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True], dtype=bool)