In [40]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
import numpy as np

In [41]:
class DqnModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, rngs: nnx.Rngs):
    intermediateSize = 512
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [42]:
model = DqnModel(75, 36, nnx.Rngs(0))
targetModel = DqnModel(75, 36, nnx.Rngs(1))
adam = optax.adam(1e-5)
optimizer = nnx.Optimizer(model, adam)

In [43]:
# Function to compute weighted loss and TD error for a SINGLE transition
def computeWeightedLossAndTdErrorSingle(model, targetModel, transition, weight):
  observation, selectedAction, isTerminal, reward, nextObservation = transition

  gamma = 0.99

  # --- DDQN Target Calculation ---
  # Define the calculation for the non-terminal case using DDQN logic
  def ddqnTargetCalculation(_):
    # 1. Select the best action for nextObservation using the *main* model's Q-values.
    #    We don't need gradients for this action selection step itself.
    qValuesNextStateMain = model(nextObservation)
    bestNextAction = jnp.argmax(qValuesNextStateMain)

    # 2. Evaluate the Q-value of that *selected* action using the *target* model.
    qValuesNextStateTarget = targetModel(nextObservation)
    # Index the target Q-values with the action chosen by the main model
    targetQValueForBestAction = qValuesNextStateTarget[bestNextAction]

    # 3. Calculate the final target value (reward + discounted future value).
    #    Stop gradients from flowing back through the target network calculation.
    return reward + gamma * jax.lax.stop_gradient(targetQValueForBestAction)

  targetValue = jax.lax.cond(
      isTerminal,
      lambda _: reward,
      ddqnTargetCalculation,
      None
  )

  # Prediction from main model
  values = model(observation)
  pred = values[selectedAction]

  # Calculate Huber loss (unweighted)
  unweightedLoss = optax.losses.huber_loss(pred, targetValue)

  # Calculate TD Error (unweighted)
  tdError = targetValue - pred

  # Apply Importance Sampling weight to the loss
  weightedLoss = weight * unweightedLoss

  return weightedLoss, (tdError, jnp.min(values), jnp.mean(values), jnp.max(values))

def computeWeightedLossAndTdErrorBatch(model, targetModel, transitions, weights):
  batched = jax.vmap(computeWeightedLossAndTdErrorSingle, in_axes=( None, None, (0, 0, 0, 0, 0), 0 ), out_axes=(0, 0))
  weightedLosses, (tdErrors, minValues, meanValues, maxValues) = batched(model, targetModel, transitions, weights)
  return jnp.mean(weightedLosses), (jnp.mean(tdErrors), jnp.mean(minValues), jnp.mean(meanValues), jnp.mean(maxValues))

@nnx.jit
def train(model, optimizerState, targetModel, observation, selectedAction, isTerminal, reward, nextObservation, weight):
  # print(f'observation: {type(observation)}: {observation}')
  # print(f'selectedAction: {type(selectedAction)}: {selectedAction}')
  # print(f'isTerminal: {type(isTerminal)}: {isTerminal}')
  # print(f'reward: {type(reward)}: {reward}')
  # print(f'nextObservation: {type(nextObservation)}: {nextObservation}')
  # print(f'weight: {type(weight)}: {weight}')

  (loss, auxOutput), gradients = nnx.value_and_grad(computeWeightedLossAndTdErrorBatch, has_aux=True)(model, targetModel, (observation, selectedAction, isTerminal, reward, nextObservation), weight)
  optimizerState.update(gradients)
  return auxOutput

In [44]:
observation = np.array([[1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.]])
selectedAction = [0]
isTerminal = [False]
reward = [0.0]
nextObservation = np.array([[1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.]])
weight = [1.0]

In [45]:
npSelectedAction = np.array(selectedAction)
npIsTerminal = np.array(isTerminal)
npReward = np.array(reward)
npWeight = np.array(weight)
print(f'npSelectedAction: {type(npSelectedAction)}: {npSelectedAction}')
print(f'npIsTerminal: {type(npIsTerminal)}: {npIsTerminal}')
print(f'npReward: {type(npReward)}: {npReward}')
print(f'npWeight: {type(npWeight)}: {npWeight}')

npSelectedAction: <class 'numpy.ndarray'>: [0]
npIsTerminal: <class 'numpy.ndarray'>: [False]
npReward: <class 'numpy.ndarray'>: [0.]
npWeight: <class 'numpy.ndarray'>: [1.]


In [46]:
train(model, optimizer, targetModel, observation, npSelectedAction, npIsTerminal, npReward, nextObservation, npWeight)

(Array(-0.14731956, dtype=float32),
 Array(-0.9599196, dtype=float32),
 Array(-0.06241585, dtype=float32),
 Array(1.0386813, dtype=float32))