In [1]:
import os, sys

# locate the rl/python directory from the notebook’s folder
root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src  = os.path.join(root, 'dqn.py')        # optional sanity check
pkg  = os.path.join(root, '..')            # up two levels: bot/src/rl

# insert the rl/python folder on the import path
sys.path.insert(0, pkg)

os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump'
os.environ['JAX_CHECK_TRACER_LEAKS'] = '1'

In [2]:
from flax import nnx
import numpy as np
import jax
import jax.numpy as jnp
from jax import profiler
from python.dqn import DqnModel, selectAction, jittedTrain, getOptaxAdamWOptimizer
from copy import deepcopy

In [3]:
batchSize = 128
observationSize = 168
observationStackSize = 64
actionSpaceSize = 35
dropoutRate = 0.1
learningRate = 1e-3

rngSeed = 0

@nnx.jit
def getModel(rngSeed):
  nnxRngs = nnx.Rngs(jax.random.key(rngSeed))

  return DqnModel(
      observationSize=observationSize,
      stackSize=observationStackSize,
      actionSpaceSize=actionSpaceSize,
      dropoutRate=dropoutRate,
      rngs=nnxRngs
  )

model = getModel(rngSeed)
targetModel = deepcopy(model)

@nnx.jit(static_argnums=(1,))
def getOptimizer(model, learningRate):
  """Create an optimizer for the model."""
  optTx = getOptaxAdamWOptimizer(learningRate)
  optimizerState = nnx.Optimizer(model, optTx)
  return optimizerState

# optTx = getOptaxAdamWOptimizer(learningRate)
# optimizerState = nnx.Optimizer(model, optTx)

optimizerState = getOptimizer(model, learningRate)

In [4]:
# Input is:
# - pastObservationStack:      (stackSize, observationSize)
# - pastObservationTimestamps: (stackSize, 1)
# - pastActions:               (stackSize, actionSpaceSize)
# - pastMask:                  (stackSize, 1)
# - currentObservation:        (observationSize)
@nnx.jit(static_argnums=(0,))
def createData(rngSeed):
  pastObservationStack = jnp.zeros((batchSize, observationStackSize, observationSize))
  pastObservationTimestamps = jnp.zeros((batchSize, observationStackSize, 1))
  pastActions = jnp.zeros((batchSize, observationStackSize, actionSpaceSize))
  pastMask = jnp.zeros((batchSize, observationStackSize, 1))
  currentObservation = jnp.zeros((batchSize, observationSize))
  actions = jnp.zeros((batchSize, observationStackSize), dtype=np.int32)
  isTerminals = jnp.zeros((batchSize,), dtype=np.bool)
  rewards = jnp.zeros((batchSize,))
  weights = jnp.zeros((batchSize,))

  tuple1 = (pastObservationStack, pastObservationTimestamps, pastActions, pastMask, currentObservation)
  tuple2 = (pastObservationStack, pastObservationTimestamps, pastActions, pastMask, currentObservation)
  return tuple1, tuple2, actions, isTerminals, rewards, weights, jax.random.key(rngSeed)

tuple1, tuple2, actions, isTerminals, rewards, weights, rngKey = createData(rngSeed)

res = jittedTrain(
  model=model,
  optimizerState=optimizerState,
  targetModel=targetModel,
  pastModelInputTuple=tuple1,
  selectedActions=actions,
  isTerminals=isTerminals,
  rewards=rewards,
  currentModelInputTuple=tuple2,
  weights=weights,
  gamma=0.99,
  rngKey=rngKey
)

jax.block_until_ready(res)

E0601 19:10:10.397407  190570 hlo_lexer.cc:443] Failed to parse int literal: 30212747383567152780632


(Array(0., dtype=float32),
 Array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32))

In [5]:
@nnx.jit
def createData2():
  pastObservationStack = jnp.zeros((observationStackSize, observationSize))
  pastObservationTimestamps = jnp.zeros((observationStackSize, 1))
  pastActions = jnp.zeros((observationStackSize, actionSpaceSize))
  pastMask = jnp.zeros((observationStackSize, 1))
  currentObservation = jnp.zeros((observationSize,))
  actionMask = jnp.zeros((actionSpaceSize,))

  return pastObservationStack, pastObservationTimestamps, pastActions, pastMask, currentObservation, actionMask

pastObservationStack, pastObservationTimestamps, pastActions, pastMask, currentObservation, actionMask = createData2()

# @nnx.jit
# def selectAction(model, pastObservationStack, pastObservationTimestamps, pastActions, pastMask, currentObservation, actionMask):
#   values = model(pastObservationStack, pastObservationTimestamps, pastActions, pastMask, currentObservation, deterministic=True)
#   values += actionMask
#   return jnp.argmax(values)

selectAction = selectAction(model, pastObservationStack, pastObservationTimestamps, pastActions, pastMask, currentObservation, actionMask)