<a href="https://colab.research.google.com/github/Zakuta/D-QRL/blob/main/QRL_try_2_feb23_2024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install equinox
# !pip install tensorcircuit
# !pip install -U qiskit
# !pip install tensorcircuit
# !pip install cirq
# !pip install openfermion
# !pip install gymnax
# !pip install brax
# !pip install distrax

In [20]:
import jax
from jax import config

config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)

import jax.numpy as jnp
DTYPE=jnp.float64

import chex
import numpy as np
import optax
from flax import struct
from functools import partial
import tensorcircuit as tc

import tensorflow as tf
from sklearn.decomposition import PCA
import equinox as eqx
import types
from jaxtyping import Array, PRNGKeyArray
from typing import Union, Sequence, List, NamedTuple, Optional, Tuple, Any, Literal, TypeVar
import jax.tree_util as jtu
import gymnax
import distrax
from gymnax.environments import environment, spaces
from brax import envs
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper

K = tc.set_backend("jax")

In [3]:
# shamelessly taken from purejaxrl: https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/wrappers.py

class GymnaxWrapper(object):
    """Base class for Gymnax wrappers."""

    def __init__(self, env):
        self._env = env

    # provide proxy access to regular attributes of wrapped object
    def __getattr__(self, name):
        return getattr(self._env, name)


