<a href="https://colab.research.google.com/github/Zakuta/D-QRL/blob/main/QRL_try_2_feb24_2024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
# !pip install equinox
# !pip install tensorcircuit
# !pip install -U qiskit
# !pip install tensorcircuit
# !pip install cirq
# !pip install openfermion
# !pip install gymnax
# !pip install brax
# !pip install distrax
# !pip install flax



In [1]:
import jax
from jax import config

config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)

import jax.numpy as jnp
DTYPE=jnp.float64

import chex
import numpy as np
import optax
from flax import struct
from functools import partial
import tensorcircuit as tc

import tensorflow as tf
from sklearn.decomposition import PCA
import equinox as eqx
import types
from jaxtyping import Array, PRNGKeyArray
from typing import Union, Sequence, List, NamedTuple, Optional, Tuple, Any, Literal, TypeVar
import jax.tree_util as jtu
import gymnax
import distrax
from gymnax.environments import environment, spaces
from brax import envs
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper

K = tc.set_backend("jax")



In [2]:
# shamelessly taken from purejaxrl: https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/wrappers.py

class GymnaxWrapper(object):
    """Base class for Gymnax wrappers."""

    def __init__(self, env):
        self._env = env

    # provide proxy access to regular attributes of wrapped object
    def __getattr__(self, name):
        return getattr(self._env, name)