class FlattenObservationWrapper(GymnaxWrapper):
    """Flatten the observations of the environment."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    def observation_space(self, params) -> spaces.Box:
        assert isinstance(
            self._env.observation_space(params), spaces.Box
        ), "Only Box spaces are supported for now."
        return spaces.Box(
            low=self._env.observation_space(params).low,
            high=self._env.observation_space(params).high,
            shape=(np.prod(self._env.observation_space(params).shape),),
            dtype=self._env.observation_space(params).dtype,
        )

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, state = self._env.reset(key, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, state, reward, done, info = self._env.step(key, state, action, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state, reward, done, info


@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int
    timestep: int


class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params)
        state = LogEnvState(env_state, 0, 0, 0, 0, 0)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        new_episode_return = state.episode_returns + reward
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - done),
            episode_lengths=new_episode_length * (1 - done),
            returned_episode_returns=state.returned_episode_returns * (1 - done)
            + new_episode_return * done,
            returned_episode_lengths=state.returned_episode_lengths * (1 - done)
            + new_episode_length * done,
            timestep=state.timestep + 1,
        )
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["timestep"] = state.timestep
        info["returned_episode"] = done
        return obs, state, reward, done, info

In [21]:
def reuploading_circuit(n_qubits, n_layers, rot_params, input_params, X):
  circuit = tc.Circuit(n_qubits)
  # params = np.random.normal(size=(n_layers + 1, n_qubits, 3))
  # inputs = np.random.normal(size=(n_layers, n_qubits))

  for l in range(n_layers):
    # variational part
    for qubit_idx in range(n_qubits):
      circuit.rx(qubit_idx, theta=rot_params[l, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=rot_params[l, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=rot_params[l, qubit_idx, 2])

    # entangling part
    for qubit_idx in range(n_qubits - 1):
      circuit.cnot(qubit_idx, qubit_idx + 1)
    if n_qubits != 2:
      circuit.cnot(n_qubits - 1, 0)

    # encoding part
    for qubit_idx in range(n_qubits):
      input = X[qubit_idx] * input_params[l, qubit_idx]
      circuit.rx(qubit_idx, theta=input)

  # last variational part
  for qubit_idx in range(n_qubits):
    circuit.rx(qubit_idx, theta=rot_params[n_layers, qubit_idx, 0])
    circuit.ry(qubit_idx, theta=rot_params[n_layers, qubit_idx, 1])
    circuit.rz(qubit_idx, theta=rot_params[n_layers, qubit_idx, 2])

  return circuit


class PQCLayer(eqx.Module):
  theta: Array
  lmbd: Array
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)

  def __init__(self, n_qubits: int, n_layers: int, params: Optional, key: PRNGKeyArray):

    key = jax.random.PRNGKey(key)
    tkey, lkey = jax.random.split(key, num=2)

    if params is None:
      self.theta = params['thetas']
      self.lmbd = params['lmbds']
    else:
      self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE)
      self.lmbd = jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)

    self.n_qubits = n_qubits
    self.n_layers = n_layers

  def __call__(self, inputs):
    circuit = tc.Circuit(self.n_qubits)

    for l in range(self.n_layers):
      # variational part
      for qubit_idx in range(self.n_qubits):
        circuit.rx(qubit_idx, theta=self.theta[l, qubit_idx, 0])
        circuit.ry(qubit_idx, theta=self.theta[l, qubit_idx, 1])
        circuit.rz(qubit_idx, theta=self.theta[l, qubit_idx, 2])

      # entangling part
      for qubit_idx in range(self.n_qubits - 1):
        circuit.cnot(qubit_idx, qubit_idx + 1)
      if self.n_qubits != 2:
        circuit.cnot(self.n_qubits - 1, 0)

      # encoding part
      for qubit_idx in range(self.n_qubits):
        linear_input = inputs[qubit_idx] * self.lmbd[l, qubit_idx]
        circuit.rx(qubit_idx, theta=linear_input)

    # last variational part
    for qubit_idx in range(self.n_qubits):
      circuit.rx(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 2])

    return jnp.real(circuit.expectation_ps(z=jnp.arange(len(self.n_qubits))))


# class PQCLayer(eqx.Module):
#   theta: jax.Array = eqx.field(converter=jnp.asarray)
#   lmbd: jax.Array = eqx.field(converter=jnp.asarray)
#   n_qubits: int = eqx.field(static=True)
#   n_layers: int = eqx.field(static=True)

#   def __init__(self, n_qubits: int, n_layers: int, key: int):
#     key = jax.random.PRNGKey(key)
#     tkey, lkey = jax.random.split(key, num=2)
#     self.n_qubits = n_qubits
#     self.n_layers = n_layers
#     # rotation_params
#     self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
#                                     minval=0.0, maxval=np.pi, dtype=DTYPE)
#     # input encoding params
#     # self.lmbd = jnp.ones(shape=(n_layers, n_qubits))
#     self.lmbd = jax.random.uniform(key=lkey, shape=(n_layers, n_qubits),
#                                     minval=0.0, maxval=np.pi, dtype=DTYPE)

#   def __call__(self, inputs):
#   # def __call__(self, X, n_qubits, depth):

#     circuit = generate_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)
#     # state = circuit.state()
#     # return state
#     return K.real(circuit.expectation_ps(z=[0,1,2,3]))

# class Alternating(eqx.Module):
#   w: jax.Array = eqx.field(converter=jnp.asarray)

#   def __init__(self, output_dim):
#     self.w = jnp.array([[(-1.) ** i for i in range(output_dim)]])

#   def __call__(self, inputs):
#     return jnp.matmul(inputs, self.w)


# class Actor(eqx.Module):
#   n_qubits: int
#   n_layers: int
#   beta: float
#   n_actions: Sequence[int]
#   key: int

#   def __call__(self, x):
#     re_uploading_pqc = PQCLayer(n_qubits=self.n_qubits,
#                                 n_layers=self.n_layers,
#                                 key=self.key)(x)

#     process = eqx.nn.Sequential([
#         Alternating(self.n_actions),
#         eqx.nn.Lambda(lambda x: x * self.beta),
#         jax.nn.softmax()
#     ])

#     policy = process(re_uploading_pqc)

#     return policy



class QuantumActor(eqx.Module):
  theta: jax.Array # trainable
  lmbd: jax.Array # trainable
  w: jax.Array # trainable
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)
  beta: float = eqx.field(static=True)
  n_actions: Sequence[int] = eqx.field(static=True)
  # key: int

  def __init__(self, n_qubits, n_layers, beta, n_actions, params: Optional, key = 42):

    key = jax.random.PRNGKey(key)
    key, _key = jax.random.split(key, num=2)

    if params is None:
      # rotation_params
      self.theta = params['thetas']
      # input encoding params
      self.lmbd = params['lmbds']
      # observable weights
      self.w = params['ws']
    else:
      self.theta = jax.random.uniform(key=key, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE)
      self.lmbd = jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)
      self.w = jnp.array([[(-1.) ** i for i in range(n_actions)]])


    self.n_qubits = n_qubits
    self.n_layers = n_layers
    self.beta = beta
    self.n_actions = n_actions

  def quantum_policy_circuit(self, inputs):

    # this can be any PQC of the user's choice. hence, I made the decision to make a separate function within this class
    circuit = reuploading_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)

    return K.real(circuit.expectation_ps(z=np.arange(self.n_qubits)))

  def alternating(self, inputs):
    return jnp.matmul(inputs, self.w)

  def get_params(self):
    return {"theta": self.theta, "lmbd": self.lmbd, "w": self.w}

  def __call__(self, x):

    pqc = self.quantum_policy_circuit(x)
    alt = self.alternating(pqc)

    # process = eqx.nn.Sequential([
    #     alt,
    #     eqx.nn.Lambda(lambda x: x * self.beta),
    #     jax.nn.softmax()
    # ])
    # policy = process(pqc)

    actor_mean = eqx.nn.Lambda(lambda x: x * self.beta)(alt)
    # policy = jax.nn.softmax(actor_mean)
    policy = distrax.Softmax(actor_mean)

    return policy


# class Actor(eqx.Module):
#   n_qubits: int
#   n_layers: int
#   beta: float
#   n_actions: Sequence[int]
#   pqc: eqx.Module
#   alt: eqx.Module
#   key: int

#   def __init__(self, n_qubits, n_layers, beta, n_actions, key):
#     self.n_qubits = n_qubits
#     self.n_layers = n_layers
#     self.beta = beta
#     self.n_actions = n_actions
#     self.key = key

#     self.pqc = PQCLayer(n_qubits=self.n_qubits,
#                         n_layers=self.n_layers,
#                         key=self.key)

#     self.alt = Alternating(self.n_actions)

#   def __call__(self, x):
#     re_uploading_pqc = self.pqc(x)

#     process = eqx.nn.Sequential([
#         self.alt,
#         eqx.nn.Lambda(lambda x: x * self.beta),
#         jax.nn.softmax()
#     ])

#     policy = process(re_uploading_pqc)

#     return policy


class TrainState(eqx.Module):
    model: eqx.Module
    optimizer: optax.GradientTransformation = eqx.field(static=True)
    opt_state: optax.OptState

    def __init__(self, model, optimizer, opt_state):
        self.model = model
        self.optimizer = optimizer
        self.opt_state = opt_state

    def apply_updates_to_model(self, new_params):
      # this function is specific and works for this example only. one can think of
      # generalizing it to work for updating any given attribute/params of the model.

      model_new = eqx.tree_at(where=lambda model: model.theta, pytree=self.model, replace=new_params['thetas'])
      model_new = eqx.tree_at(where=lambda model: model.lmbd, pytree=model_new, replace=new_params['lmbds'])
      model_new = eqx.tree_at(where=lambda model: model.w, pytree=model_new, replace=new_params['ws'])

      return model_new

    def apply_gradients(self, params, grads):
        grads_dict = {'thetas': grads.theta, 'lmbds': grads.lmbd}
        updates, opt_state = self.optimizer.update(grads_dict, self.opt_state, params)
        new_params = optax.apply_updates(params, updates)
        model_new = self.apply_updates_to_model(new_params)
        # model = eqx.apply_updates(self.model, updates)
        new_train_state = self.__class__(model=model_new, optimizer=self.optimizer, opt_state=opt_state)
        return new_train_state

class Transition(NamedTuple):
  done: jnp.ndarray
  action: jnp.ndarray
  reward: jnp.ndarray
  log_prob: jnp.ndarray
  obs: jnp.ndarray
  info: jnp.ndarray

In [19]:
n_layers = 5
n_qubits = 4
beta = 1.0

n_envs = 2
env, env_params = gymnax.make('CartPole-v1')
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

n_actions = env.action_space(env_params).n

rng = 42
rng = jax.random.PRNGKey(rng)
rng, _rng = jax.random.split(rng)
params = {'thetas': jax.random.uniform(key=_rng, shape=(n_layers + 1, n_qubits, 3),
                                       minval=0.0, maxval=np.pi, dtype=DTYPE),
          'lmbds': jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE),
          'ws': jnp.array([[(-1.) ** i for i in range(n_actions)]], dtype=DTYPE)}

actor = QuantumActor(n_qubits=n_qubits, n_layers=n_layers, beta=beta, n_actions=n_actions, params=params)

def map_nested_fn(fn):
  '''Recursively apply `fn` to the key-value pairs of a nested dict'''
  def map_fn(nested_dict):
    return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
            for k, v in nested_dict.items()}
  return map_fn

label_fn = map_nested_fn(lambda k, _: k)
optim = optax.multi_transform({'thetas': optax.amsgrad(0.001),
                               'lmbds': optax.amsgrad(0.1),
                               'ws': optax.amsgrad(0.1)},
                           label_fn)

# optim = closure_to_pytree(optim)
opt_state = optim.init(params)

train_state = TrainState(model=actor, optimizer=optim, opt_state=opt_state)

rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, n_envs)
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

def _update_step(runner_state, unused):
  # COLLECT TRAJECTORIES
  def _env_step(runner_state, ununsed):
    train_state, env_state, last_obs, rng = runner_state

    rng, _rng = jax.random.split(rng)
    policy = actor(last_obs)
    action = policy.sample(seed=_rng)
    log_prob = policy.log_prob(action)
    # action_probs = actor(last_obs)
    # action = jax.random.choice(key=_rng, a=n_actions, p=action_probs)

    rng, _rng = jax.random.split(rng)
    rng_step = jax.random.split(_rng, n_envs)
    obs, env_state, reward, done, info = jax.vmap(
        env.step, in_axes=(0, 0, 0, None)
        )(rng_step, env_state, action, env_params)

    transition = Transition(
        done, action, reward, log_prob, last_obs, info)

    runner_state = (train_state, env_state, obs, rng)

    return runner_state, transition