class FlattenObservationWrapper(GymnaxWrapper):
    """Flatten the observations of the environment."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    def observation_space(self, params) -> spaces.Box:
        assert isinstance(
            self._env.observation_space(params), spaces.Box
        ), "Only Box spaces are supported for now."
        return spaces.Box(
            low=self._env.observation_space(params).low,
            high=self._env.observation_space(params).high,
            shape=(np.prod(self._env.observation_space(params).shape),),
            dtype=self._env.observation_space(params).dtype,
        )

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, state = self._env.reset(key, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, state, reward, done, info = self._env.step(key, state, action, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state, reward, done, info


@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int
    timestep: int


class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params)
        state = LogEnvState(env_state, 0, 0, 0, 0, 0)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        new_episode_return = state.episode_returns + reward
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - done),
            episode_lengths=new_episode_length * (1 - done),
            returned_episode_returns=state.returned_episode_returns * (1 - done)
            + new_episode_return * done,
            returned_episode_lengths=state.returned_episode_lengths * (1 - done)
            + new_episode_length * done,
            timestep=state.timestep + 1,
        )
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["timestep"] = state.timestep
        info["returned_episode"] = done
        return obs, state, reward, done, info

In [3]:
def reuploading_circuit(n_qubits, n_layers, rot_params, input_params, X):
  circuit = tc.Circuit(n_qubits)

  for l in range(n_layers):
    # variational part
    for qubit_idx in range(n_qubits):
      circuit.rx(qubit_idx, theta=rot_params[l, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=rot_params[l, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=rot_params[l, qubit_idx, 2])

    # entangling part
    for qubit_idx in range(n_qubits - 1):
      circuit.cnot(qubit_idx, qubit_idx + 1)
    if n_qubits != 2:
      circuit.cnot(n_qubits - 1, 0)

    # encoding part
    for qubit_idx in range(n_qubits):
      input = X[qubit_idx] * input_params[l, qubit_idx]
      # print(input)
      circuit.rx(qubit_idx, theta=input)

  # last variational part
  for qubit_idx in range(n_qubits):
    circuit.rx(qubit_idx, theta=rot_params[n_layers, qubit_idx, 0])
    circuit.ry(qubit_idx, theta=rot_params[n_layers, qubit_idx, 1])
    circuit.rz(qubit_idx, theta=rot_params[n_layers, qubit_idx, 2])

  return circuit


# class PQCLayer(eqx.Module):
#   theta: Array
#   lmbd: Array
#   n_qubits: int = eqx.field(static=True)
#   n_layers: int = eqx.field(static=True)

#   def __init__(self, n_qubits: int, n_layers: int, params: Optional, key: PRNGKeyArray):

#     key = jax.random.PRNGKey(key)
#     tkey, lkey = jax.random.split(key, num=2)

#     if params is None:
#       self.theta = params['thetas']
#       self.lmbd = params['lmbds']
#     else:
#       self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
#                                     minval=0.0, maxval=np.pi, dtype=DTYPE)
#       self.lmbd = jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)

#     self.n_qubits = n_qubits
#     self.n_layers = n_layers

#   def __call__(self, inputs):
#     circuit = tc.Circuit(self.n_qubits)

#     for l in range(self.n_layers):
#       # variational part
#       for qubit_idx in range(self.n_qubits):
#         circuit.rx(qubit_idx, theta=self.theta[l, qubit_idx, 0])
#         circuit.ry(qubit_idx, theta=self.theta[l, qubit_idx, 1])
#         circuit.rz(qubit_idx, theta=self.theta[l, qubit_idx, 2])

#       # entangling part
#       for qubit_idx in range(self.n_qubits - 1):
#         circuit.cnot(qubit_idx, qubit_idx + 1)
#       if self.n_qubits != 2:
#         circuit.cnot(self.n_qubits - 1, 0)

#       # encoding part
#       for qubit_idx in range(self.n_qubits):
#         linear_input = inputs[qubit_idx] * self.lmbd[l, qubit_idx]
#         circuit.rx(qubit_idx, theta=linear_input)

#     # last variational part
#     for qubit_idx in range(self.n_qubits):
#       circuit.rx(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 0])
#       circuit.ry(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 1])
#       circuit.rz(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 2])

#     return jnp.real(circuit.expectation_ps(z=jnp.arange(len(self.n_qubits))))


# class PQCLayer(eqx.Module):
#   theta: jax.Array = eqx.field(converter=jnp.asarray)
#   lmbd: jax.Array = eqx.field(converter=jnp.asarray)
#   n_qubits: int = eqx.field(static=True)
#   n_layers: int = eqx.field(static=True)

#   def __init__(self, n_qubits: int, n_layers: int, key: int):
#     key = jax.random.PRNGKey(key)
#     tkey, lkey = jax.random.split(key, num=2)
#     self.n_qubits = n_qubits
#     self.n_layers = n_layers
#     # rotation_params
#     self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
#                                     minval=0.0, maxval=np.pi, dtype=DTYPE)
#     # input encoding params
#     # self.lmbd = jnp.ones(shape=(n_layers, n_qubits))
#     self.lmbd = jax.random.uniform(key=lkey, shape=(n_layers, n_qubits),
#                                     minval=0.0, maxval=np.pi, dtype=DTYPE)

#   def __call__(self, inputs):
#   # def __call__(self, X, n_qubits, depth):

#     circuit = generate_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)
#     # state = circuit.state()
#     # return state
#     return K.real(circuit.expectation_ps(z=[0,1,2,3]))

# class Alternating(eqx.Module):
#   w: jax.Array = eqx.field(converter=jnp.asarray)

#   def __init__(self, output_dim):
#     self.w = jnp.array([[(-1.) ** i for i in range(output_dim)]])

#   def __call__(self, inputs):
#     return jnp.matmul(inputs, self.w)


# class Actor(eqx.Module):
#   n_qubits: int
#   n_layers: int
#   beta: float
#   n_actions: Sequence[int]
#   key: int

#   def __call__(self, x):
#     re_uploading_pqc = PQCLayer(n_qubits=self.n_qubits,
#                                 n_layers=self.n_layers,
#                                 key=self.key)(x)

#     process = eqx.nn.Sequential([
#         Alternating(self.n_actions),
#         eqx.nn.Lambda(lambda x: x * self.beta),
#         jax.nn.softmax()
#     ])

#     policy = process(re_uploading_pqc)

#     return policy



class QuantumActor(eqx.Module):
  theta: jax.Array # trainable
  lmbd: jax.Array # trainable
  w: jax.Array # trainable
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)
  beta: float = eqx.field(static=True)
  n_actions: Sequence[int] = eqx.field(static=True)
  # key: int

  def __init__(self, n_qubits, n_layers, beta, n_actions, params: Optional, key = 42):

    key = jax.random.PRNGKey(key)
    key, _key = jax.random.split(key, num=2)

    if params is None:
      # rotation_params
      self.theta = params['thetas']
      # input encoding params
      self.lmbd = params['lmbds']
      # observable weights
      self.w = params['ws']
    else:
      self.theta = jax.random.uniform(key=key, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE)
      self.lmbd = jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)
      self.w = jnp.array([[(-1.) ** i for i in range(n_actions)]])


    self.n_qubits = n_qubits
    self.n_layers = n_layers
    self.beta = beta
    self.n_actions = n_actions

  def quantum_policy_circuit(self, inputs):

    # this can be any PQC of the user's choice. hence, I made the decision to make a separate function within this class
    circuit = reuploading_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)

    return K.real(circuit.expectation_ps(z=np.arange(self.n_qubits)))

  def alternating(self, inputs):
    return jnp.matmul(inputs, self.w)

  def get_params(self):
    return {'thetas': self.theta, 'lmbds': self.lmbd, 'ws': self.w}

  def __call__(self, x):

    pqc = self.quantum_policy_circuit(x)
    # print(pqc)
    alt = self.alternating(jnp.array([pqc], dtype=DTYPE))

    # process = eqx.nn.Sequential([
    #     alt,
    #     eqx.nn.Lambda(lambda x: x * self.beta),
    #     jax.nn.softmax()
    # ])
    # process = eqx.nn.Sequential([
    #     self.alternating,
    #     eqx.nn.Lambda(lambda x: x * self.beta),
    #     distrax.Softmax
    # ])


    # policy = process(jnp.array([pqc], dtype=DTYPE))

    actor_mean = eqx.nn.Lambda(lambda x: x * self.beta)(alt)
    # policy = jax.nn.softmax(actor_mean)
    policy = distrax.Softmax(actor_mean)

    return policy


# class Actor(eqx.Module):
#   n_qubits: int
#   n_layers: int
#   beta: float
#   n_actions: Sequence[int]
#   pqc: eqx.Module
#   alt: eqx.Module
#   key: int

#   def __init__(self, n_qubits, n_layers, beta, n_actions, key):
#     self.n_qubits = n_qubits
#     self.n_layers = n_layers
#     self.beta = beta
#     self.n_actions = n_actions
#     self.key = key

#     self.pqc = PQCLayer(n_qubits=self.n_qubits,
#                         n_layers=self.n_layers,
#                         key=self.key)

#     self.alt = Alternating(self.n_actions)

#   def __call__(self, x):
#     re_uploading_pqc = self.pqc(x)

#     process = eqx.nn.Sequential([
#         self.alt,
#         eqx.nn.Lambda(lambda x: x * self.beta),
#         jax.nn.softmax()
#     ])

#     policy = process(re_uploading_pqc)

#     return policy


class TrainState(eqx.Module):
  model: eqx.Module
  optimizer: optax.GradientTransformation = eqx.field(static=True)
  opt_state: optax.OptState

  def __init__(self, model, optimizer, opt_state):
    self.model = model
    self.optimizer = optimizer
    self.opt_state = opt_state

  def apply_updates_to_model(self, new_params):
    # this function is specific and works for this example only. one can think of
    # generalizing it to work for updating any given attribute/params of the model.

    model_new = eqx.tree_at(where=lambda model: model.theta, pytree=self.model, replace=new_params['thetas'])
    model_new = eqx.tree_at(where=lambda model: model.lmbd, pytree=model_new, replace=new_params['lmbds'])
    model_new = eqx.tree_at(where=lambda model: model.w, pytree=model_new, replace=new_params['ws'])

    return model_new

  def apply_gradients(self, params, grads):
    grads = {'thetas': grads.theta, 'lmbds': grads.lmbd, 'ws': grads.w}
    # if type(grads) == dict: # this is if we actually apply value_and_grad functionality of JAX, might be useful if we apply PPO for instance
    #   grads = {'thetas': grads.theta, 'lmbds': grads.lmbd, 'ws': grads.w}
    # else: # for monte-carlo estimation of grads which REINFORCE applies
    #   grads = grads
    updates, opt_state = self.optimizer.update(grads, self.opt_state, params)
    new_params = optax.apply_updates(params, updates)
    model_new = self.apply_updates_to_model(new_params)
    # model = eqx.apply_updates(self.model, updates)
    new_train_state = self.__class__(model=model_new, optimizer=self.optimizer, opt_state=opt_state)
    return new_train_state

class Transition(NamedTuple):
  done: jnp.ndarray
  action: jnp.ndarray
  reward: jnp.ndarray
  log_prob: jnp.ndarray
  obs: jnp.ndarray
  info: jnp.ndarray

In [None]:
@eqx.filter_jit
def train_cp(conf):

  conf['n_updates'] = (
          conf['total_timesteps'] // conf['n_steps'] // conf['n_envs']
      )

  conf['mini_batchsize'] = (
          conf['n_envs'] * conf['n_steps'] // conf['n_minibatches']
      )

  env, env_params = gymnax.make(conf['env_name'])
  env = FlattenObservationWrapper(env)
  env = LogWrapper(env)

  n_actions = env.action_space(env_params).n

  rng = jax.random.PRNGKey(conf['rng'])
  rng, _rng = jax.random.split(rng)
  params = {'thetas': jax.random.uniform(
      key=_rng, shape=(conf['n_layers'] + 1, conf['n_qubits'], 3),
      minval=0.0, maxval=np.pi, dtype=DTYPE
      ),
            'lmbds': jnp.ones(shape=(conf['n_layers'],
                                    conf['n_qubits']), dtype=DTYPE
                              ),
            'ws': jnp.array([[(-1.) ** i for i in range(n_actions)]], dtype=DTYPE)
            }

  def train(rng):
    actor = QuantumActor(
        n_qubits=conf['n_qubits'], n_layers=conf['n_layers'],
        beta=conf['beta'], n_actions=n_actions, params=params
        )

    state_bounds = jnp.array([2.4, 2.5, 0.21, 2.5], dtype=DTYPE)

    def map_nested_fn(fn):
      '''Recursively apply `fn` to the key-value pairs of a nested dict'''
      def map_fn(nested_dict):
        return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
                for k, v in nested_dict.items()}
      return map_fn

    label_fn = map_nested_fn(lambda k, _: k)
    optim = optax.multi_transform({'thetas': optax.amsgrad(conf['lr_theta']),
                                  'lmbds': optax.amsgrad(conf['lr_lmbd']),
                                  'ws': optax.amsgrad(conf['lr_w'])},
                              label_fn)

    opt_state = optim.init(params)

    train_state = TrainState(model=actor, optimizer=optim, opt_state=opt_state)

    rng, rng_reset = jax.random.split(rng)
    #vmappable
    obs, env_state = env.reset(rng_reset, env_params)

    def update_episode(runner_state, ununsed):
      def env_step(runner_state, ununsed):

        last_obs, env_state, train_state, rng = runner_state
        rng, rng_step, rng_net = jax.random.split(rng, 3)

        actor = train_state.model
        policy = actor(last_obs.reshape(-1))
        # policy = actor(last_obs.reshape(-1) / state_bounds)
        action = policy.sample(seed=rng_net)
        log_prob = policy.log_prob(action)

        #vmappable
        obs, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)

        transition = Transition(
            done, action, reward, log_prob, obs, info)

        runner_state = (obs, env_state, train_state, rng)

        return runner_state, transition

      runner_state, traj_batch = jax.lax.scan(env_step, runner_state, None, conf['n_steps'])

      return runner_state, traj_batch

      obs, env_state, train_state, rng = runner_state

      def calculate_returns(traj_batch):
          def _compute_discounted_sum(carry, transition):
              rewards_to_go = carry
              reward = transition.reward
              rewards_to_go = reward + conf['gamma'] * rewards_to_go
              baseline = 0
              return rewards_to_go, rewards_to_go

          init_carry = jnp.zeros_like(0, dtype=DTYPE)

          _, returns = jax.lax.scan(
              _compute_discounted_sum,
              init_carry,
              traj_batch,
              reverse=True,
          )
          return returns

      returns = calculate_returns(traj_batch)

      def update_epoch(update_state, ununsed):
        @eqx.filter_value_and_grad
        def reinforce_update(actor, batch_info):
          traj_batch, returns = batch_info
          policy =


      # UPDATE ACTOR
      def _update_epoch(update_state, unused):
        def _update_minbatch(train_state, batch_info):
          traj_batch, discounted_rewards = batch_info

          @eqx.filter_value_and_grad
          def _loss_fn(actor, traj_batch):
            #TODO: can I use vmap here?
            # RERUN ACTOR
            print(traj_batch.obs.shape)
            policy = actor(traj_batch.obs / state_bounds)
            log_prob = policy.log_prob(traj_batch.action)
            # for stability while training and less variability
            returns = (discounted_rewards - jnp.mean(discounted_rewards)) / (jnp.std(discounted_rewards) + 1e-8)
            loss = -jnp.mean(log_prob * returns)

            return loss

          loss, grads = _loss_fn(train_state.model, traj_batch)
          print(grads)
          train_state = train_state.apply_gradients(train_state.model.get_params(), grads)

          return train_state, loss

        train_state, traj_batch, discounted_rewards, rng = update_state
        rng, _rng = jax.random.split(rng)

        # Mini-batch Updates
        batch_size = conf['mini_batchsize'] * conf['n_minibatches']
        assert (batch_size == conf['n_steps'] * conf['n_envs']
        ), 'batch size must be equal to number of steps * number of envs'
        permutation = jax.random.permutation(_rng, batch_size)
        batch = (traj_batch, discounted_rewards)
        # print(batch)

        #TODO: @Yash in future. We only need this if we are using VMAP. for now commenting it out!!
        # batch = jax.tree_util.tree_map(
        #                 lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
        shuffled_batch = jax.tree_util.tree_map(
                        lambda x: jnp.take(x, permutation, axis=0), batch
                    )

        minibatches = jax.tree_util.tree_map(
            lambda x: jnp.reshape(
                x, [conf['n_minibatches'], -1] + list(x.shape[1:])
            ),
            shuffled_batch,
        )
        # print(train_state)
        # print(minibatches)

        train_state, loss = jax.lax.scan(
            _update_minbatch, train_state, minibatches
        )
        update_state = (train_state, traj_batch, discounted_rewards, rng)
        return update_state, loss

      # Updating training state and metrics
      update_state = (train_state, traj_batch, discounted_rewards, rng)
      update_state, loss_info = jax.lax.scan(
          _update_epoch, update_state, None, conf['update_epochs']
          )
      train_state = update_state[0]
      metric = traj_batch.info
      rng = update_state[-1]













# def rollout(rng_input, policy_params, env_params, steps_in_episode):
#     """Rollout a jitted gymnax episode with lax.scan."""
#     # Reset the environment
#     rng_reset, rng_episode = jax.random.split(rng_input)
#     obs, state = env.reset(rng_reset, env_params)

#     def policy_step(state_input, tmp):
#         """lax.scan compatible step transition in jax env."""
#         obs, state, policy_params, rng = state_input
#         rng, rng_step, rng_net = jax.random.split(rng, 3)
#         action = model.apply(policy_params, obs, rng_net)
#         next_obs, next_state, reward, done, _ = env.step(
#           rng_step, state, action, env_params
#         )
#         carry = [next_obs, next_state, policy_params, rng]
#         return carry, [obs, action, reward, next_obs, done]

#     # Scan over episode step loop
#     _, scan_out = jax.lax.scan(
#       policy_step,
#       [obs, state, policy_params, rng_episode],
#       (),
#       steps_in_episode
#     )
#     # Return masked sum of rewards accumulated by agent in episode
#     obs, action, reward, next_obs, done = scan_out
#     return obs, action, reward, next_obs, done








In [10]:
#TODO: in the tfq, tutorial the state was normalized using state_bounds = np.array([2.4, 2.5, 0.21, 2.5])
# Need to double check whether it is necessary for our iteration of the code.
# - vmapping over n_envs is throwing a weird error which needs further investigation, workaround right now is the set n_envs in conf to 1 and commenting out all vmap comments


@eqx.filter_jit
def make_train(conf):

  conf['n_updates'] = (
          conf['total_timesteps'] // conf['n_steps'] // conf['n_envs']
      )

  conf['mini_batchsize'] = (
          conf['n_envs'] * conf['n_steps'] // conf['n_minibatches']
      )

  env, env_params = gymnax.make(conf['env_name'])
  env = FlattenObservationWrapper(env)
  env = LogWrapper(env)

  n_actions = env.action_space(env_params).n

  rng = jax.random.PRNGKey(conf['rng'])
  rng, _rng = jax.random.split(rng)
  params = {'thetas': jax.random.uniform(
      key=_rng, shape=(conf['n_layers'] + 1, conf['n_qubits'], 3),
      minval=0.0, maxval=np.pi, dtype=DTYPE
      ),
            'lmbds': jnp.ones(shape=(conf['n_layers'],
                                    conf['n_qubits']), dtype=DTYPE
                              ),
            'ws': jnp.array([[(-1.) ** i for i in range(n_actions)]], dtype=DTYPE)
            }

  def train(rng):
    actor = QuantumActor(
        n_qubits=conf['n_qubits'], n_layers=conf['n_layers'],
        beta=conf['beta'], n_actions=n_actions, params=params
        )

    state_bounds = jnp.array([2.4, 2.5, 0.21, 2.5], dtype=DTYPE)

    def map_nested_fn(fn):
      '''Recursively apply `fn` to the key-value pairs of a nested dict'''
      def map_fn(nested_dict):
        return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
                for k, v in nested_dict.items()}
      return map_fn

    label_fn = map_nested_fn(lambda k, _: k)
    optim = optax.multi_transform({'thetas': optax.amsgrad(conf['lr_theta']),
                                  'lmbds': optax.amsgrad(conf['lr_lmbd']),
                                  'ws': optax.amsgrad(conf['lr_w'])},
                              label_fn)

    opt_state = optim.init(params)

    train_state = TrainState(model=actor, optimizer=optim, opt_state=opt_state)

    rng, _rng = jax.random.split(rng)
    # reset_rng = jax.random.split(_rng, conf['n_envs'])
    # obs, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
    obs, env_state = env.reset(_rng, env_params)

    def _update_step(runner_state, unused):
      # COLLECT TRAJECTORIES
      def _env_step(runner_state, ununsed):
        train_state, env_state, last_obs, rng = runner_state
        # print(last_obs)
        rng, _rng = jax.random.split(rng)
        actor = train_state.model
        policy = actor(last_obs.reshape(-1) / state_bounds)
        action = policy.sample(seed=_rng)
        log_prob = policy.log_prob(action)
        # print(action, log_prob)

        rng, _rng = jax.random.split(rng)
        # rng_step = jax.random.split(_rng, conf['n_envs'])

        obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        # obs, env_state, reward, done, info = jax.vmap(
        #     env.step, in_axes=(0, 0, 0, None)
        #     )(rng_step, env_state, action, env_params)

        transition = Transition(
            done, action, reward, log_prob, last_obs, info)

        runner_state = (train_state, env_state, obs, rng)

        return runner_state, transition

      train_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, conf['n_steps'])

      train_state, env_state, last_obs, rng = runner_state

      def compute_rewards(traj_batch, gamma):
        returns = []
        discounted_sum = 0
        for r in traj_batch.reward[::-1]:
          discounted_sum = r + gamma * discounted_sum
          returns.insert(0, discounted_sum)
        # returns = (returns - jnp.mean(returns)) / (jnp.std(returns) + 1e-8)

        return jnp.array(returns, dtype=DTYPE)
      discounted_rewards = compute_rewards(traj_batch, conf['gamma'])
      # print(discounted_rewards.shape)

      # UPDATE ACTOR
      def _update_epoch(update_state, unused):
        def _update_minbatch(train_state, batch_info):
          traj_batch, discounted_rewards = batch_info

          @eqx.filter_value_and_grad
          def _loss_fn(actor, traj_batch):
            #TODO: can I use vmap here?
            # RERUN ACTOR
            print(traj_batch.obs.shape)
            policy = actor(traj_batch.obs / state_bounds)
            log_prob = policy.log_prob(traj_batch.action)
            # for stability while training and less variability
            returns = (discounted_rewards - jnp.mean(discounted_rewards)) / (jnp.std(discounted_rewards) + 1e-8)
            loss = -jnp.mean(log_prob * returns)

            return loss

          loss, grads = _loss_fn(train_state.model, traj_batch)
          print(grads)
          train_state = train_state.apply_gradients(train_state.model.get_params(), grads)

          return train_state, loss

        train_state, traj_batch, discounted_rewards, rng = update_state
        rng, _rng = jax.random.split(rng)

        # Mini-batch Updates
        batch_size = conf['mini_batchsize'] * conf['n_minibatches']
        assert (batch_size == conf['n_steps'] * conf['n_envs']
        ), 'batch size must be equal to number of steps * number of envs'
        permutation = jax.random.permutation(_rng, batch_size)
        batch = (traj_batch, discounted_rewards)
        # print(batch)

        #TODO: @Yash in future. We only need this if we are using VMAP. for now commenting it out!!
        # batch = jax.tree_util.tree_map(
        #                 lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
        shuffled_batch = jax.tree_util.tree_map(
                        lambda x: jnp.take(x, permutation, axis=0), batch
                    )

        minibatches = jax.tree_util.tree_map(
            lambda x: jnp.reshape(
                x, [conf['n_minibatches'], -1] + list(x.shape[1:])
            ),
            shuffled_batch,
        )
        # print(train_state)
        # print(minibatches)

        train_state, loss = jax.lax.scan(
            _update_minbatch, train_state, minibatches
        )
        update_state = (train_state, traj_batch, discounted_rewards, rng)
        return update_state, loss

      # Updating training state and metrics
      update_state = (train_state, traj_batch, discounted_rewards, rng)
      update_state, loss_info = jax.lax.scan(
          _update_epoch, update_state, None, conf['update_epochs']
          )
      train_state = update_state[0]
      metric = traj_batch.info
      rng = update_state[-1]

      # Debugging mode
      if conf['debug']:
        def callback(info):
          return_values = info['returned_episode_returns'][info['returned_episode']]
          timesteps = info['timestep'][info['returned_episode']] * conf['n_envs']
          for t in range(len(timesteps)):
            print(f"global step={timesteps[t]}, episodic return={return_values[t]}")
        jax.debug.callback(callback, metric)

      runner_state = (train_state, env_state, last_obs, rng)

      return runner_state, metric

    rng, _rng = jax.random.split(rng)
    runner_state = (train_state, env_state, obs, _rng)
    runner_state, metric = jax.lax.scan(
        _update_step, runner_state, None, conf['n_updates']
        )
    return {'runner_state': runner_state, 'metrics': metric}

  return train

In [11]:
conf = {'n_layers': 5,
        'n_qubits': 4,
        'beta': 1.0,
        'n_envs': 1,
        'total_timesteps': 125000,
        'n_steps': 64,
        'gamma': 0.99,
        'n_minibatches': 4,
        'update_epochs': 4,
        'debug': True,
        'env_name': 'CartPole-v1',
        'lr_theta': 0.001,
        'lr_lmbd': 0.1,
        'lr_w': 0.1,
        'rng': 42}

out = make_train(conf)

  params = {'thetas': jax.random.uniform(
  'lmbds': jnp.ones(shape=(conf['n_layers'],
  'ws': jnp.array([[(-1.) ** i for i in range(n_actions)]], dtype=DTYPE)


In [12]:
out(jax.random.PRNGKey(42))

  self.theta = jax.random.uniform(key=key, shape=(n_layers + 1, n_qubits, 3),
  self.lmbd = jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)
  state_bounds = jnp.array([2.4, 2.5, 0.21, 2.5], dtype=DTYPE)


(4,)


  alt = self.alternating(jnp.array([pqc], dtype=DTYPE))


(4,)


  alt = self.alternating(jnp.array([pqc], dtype=DTYPE))
  return jnp.array(returns, dtype=DTYPE)


(16, 4)


TypeError: div got incompatible shapes for broadcasting: (64,), (4,).

In [None]:
env, env_params = gymnax.make(conf['env_name'])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

In [38]:
env, env_params = gymnax.make(conf['env_name'])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

n_actions = env.action_space(env_params).n

In [39]:
rng, _rng = jax.random.split(jax.random.PRNGKey(42))
reset_rng = jax.random.split(_rng, 1)
obs, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

In [57]:
help(env.step)

Help on method step in module __main__:

step(key: jax.Array, state: gymnax.environments.environment.EnvState, action: Union[int, float], params: Optional[gymnax.environments.environment.EnvParams] = None) -> Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number], gymnax.environments.environment.EnvState, float, bool, dict] method of __main__.LogWrapper instance



In [None]:
env_state.env_state.

<bound method EnvState.__repr__ of EnvState(x=Array([0.04827876], dtype=float64), x_dot=Array([0.03837702], dtype=float64), theta=Array([0.00256703], dtype=float64), theta_dot=Array([-0.03793145], dtype=float64), time=Array([0], dtype=int64, weak_type=True))>

In [None]:
obs

Array([[ 0.04827876,  0.03837702,  0.00256703, -0.03793145]], dtype=float64)

In [None]:
np.array([[(-1.) ** i for i in range(2)]])

array([[ 1., -1.]])

In [None]:
import cirq
from functools import reduce


In [None]:
qubits = cirq.GridQubit.rect(1, 4)
ops = [cirq.Z(q) for q in qubits]
observables = [reduce((lambda x, y: x * y), ops)]

In [None]:
observables

[Z(q(0, 0))*Z(q(0, 1))*Z(q(0, 2))*Z(q(0, 3))]

In [None]:
input = [tf.keras.Input(shape=(4,), dtype=tf.dtypes.float32, name='input')]

tf.gather(tf.shape(input[0]), 0)

import sympy

In [121]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
# from purejaxrl.wrappers import LogWrapper, FlattenObservationWrapper


class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).n, activation=config["ACTIVATION"]
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        print(f'init_x {init_x.shape}')
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])

                print(rng_step.shape)
                print(action.shape)
                obsv, env_state, reward, done, info = jax.vmap(
                    env.step, in_axes=(0, 0, 0, None)
                )(rng_step, env_state, action, env_params)
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)
            print(traj_batch.reward.shape)
            print(advantages.shape)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                # Batching and Shuffling
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                # print(advantages.shape)
                # print(targets.shape)
                print(batch)
                # print(traj_batch.shape)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                print(batch)
                # print(traj_batch.reshape((batch_size,) + traj_batch.shape[2:]))
                # batch[]
                # print(traj_batch.reshape((batch_size,) + traj_batch.shape[2:]).shape)

                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                # Mini-batch Updates
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss
            # Updating Training State and Metrics:
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]

            # Debugging mode
            if config.get("DEBUG"):
                def callback(info):
                    return_values = info["returned_episode_returns"][info["returned_episode"]]
                    timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    for t in range(len(timesteps)):
                        print(f"global step={timesteps[t]}, episodic return={return_values[t]}")
                jax.debug.callback(callback, metric)

            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metrics": metric}

    return train


if __name__ == "__main__":
  # conf = {'n_layers': 5,
  #       'n_qubits': 4,
  #       'beta': 1.0,
  #       'n_envs': 1,
  #       'total_timesteps': 125000,
  #       'n_steps': 64,
  #       'gamma': 0.99,
  #       'n_minibatches': 4,
  #       'update_epochs': 4,
  #       'debug': True,
  #       'env_name': 'CartPole-v1',
  #       'lr_theta': 0.001,
  #       'lr_lmbd': 0.1,
  #       'lr_w': 0.1,
  #       'rng': 42}


    config = {
        "LR": 2.5e-4,
        "NUM_ENVS": 1,
        "NUM_STEPS": 64,
        "TOTAL_TIMESTEPS": 125000,
        "UPDATE_EPOCHS": 4,
        "NUM_MINIBATCHES": 4,
        "GAMMA": 0.99,
        "GAE_LAMBDA": 0.95,
        "CLIP_EPS": 0.2,
        "ENT_COEF": 0.01,
        "VF_COEF": 0.5,
        "MAX_GRAD_NORM": 0.5,
        "ACTIVATION": "tanh",
        "ENV_NAME": "CartPole-v1",
        "ANNEAL_LR": True,
        "DEBUG": True,
    }
    rng = jax.random.PRNGKey(30)
    train_jit = jax.jit(make_train(config))
    out = train_jit(rng)

init_x (4,)
(1, 2)
(1,)
(1, 2)
(1,)
(64, 1)
(64, 1)
(Transition(done=Traced<ShapedArray(bool[64,1])>with<DynamicJaxprTrace(level=3/0)>, action=Traced<ShapedArray(int64[64,1])>with<DynamicJaxprTrace(level=3/0)>, value=Traced<ShapedArray(float64[64,1])>with<DynamicJaxprTrace(level=3/0)>, reward=Traced<ShapedArray(float64[64,1], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>, log_prob=Traced<ShapedArray(float64[64,1])>with<DynamicJaxprTrace(level=3/0)>, obs=Traced<ShapedArray(float64[64,1,4])>with<DynamicJaxprTrace(level=3/0)>, info={'discount': Traced<ShapedArray(float64[64,1], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>, 'returned_episode': Traced<ShapedArray(bool[64,1])>with<DynamicJaxprTrace(level=3/0)>, 'returned_episode_lengths': Traced<ShapedArray(int64[64,1], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>, 'returned_episode_returns': Traced<ShapedArray(float64[64,1])>with<DynamicJaxprTrace(level=3/0)>, 'timestep': Traced<ShapedArray(int64[64,1], weak_type=True)>with

KeyboardInterrupt: 

In [16]:
from flax import linen as nn


class MLP(nn.Module):
    """Simple ReLU MLP."""

    num_hidden_units: int
    num_hidden_layers: int
    num_output_units: int

    @nn.compact
    def __call__(self, x, rng):
        for l in range(self.num_hidden_layers):
            x = nn.Dense(features=self.num_hidden_units)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.num_output_units)(x)
        return x


model = MLP(48, 1, 1)
policy_params = model.init(jax.random.PRNGKey(0), jnp.zeros(3), None)
print(type(policy_params))

<class 'dict'>


In [45]:
rewards = [1,2,3,4,5,6,7,10]
gamma = 0.99


returns = []
discounted_sum = 0
for r in rewards[::-1]:
  discounted_sum = r + gamma * discounted_sum
  returns.insert(0, discounted_sum)

returns

[36.2214308742769,
 35.577202903309995,
 33.916366569,
 31.2286531,
 27.50369,
 22.730999999999998,
 16.9,
 10.0]

In [35]:
jnp.zeros_like(1000)

Array(0, dtype=int32, weak_type=True)

In [18]:
returns

[23.472496000000003,
 24.969440000000002,
 25.521600000000003,
 25.024,
 23.36,
 20.4,
 16.0,
 10.0]

In [24]:
#     def policy_step(state_input, tmp):
#         """lax.scan compatible step transition in jax env."""
#         obs, state, policy_params, rng = state_input
#         rng, rng_step, rng_net = jax.random.split(rng, 3)
#         action = model.apply(policy_params, obs, rng_net)
#         next_obs, next_state, reward, done, _ = env.step(
#           rng_step, state, action, env_params
#         )
#         carry = [next_obs, next_state, policy_params, rng]
#         return carry, [obs, action, reward, next_obs, done]

def _calculate_gae(traj_batch, last_val):
    def _get_advantages(gae_and_next_value, transition):
        gae, next_value = gae_and_next_value
        done, value, reward = (
            transition.done,
            transition.value,
            transition.reward,
        )
        delta = reward + config["GAMMA"] * next_value * (1 - done) - value
        gae = (
            delta
            + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
        )
        return (gae, value), gae

    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_val), last_val),
        traj_batch,
        reverse=True,
        unroll=16,
    )
    return advantages, advantages + traj_batch.value


def _calculate_returns(traj_batch):
  def _discounted_sum(return_and_next_sum, transition):
    discounted_sum = return_and_next_sum
    done, reward =





In [43]:
class Trans(NamedTuple):
  reward: jnp.ndarray

def calculate_returns(traj_batch):
    def _compute_discounted_sum(carry, transition):
        rewards_to_go = carry
        reward = transition.reward
        rewards_to_go = reward + conf['gamma'] * rewards_to_go
        baseline = 0
        return rewards_to_go, rewards_to_go

    init_carry = jnp.zeros_like(0, dtype=DTYPE)

    _, returns = jax.lax.scan(
        _compute_discounted_sum,
        init_carry,
        traj_batch,
        reverse=True,
    )
    return returns

traj_batch = Trans(reward=jnp.array([1,2,3,4,5,6,7,10], dtype=DTYPE))
calculate_returns(traj_batch)

  traj_batch = Trans(reward=jnp.array([1,2,3,4,5,6,7,10], dtype=DTYPE))
  init_carry = jnp.zeros_like(0, dtype=DTYPE)


Array([36.22143 , 35.5772  , 33.916367, 31.228653, 27.503689, 22.730999,
       16.9     , 10.      ], dtype=float32)

In [41]:
def rcal(rewards_history, ununsed):
  discounted_sum, reward = rewards_history

  new_discounted_sum = reward + 0.99 * discounted_sum

  carry = (new_discounted_sum, reward)

  return carry, new_discounted_sum

_, ds = jax.lax.scan(
        rcal,
        (0.0, [1,2,3,4,5,6,7,10]),
        None,
        reverse=True,
        length=8)

TypeError: unsupported operand type(s) for +: 'list' and 'DynamicJaxprTracer'

In [None]:
    def _discount_rewards(carry, reward):
        running_add, discounted_rewards = carry
        running_add = reward + config["GAMMA"] * running_add
        discounted_rewards = jax.ops.index_update(discounted_rewards, jax.ops.index[i], running_add)
        return (running_add, discounted_rewards), (running_add, discounted_rewards)

#     def policy_step(state_input, tmp):
#         """lax.scan compatible step transition in jax env."""
#         obs, state, policy_params, rng = state_input
#         rng, rng_step, rng_net = jax.random.split(rng, 3)
#         action = model.apply(policy_params, obs, rng_net)
#         next_obs, next_state, reward, done, _ = env.step(
#           rng_step, state, action, env_params
#         )
#         carry = [next_obs, next_state, policy_params, rng]
#         return carry, [obs, action, reward, next_obs, done]

In [None]:
def _calculate_discounted_rewards(traj_batch, last_val):
    def _discount_rewards(carry, reward):
        running_add, discounted_rewards = carry
        running_add = reward + config["GAMMA"] * running_add
        discounted_rewards = jax.ops.index_update(discounted_rewards, jax.ops.index[i], running_add)
        return (running_add, discounted_rewards), (running_add, discounted_rewards)

    init_carry = (last_val, jnp.zeros_like(traj_batch.reward))

    _, (_, discounted_rewards) = jax.lax.scan(
        _discount_rewards,
        init_carry,
        traj_batch.reward,
        reverse=True
    )

    return discounted_rewards, discounted_rewards + last_val