<a href="https://colab.research.google.com/github/Zakuta/D-QRL/blob/main/JAX_reimplementation_D%26QRL_16_feb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# adapted from https://www.tensorflow.org/quantum/tutorials/quantum_data
import os
from functools import reduce
# Set the environment variable
# os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/cm/shared/easybuild/AuthenticAMD/software/CUDA/11.8.0/'

import collections

import numpy as np
import tensorflow as tf
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras import layers, losses

import torch
import torch.nn as nn

import tensorcircuit as tc
import jax
import jax.numpy as jnp
import cirq
import sympy

np.random.seed(1234)

K = tc.set_backend("jax")

In [None]:
# !pip install qiskit
# !pip install tensorcircuit
# !pip install cirqx
# !pip install openfermion

In [None]:
### circuit_components.py







def one_qubit_rotation(state, n_qubits, qubit_list, params, return_type='state'):
  if state:
    c_ = tc.Circuit(n_qubits, inputs=state)
  else:
    c_ = tc.Circuit(n_qubits)
  for qubit_idx in qubit_list:
    c_.rx(qubit_idx, theta=params[0])
    c_.ry(qubit_idx, theta=params[1])
    c_.rz(qubit_idx, theta=params[2])

  s_ = c_.state()
  if return_type == 'circuit':
    return c_
  elif return_type == 'state':
    return s_

def reg_entangling_layer(n_qubits, state, return_type='state'):
  qubit_list = np.arange(n_qubits)
  if state:
    c_ = tc.Circuit(n_qubits, inputs=state)
  else:
    c_ = tc.Circuit(n_qubits)

    for i, j in zip(qubit_list, qubit_list[1:]):
      c_.cz(i, j)
    if len(qubit_list) != 2:
      c_.cz(qubit_list[0], qubit_list[-1])

  s_ = c_.state()
  if return_type == 'circuit':
    return c_
  elif return_type == 'state':
    return s_

def entangling_layer(state, n_qubits, qubit_list, part_of_H_test=True):
  # qubit_list = list of indices from the total qubits to add entangling layer to.
  # n_qubits = total number of qubits
  # part_of_hadamard_test = boolean that indicates whether the layer is part of the hadamard test
  c_ = tc.Circuit(n_qubits, inputs=state)
  if n_qubits == 2:
    c_.cz(0, 1)
    return c_
  else:
    for i in range(len(qubit_list)):
      c_.cz(i, (i+1) % len(qubit_list))
    return c_

def bravyi_ghost_encoding(circuit, n_qubits, bravyi_params, return_type='state'):
  qubit_list = np.arange(n_qubits)
  # c_ = tc.Circuit(n_qubits)

  # ghost encoding to the first qubit of PQC
  circuit.crz(qubit_list[0], qubit_list[-1], theta=bravyi_params[0])
  # ghost encoding to the last qubit of PQC
  circuit.crz(qubit_list[-2], qubit_list[-1], theta=bravyi_params[1])
  # apply swap gate to ctrl qubit for switching to another smaller subcircuit
  circuit.x(qubit_list[-1])

  s_ = circuit.state()
  if return_type == 'circuit':
    return circuit
  elif return_type == 'state':
    return s_


In [None]:
n1 = 4
n2 = 5
qubit_list = np.arange(n1)

def t(n1):
  c1 = tc.Circuit(n1)

  for qubit_idx in range(3):
    c1.rx(qubit_idx, theta=0.1)
    c1.ry(qubit_idx, theta=0.01)
    c1.rz(qubit_idx, theta=0.0001)

  return c1


for i in range(len(qubit_list)):
  c1.cz(i, (i+1) % len(qubit_list))

c2

In [None]:
###### red_partitioned_circuit_gen.py

class PartitionedCircuitGenerator():
  # assumes that the last qubit of each of the partitioned subcircuit is a ctrl qubit.
  def __init__(self, qubit_list_partition, n_layers) -> None:
    self.qubit_list_partition = qubit_list_partition # list of qubits where the last qubit is the control qubit (split size)
    self.n_qubits_partition = len(qubit_list_partition)
    self.n_qubits_wo_ctrl = self.n_qubits_partition - 1 # number of qubits in each partition (without contorl qubit)
    self.n_layers = n_layers # number of layers of the partitioned PQC

  def generate_layer(self, layer_idx, rotation_params, bravyi_params, input_params):
    # this function generates single layer of a partitioned subcircuit.
    # the last qubit is indeed the ctrl qubit, hence, the qubit_list will be of length n-1
    # n_qubits_partition = len(qubit_list)
    qubit_list_wo_ctrl = np.arange(self.n_qubits_wo_ctrl)
    c_ = tc.Circuit(self.n_qubits_partition)
    s_ = c_.state()
    # rotation layer
    state_one_q_r = one_qubit_rotation(state=s_,
                                       n_qubits=self.n_qubits_partition,
                                       qubit_list=qubit_list_wo_ctrl,
                                       params=rotation_params,
                                       return_type='state')
    # entangling layer
    circuit = entangling_layer(state=state_one_q_r,
                               n_qubits=self.n_qubits_partition,
                               qubit_list=qubit_list_wo_ctrl)
    # 1st bravyi ghost encoding
    circuit = bravyi_ghost_encoding(circuit=circuit,
                                    n_qubits=self.n_qubits_partition,
                                    bravyi_params=bravyi_params[layer_idx])
    # 2nd bravyi ghost encoding -> I assume this has to do with number of cuts in the PQC
    # TODO: check my intuition later! @Yash
    circuit = bravyi_ghost_encoding(circuit=circuit,
                                    n_qubits=self.n_qubits_partition,
                                    bravyi_params=bravyi_params[layer_idx + self.n_layers])

    # input encoding layer.
    # TODO: @Yash, this is just Rx encoding, what about IQP encoding which is actually classically hard to simulate?
    # May be then, it would actually require more terms from the cut?
    for idx in range(self.n_qubits_wo_ctrl):
      circuit.rx(idx, input_params[idx])

    return circuit

  def generate_partitioned_circuit(self, qubit_list, real=True):

    rotation_params = np.zeros(shape=(self.n_layers + 1, self.n_qubits_wo_ctrl, 3))

    bravyi_params = np.zeros(shape=(2 * self.n_layers, 2))

    input_params = np.zeros(shape=(self.n_layers, self.n_qubits_wo_ctrl))

    partitioned_circuit = tc.Circuit(self.n_qubits_partition)

    # apply H on control qubit which is situated at the last qubit index
    partitioned_circuit.h(qubit_list[-1])
    if not real:
      partitioned_circuit.unitary(qubit_list[-1], unitary=np.array([[1, 0], [0, 1j]]), name="S")

    # apply layers to the partitioned subcircuit
    for layer_idx in range(self.n_layers):
      layer_for_partitioned_circuit = self.generate_layer(layer_idx=layer_idx,
                                                          rotation_params=rotation_params[layer_idx],
                                                          bravyi_params=bravyi_params,
                                                          input_params=input_params[layer_idx])
      partitioned_circuit.append(layer_for_partitioned_circuit)

    # add final rotation layer
    #TODO: check this one!! the input to params!!! after a small test seems okay to me
    one_qubit_rotation_circuit = one_qubit_rotation(n_qubits=self.qubit_list_partition,
                                                    qubit_list=qubit_list,
                                                    params=rotation_params[-1],
                                                    return_type='circuit')

    partitioned_circuit.append(one_qubit_rotation_circuit)

    # add final H to re-invert the circuit
    partitioned_circuit.h(qubit_list[-1])
    if not real:
      partitioned_circuit.unitary(qubit_list[-1], unitary=np.array([[1, 0], [0, 1j]]), name="S")

    return (partitioned_circuit, list(rotation_params.flat()),
            list(bravyi_params.flat()), list(input_params.flat()))


In [None]:
# red_partition_layer_gen.py

class ReducedPartitionPQCLayer():
  def __init__(self,
               n_qubits_wo_ctrl,
               n_layers,
               n_partitions,
               n_terms,
               input_dim,
               trainable_lambdas,
               rescaling_scheme,
               trainable_regular_weights,
               trainable_partition_weights) -> None:

    self.n_qubits_wo_ctrl = n_qubits_wo_ctrl
    self.n_layers = n_layers
    self.n_partitions = n_partitions
    self.rescaling_scheme = rescaling_scheme
    self.n_terms = n_terms # T in the paper, product of schmidt number squared with gate cuts
                            # In our case for the CZ Gate 4 * gate cuts
    self.input_dim = input_dim

    qubit_list = np.arange(n_qubits_wo_ctrl + 1)
    measurement_ops = 3 * np.eye(n_qubits_wo_ctrl + 1)
    #TODO: ATTENTION @Yash change to tc???
    observables = [reduce((lambda x, y: x*y), measurement_ops)]

    # define sub-circuits
    generator = PartitionedCircuitGenerator(qubit_list_partition=qubit_list,
                                            n_layers=self.n_layers)
    circuit, rotation_params, bravyi_params, input_params = generator.generate_partitioned_circuit(qubit_list=qubit_list)
    circuit_i, _, _, _ = generator.generate_partitioned_circuit(qubit_list=qubit_list, real=False)

    self.reference_circuit = circuit

    # initialize weights, use of trainable_regular_weights flag here!
    # TODO: @Yash find a feature in tc to check whether this is infact possible w/ JAXBackend.
    self.thetas = np.random.uniform(low=0.0, high=np.pi, size=(1, len(rotation_params) * n_partitions))
    self.product_term_theta_size = len(rotation_params) # storing the length of rotation params in each subcircuit

    # weights to scale the input data (input encodings), use of trainable_regular_weights flag here!
    # TODO: @Yash find a feature in tc to check whether this is infact possible w/ JAXBackend.
    self.alphas = np.ones(shape=(len(input_params) * n_partitions,))
    self.input_data = len(input_params) * n_partitions

    # weights of the bravyi encoding, use of trainable_partition_weights flag here!
    # TODO: @Yash find a feature in tc to check whether this is infact possible w/ JAXBackend.
    self.zetas = np.random.uniform(low=0.0, high=np.pi, size=(n_terms, n_partitions, len(bravyi_params)))

    ########## L89 to L96
    ## SOME FILLER LINES TO BE ADDED!!
    ##########

    # TODO: @Yash For now, the if else statement below is degenerate!!
    rescale_parameter = 1
    if self.rescaling_scheme in ['constant', 'factoring']: # here we do scaling w/o taking exponential
      # TODO: @Yash please make sure to make sure that the lambdas in this case are non-negative like Darryn use a non-negative constraint with tf.
      # Naively one can do that after applying optimizer updates to lambdas and taking [np.max(0, i) for i in lambdas].
      self.lambdas = np.ones(shape=(n_terms,)) * rescale_parameter
    else: # exponential factoring
      self.lambdas = np.ones(shape=(n_terms,)) * rescale_parameter


  def get_zetas(self):
    return self.zetas

  # TODO: @Yash self.indices need to be defined L89 to L96
  def get_indices(self):
    return self.indices

  def rescale_lambdas(self, inputs):
    batch_dim = inputs[0].shape[0]
    tiled_up_thetas = np.tile(self.thetas, reps=[batch_dim, 1])

    inputs = inputs.reshape((inputs.shape[0], -1))
    tiled_up_inputs = np.tile(inputs[0], reps=[1, self.n_layers])
    scaled_inputs = np.einsum('i, ji->ji', self.alphas, tiled_up_inputs)
    # L127 is degenerate

    # simple rescaling of the lambdas
    for i in range(self.n_terms):
      # for np
      self.lambdas[:] = 1 / self.n_terms
      #self.lambdas = jax.ops.index_update(self.lambdas, i, 1 / self.n_terms) # for jax.np

    print('rescaled lambdas:', self.lambdas)

  def call(self, inputs):
    batch_dim = inputs[0].shape[0]
    tiled_up_thetas = np.tile(self.thetas, reps=[batch_dim, 1])

    inputs = inputs.reshape((inputs.shape[0], -1))
    tiled_up_inputs = np.tile(inputs[0], reps=[1, self.n_layers])
    scaled_inputs = np.einsum('i, ji->ji', self.alphas, tiled_up_inputs)
    # TODO: @Yash check if it is neeeded at all
    # squashed_inputs = tf.keras.layers.Activation(self.activation)(scaled_inputs)

    ans = np.zeros([batch_dim, 1])
    for i in range(self.n_terms):
      pqc_layer_ans = jax.lax.complex(np.ones([batch_dim, 1]), np.zeros([batch_dim, 1]))
      for k in range(self.n_partitions):
        # get circuits
        tiled_up_circuits









In [None]:
!pip install equinox
!pip install tensorcircuit
!pip install qiskit
!pip install tensorcircuit
!pip install cirq
!pip install openfermion
!pip install gymnax
!pip install brax
!pip install distrax
# !pip install purejaxrl


In [1]:
import jax
from jax import config

config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)

import jax.numpy as jnp
DTYPE=jnp.float64

import chex
import numpy as np
import optax
from flax import struct
from functools import partial
import tensorcircuit as tc
import equinox as eqx
import types
from typing import Union, Sequence, List, NamedTuple, Optional, Tuple, Any
import jax.tree_util as jtu
from gymnax.environments import environment, spaces
from brax import envs
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper

K = tc.set_backend("jax")




In [2]:
# https://github.com/google-deepmind/optax/issues/577
# shamelessly taken from Patrick Kidger's issue from optax github. He's a legend, basically converted the optim instance of optax.optimizer to pytree!

def _make_cell(val):
    fn = lambda: val
    return fn.__closure__[0]  # pyright: ignore


def _adjust_function_closure(fn, closure):
    out = types.FunctionType(
        code=fn.__code__,
        globals=fn.__globals__,
        name=fn.__name__,
        argdefs=fn.__defaults__,
        closure=closure,
    )
    out.__module__ = fn.__module__
    out.__qualname__ = fn.__qualname__
    out.__doc__ = fn.__doc__
    out.__annotations__.update(fn.__annotations__)
    if fn.__kwdefaults__ is not None:
        out.__kwdefaults__ = fn.__kwdefaults__.copy()
    return out


# Not a pytree.
# Used so that two different local functions, with different identities, can still
# compare equal. This is needed as these leaves are compared statically when
# filter-jit'ing.
class _FunctionWithEquality:
    def __init__(self, fn: types.FunctionType):
        self.fn = fn

    def information(self):
        return self.fn.__qualname__, self.fn.__module__

    def __hash__(self):
        return hash(self.information())

    def __eq__(self, other):
        return type(self) == type(other) and self.information() == other.information()


class _Closure(eqx.Module):
    fn: _FunctionWithEquality
    contents: Optional[tuple[Any, ...]]

    def __init__(self, fn: types.FunctionType):
        self.fn = _FunctionWithEquality(fn)
        if fn.__closure__ is None:
            contents = None
        else:
            contents = tuple(
                closure_to_pytree(cell.cell_contents) for cell in fn.__closure__
            )
        self.contents = contents

    def __call__(self, *args, **kwargs):
        if self.contents is None:
            closure = None
        else:
            closure = tuple(_make_cell(contents) for contents in self.contents)
        fn = _adjust_function_closure(self.fn.fn, closure)
        return fn(*args, **kwargs)


def _fixup_closure(leaf):
    if isinstance(leaf, types.FunctionType):
        return _Closure(leaf)
    else:
        return leaf


def closure_to_pytree(tree):
    """Convert all function closures into pytree nodes.

    **Arguments:**

    - `tree`: Any pytree.

    **Returns:**

    A copy of `tree`, where all function closures have been replaced by a new object
    that is (a) callable like the original function, but (b) iterates over its
    `__closure__` as subnodes in the pytree.

    !!! Example

        ```python
        def some_fn():
            a = jnp.array(1.)

            @closure_to_pytree
            def f(x):
                return x + a

            print(jax.tree_util.tree_leaves(f))  # prints out `a`
        ```

    !!! Warning

        One annoying technical detail in the above example: we had to wrap the whole lot
        in a `some_fn`, so that we're in a local scope. Python treats functions at the
        global scope differently, and this conversion won't result in any global
        variable being treated as part of the pytree.

        In practice, the intended use case of this function is to fix Optax, which
        always uses local functions.
    """
    return jtu.tree_map(_fixup_closure, tree)


# EXAMPLE USAGE
# lr = jnp.array(1e-3)
# optim = optax.chain(
#     optax.adam(lr),
#     optax.scale_by_schedule(optax.piecewise_constant_schedule(1, {200: 0.1})),
# )
# optim = closure_to_pytree(optim)


In [3]:
# shamelessly taken from purejaxrl: https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/wrappers.py

class GymnaxWrapper(object):
    """Base class for Gymnax wrappers."""

    def __init__(self, env):
        self._env = env

    # provide proxy access to regular attributes of wrapped object
    def __getattr__(self, name):
        return getattr(self._env, name)


class FlattenObservationWrapper(GymnaxWrapper):
    """Flatten the observations of the environment."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    def observation_space(self, params) -> spaces.Box:
        assert isinstance(
            self._env.observation_space(params), spaces.Box
        ), "Only Box spaces are supported for now."
        return spaces.Box(
            low=self._env.observation_space(params).low,
            high=self._env.observation_space(params).high,
            shape=(np.prod(self._env.observation_space(params).shape),),
            dtype=self._env.observation_space(params).dtype,
        )

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, state = self._env.reset(key, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, state, reward, done, info = self._env.step(key, state, action, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state, reward, done, info


@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int
    timestep: int


class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params)
        state = LogEnvState(env_state, 0, 0, 0, 0, 0)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        new_episode_return = state.episode_returns + reward
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - done),
            episode_lengths=new_episode_length * (1 - done),
            returned_episode_returns=state.returned_episode_returns * (1 - done)
            + new_episode_return * done,
            returned_episode_lengths=state.returned_episode_lengths * (1 - done)
            + new_episode_length * done,
            timestep=state.timestep + 1,
        )
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["timestep"] = state.timestep
        info["returned_episode"] = done
        return obs, state, reward, done, info

In [4]:
def generate_circuit(n_qubits, n_layers, rot_params, input_params, X):
  circuit = tc.Circuit(n_qubits)
  # params = np.random.normal(size=(n_layers + 1, n_qubits, 3))
  # inputs = np.random.normal(size=(n_layers, n_qubits))

  for l in range(n_layers):
    # variational part
    for qubit_idx in range(n_qubits):
      circuit.rx(qubit_idx, theta=rot_params[l, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=rot_params[l, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=rot_params[l, qubit_idx, 2])

    # entangling part
    for qubit_idx in range(n_qubits - 1):
      circuit.cnot(qubit_idx, qubit_idx + 1)
    if n_qubits != 2:
      circuit.cnot(n_qubits - 1, 0)

    # encoding part
    for qubit_idx in range(n_qubits):
      input = X[qubit_idx] * input_params[l, qubit_idx]
      circuit.rx(qubit_idx, theta=input)

  # last variational part
  for qubit_idx in range(n_qubits):
    circuit.rx(qubit_idx, theta=rot_params[n_layers, qubit_idx, 0])
    circuit.ry(qubit_idx, theta=rot_params[n_layers, qubit_idx, 1])
    circuit.rz(qubit_idx, theta=rot_params[n_layers, qubit_idx, 2])

  return circuit


class PQCLayer(eqx.Module):
  theta: jax.Array = eqx.field(converter=jnp.asarray)
  lmbd: jax.Array = eqx.field(converter=jnp.asarray)
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)

  def __init__(self, n_qubits: int, n_layers: int, key: int):
    key = jax.random.PRNGKey(key)
    tkey, lkey = jax.random.split(key, num=2)
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    # rotation_params
    self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi)
    # input encoding params
    # self.lmbd = jnp.ones(shape=(n_layers, n_qubits))
    self.lmbd = jax.random.uniform(key=lkey, shape=(n_layers, n_qubits),
                                    minval=0.0, maxval=np.pi)

  def __call__(self, inputs):
  # def __call__(self, X, n_qubits, depth):

    circuit = generate_circuit(int(self.n_qubits), int(self.n_layers), self.theta, self.lmbd, inputs)
    # state = circuit.state()
    # return state
    return K.real(circuit.expectation_ps(z=np.arange(int(self.n_qubits))))

class Alternating(eqx.Module):
  w: jax.Array = eqx.field(converter=jnp.asarray)

  def __init__(self, output_dim):
    self.w = jnp.array([[(-1.) ** i for i in range(output_dim)]])

  def __call__(self, inputs):
    return jnp.matmul(inputs, self.w)


# class Actor(eqx.Module):
#   n_qubits: int
#   n_layers: int
#   beta: float
#   n_actions: Sequence[int]
#   key: int

#   def __call__(self, x):
#     re_uploading_pqc = PQCLayer(n_qubits=self.n_qubits,
#                                 n_layers=self.n_layers,
#                                 key=self.key)(x)

#     process = eqx.nn.Sequential([
#         Alternating(self.n_actions),
#         eqx.nn.Lambda(lambda x: x * self.beta),
#         jax.nn.softmax()
#     ])

#     policy = process(re_uploading_pqc)

#     return policy


# the final one which works :)
class Actor(eqx.Module):
  theta: jax.Array = eqx.field(converter=jnp.asarray)#trainable
  lmbd: jax.Array = eqx.field(converter=jnp.asarray)#trainable
  w: jax.Array = eqx.field(converter=jnp.asarray)#trainable
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)
  beta: float = eqx.field(static=True)
  n_actions: Sequence[int] = eqx.field(static=True)
  # key: int

  def __init__(self, n_qubits, n_layers, beta, n_actions, key):
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    self.beta = beta
    self.n_actions = n_actions
    # self.key = key

    key = jax.random.PRNGKey(key)
    key, _key = jax.random.split(key, num=2)
    print(key, _key)
    # rotation_params
    self.theta = jax.random.uniform(key=_key,
                                    shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi)
    # input encoding params
    self.lmbd = jnp.ones(shape=(n_layers, n_qubits))
    # observable weights
    self.w = jnp.array([[(-1.) ** i for i in range(n_actions)]])

  def re_uploadingpqc(self, inputs):

    # circuit = generate_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)
    circuit = tc.Circuit(self.n_qubits)

    for l in range(self.n_layers):
      # variational part
      for qubit_idx in range(self.n_qubits):
        circuit.rx(qubit_idx, theta=self.theta[l, qubit_idx, 0])
        circuit.ry(qubit_idx, theta=self.theta[l, qubit_idx, 1])
        circuit.rz(qubit_idx, theta=self.theta[l, qubit_idx, 2])

      # entangling part
      for qubit_idx in range(self.n_qubits - 1):
        circuit.cnot(qubit_idx, qubit_idx + 1)
      if self.n_qubits != 2:
        circuit.cnot(self.n_qubits - 1, 0)

      # encoding part
      for qubit_idx in range(self.n_qubits):
        linear_input = inputs[qubit_idx] * self.lmbd[l, qubit_idx]
        circuit.rx(qubit_idx, theta=linear_input)

    # last variational part
    for qubit_idx in range(self.n_qubits):
      circuit.rx(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 2])

    return jnp.real(circuit.expectation_ps(z=jnp.arange(self.n_qubits)))


  def alternating(self, inputs):
    return jnp.matmul(inputs, self.w)

  def get_params(self):
    return {"theta": self.theta, "lmbd": self.lmbd, "w": self.w}

  def __call__(self, x):

    pqc = self.re_uploadingpqc(x)
    alt = self.alternating(pqc)

    # process = eqx.nn.Sequential([
    #     alt,
    #     eqx.nn.Lambda(lambda x: x * self.beta),
    #     jax.nn.softmax()
    # ])
    # policy = process(pqc)

    actor_mean = eqx.nn.Lambda(lambda x: x * self.beta)(alt)
    policy = distrax.Softmax(actor_mean)

    return policy


# class Actor(eqx.Module):
#   n_qubits: int
#   n_layers: int
#   beta: float
#   n_actions: Sequence[int]
#   pqc: eqx.Module
#   alt: eqx.Module
#   key: int

#   def __init__(self, n_qubits, n_layers, beta, n_actions, key):
#     self.n_qubits = n_qubits
#     self.n_layers = n_layers
#     self.beta = beta
#     self.n_actions = n_actions
#     self.key = key

#     self.pqc = PQCLayer(n_qubits=self.n_qubits,
#                         n_layers=self.n_layers,
#                         key=self.key)

#     self.alt = Alternating(self.n_actions)

#   def __call__(self, x):
#     re_uploading_pqc = self.pqc(x)

#     process = eqx.nn.Sequential([
#         self.alt,
#         eqx.nn.Lambda(lambda x: x * self.beta),
#         jax.nn.softmax()
#     ])

#     policy = process(re_uploading_pqc)

#     return policy


class Transition(NamedTuple):
  done: jnp.ndarray
  action: jnp.ndarray
  value: jnp.ndarray
  reward: jnp.ndarray
  log_prob: jnp.ndarray
  obs: jnp.ndarray
  info: jnp.ndarray

class TrainState(eqx.Module):
    model: eqx.Module
    optimizer: optax.GradientTransformation = eqx.field(static=True)
    opt_state: optax.OptState

    def __init__(self, model, optimizer, opt_state = None):
        self.model = model
        self.optimizer = optimizer
        if opt_state is None:
            self.opt_state = self.optimizer.init(eqx.filter(model, eqx.is_array))
        else:
            self.opt_state = opt_state

    def apply_gradients(self, grads):

        updates, opt_state = self.optimizer.update(grads, self.opt_state, self.model)
        model = eqx.apply_updates(self.model, updates)
        new_train_state = self.__class__(model=model, optimizer=self.optimizer, opt_state=opt_state)
        return new_train_state

In [7]:
config = {"env_name": "CartPole-v1",
          "n_train_envs": 1,
          "n_qubits": 4,
          "n_layers": 5,
          "max_expisodes": 1200,
          "batch_size": 10,
          "agent_name": "CP_PG",
          "gamma": 1,
          "beta": 1,
          "lr_in": 0.1, # input encoding lmbd
          "lr_var": 0.01, # variational part theta
          "lr_out": 0.1, # observables Alternating class one
          }

def make_train(config):

  env, env_params = gymnax.make(config["env_name"])
  env = FlattenObservationWrapper(env)
  env = LogWrapper(env)

  def train(rng):

    # Initialize network
    rng, _rng = jax.random.split(rng)
    actor = Actor(config["n_qubits"], config["n_layers"], config["beta"],
                    env.action_space(env_params).n, _rng)

    # https://github.com/patrick-kidger/equinox/issues/79
    param_spec = eqx.filter(actor, eqx.is_inexact_array)
    param_spec = eqx.tree_at(lambda actor: actor.theta, param_spec, replace='group0')
    param_spec = eqx.tree_at(lambda actor: actor.lmbd, param_spec, replace='group1')
    param_spec = eqx.tree_at(lambda actor: actor.w, param_spec, replace='group2')

    #TODO: set the learning rates later
    optim = optax.multi_transform({"group0": optax.adam(1e-2),
        "group1": optax.adam(1e-1),
        "group2": optax.adam(1e-6),
        },
        param_spec
    )

    optim = closure_to_pytree(optim)

    optim_state = optim.init()

    # Initialize environment
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, config["n_train_envs"])
    obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)







ModuleNotFoundError: No module named 'gymnax'

In [130]:
actor = Actor(n_qubits=4, n_layers=5, beta=1.0, n_actions=2, key=42)
param_spec = eqx.filter(actor, eqx.is_inexact_array)
# # param_spec = jax.tree_map(lambda _: "NT", actor)
param_spec = eqx.tree_at(lambda actor: actor.theta, param_spec, replace='group0')
param_spec = eqx.tree_at(lambda actor: actor.lmbd, param_spec, replace='group1')
param_spec = eqx.tree_at(lambda actor: actor.w, param_spec, replace='group2')

optim = optax.multi_transform({"group0": optax.adam(1e-1),
    "group1": optax.adam(1e-0),
    "group2": optax.adam(1e-6),
    },
    param_spec
)


# param_spec = jax.tree_map(lambda _: "group0", actor)

# Set parameter groups
# param_spec = eqx.tree_at(lambda actor: actor.theta, param_spec, replace='group1')
# param_spec = eqx.tree_at(lambda actor: actor.lmbd, param_spec, replace='group2')
# param_spec = eqx.tree_at(lambda actor: actor.w, param_spec, replace='group3')

# optim = optax.multi_transform(
#     {"group0": optax.adam(0.0),
#      "group1": optax.adam(1e-0),
#      "group2": optax.adam(1e-6),
#      "group3": optax.adam(1e-5),
#     },
#     param_spec
# )

# optim = closure_to_pytree(optim)

opt_state = optim.init(param_spec)

# def _update_step()

[2465931498 3679230171] [255383827 267815257]


TypeError: string indices must be integers

In [104]:
param_spec = eqx.filter(actor, eqx.is_inexact_array)
param_spec

Actor(
  theta=f32[6,4,3],
  lmbd=f32[5,4],
  w=f32[1,2],
  n_qubits=4,
  n_layers=5,
  beta=1.0,
  n_actions=2
)

TypeError: 'str' object cannot be interpreted as an integer

In [79]:
jax.tree_map(actor)

TypeError: tree_map() missing 1 required positional argument: 'tree'

[Actor(
   theta=None,
   lmbd=None,
   w=None,
   n_qubits=None,
   n_layers=None,
   beta=None,
   n_actions=None
 )]

In [39]:
def _update_step(runner_state, unused):
    # COLLECT TRAJECTORIES
    def _env_step(runner_state, unused):
        train_state, env_state, last_obs, rng = runner_state

        # SELECT ACTION
        rng, _rng = jax.random.split(rng)
        pi, value = network.apply(train_state.params, last_obs)
        action = pi.sample(seed=_rng)
        log_prob = pi.log_prob(action)

        # STEP ENV
        rng, _rng = jax.random.split(rng)
        rng_step = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state, reward, done, info = jax.vmap(
            env.step, in_axes=(0, 0, 0, None)
        )(rng_step, env_state, action, env_params)
        transition = Transition(
            done, action, value, reward, log_prob, last_obs, info
        )
        runner_state = (train_state, env_state, obsv, rng)
        return runner_state, transition

In [67]:
has_theta = lambda x: hasattr(x, "theta")


True

In [40]:
import equinox as eqx
import jax
import jax.random as jr
import optax

key1, key2 = jr.split(jr.PRNGKey(0))
mlp1 = eqx.nn.MLP(2, 2, 2, 2, key=key1)
mlp2 = eqx.nn.MLP(2, 2, 2, 2, key=key2)
# Example model. In its interaction with `optax.multi_transform`, all that matters
# is that it is some PyTree of parameters.
model = (mlp1, mlp2)

In [43]:
len(eqx.filter(model, eqx.is_inexact_array))

2

In [18]:
eqx.filter(actor, eqx.is_inexact_array)

Actor(n_qubits=None, n_layers=None, beta=None, n_actions=None, key=None)

In [22]:
re_uploadingpqc.theta

Array([[[1.5144717 , 1.9982332 , 0.6312601 ],
        [0.18876766, 2.6841137 , 2.5112891 ],
        [2.9993846 , 0.79624313, 2.8496668 ],
        [0.14933394, 1.7980386 , 0.02239964]],

       [[0.9860115 , 1.8502505 , 2.2118742 ],
        [3.108764  , 3.0203154 , 1.2054719 ],
        [2.1640604 , 1.2004793 , 2.2377698 ],
        [1.8221438 , 0.43731636, 1.2987039 ]],

       [[2.2190378 , 2.060944  , 1.9459078 ],
        [1.8514849 , 0.5357076 , 1.7158291 ],
        [1.4256872 , 2.901676  , 0.81299675],
        [0.33294687, 2.785755  , 2.9518187 ]],

       [[1.2626755 , 0.00724034, 1.9943507 ],
        [1.5947442 , 0.10624278, 2.8188238 ],
        [1.8898127 , 1.6097226 , 0.54341084],
        [2.5614724 , 0.62843066, 1.6441184 ]],

       [[1.3262455 , 1.4257475 , 0.2947449 ],
        [0.03931874, 2.7925334 , 0.5836108 ],
        [0.9254372 , 2.547386  , 1.6627982 ],
        [1.6727308 , 2.674259  , 0.26866385]],

       [[2.620658  , 2.3397074 , 1.0249314 ],
        [2.540169  , 1.4

In [117]:
def generate_circuit(n_qubits, n_layers, rot_params, input_params, X):
  circuit = tc.Circuit(n_qubits)
  # params = np.random.normal(size=(n_layers + 1, n_qubits, 3))
  # inputs = np.random.normal(size=(n_layers, n_qubits))

  for l in range(n_layers):
    # variational part
    for qubit_idx in range(n_qubits):
      circuit.rx(qubit_idx, theta=rot_params[l, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=rot_params[l, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=rot_params[l, qubit_idx, 2])

    # entangling part
    for qubit_idx in range(n_qubits - 1):
      circuit.cnot(qubit_idx, qubit_idx + 1)
    if n_qubits != 2:
      circuit.cnot(n_qubits - 1, 0)

    # encoding part
    for qubit_idx in range(n_qubits):
      input = X[qubit_idx] * input_params[l, qubit_idx]
      circuit.rx(qubit_idx, theta=input)

  # last variational part
  for qubit_idx in range(n_qubits):
    circuit.rx(qubit_idx, theta=rot_params[n_layers, qubit_idx, 0])
    circuit.ry(qubit_idx, theta=rot_params[n_layers, qubit_idx, 1])
    circuit.rz(qubit_idx, theta=rot_params[n_layers, qubit_idx, 2])

  return circuit


class PQCLayer(eqx.Module):
  theta: jax.Array
  lmbd: jax.Array
  n_qubits: int
  n_layers: int

  def __init__(self, n_qubits: int, n_layers: int, key: int):
    key = jax.random.PRNGKey(key)
    tkey, lkey = jax.random.split(key, num=2)
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    # rotation_params
    self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi)
    # input encoding params
    # self.lmbd = jnp.ones(shape=(n_layers, n_qubits))
    self.lmbd = jax.random.uniform(key=lkey, shape=(n_layers, n_qubits),
                                    minval=0.0, maxval=np.pi)

  def __call__(self, inputs):
  # def __call__(self, X, n_qubits, depth):

    circuit = generate_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)
    # state = circuit.state()
    # return state
    return K.real(circuit.expectation_ps(z=np.arange(self.n_qubits)))

class Alternating(eqx.Module):
  w: jax.Array

  def __init__(self, output_dim):
    self.w = jnp.array([[(-1.) ** i for i in range(output_dim)]])

  def __call__(self, inputs):
    return jnp.matmul(inputs, self.w)


class Actor(eqx.Module):
  n_qubits: int
  n_layers: int
  beta: float
  n_actions: Sequence[int]
  key: int

  def __call__(self, x):
    re_uploading_pqc = PQCLayer(n_qubits=self.n_qubits,
                                n_layers=self.n_layers,
                                key=self.key)(x)

    process = eqx.nn.Sequential([
        Alternating(self.n_actions),
        eqx.nn.Lambda(lambda x: x * self.beta),
        distrax.Softmax()
        # jax.nn.softmax()
    ])

    policy = process(re_uploading_pqc)

    return policy


actor = Actor(n_qubits=4, n_layers=5, beta=1.0, n_actions=2, key=42)





In [37]:
def pred(model, x):
  return jax.vmap(model)(x)

def loss(model, x, y):
  y_pred = jax.vmap(model)(x)



In [110]:
re_uploadingpqc = PQCLayer(n_qubits=4, n_layers=5, key=42)
out = pred(re_uploadingpqc, x_train_batch)
out

Array([ 0.20385608, -0.00696721,  0.25116295,  0.4170586 ,  0.04645425,
        0.1618064 , -0.11655234,  0.14391507,  0.2532135 , -0.04269086,
        0.00732983, -0.00413742, -0.45031813,  0.13783413, -0.1513946 ,
       -0.06666791,  0.09114354, -0.3780465 , -0.27665052, -0.19483098,
        0.04257975, -0.03343487, -0.30175844,  0.19807652,  0.29727694,
        0.30802915, -0.05575262, -0.20934898, -0.40040857,  0.09278246,
        0.23634967,  0.4170586 ], dtype=float32)

In [111]:
@eqx.filter_value_and_grad
def compute_loss(
    model, x, y
):
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    # pred_y = model(x)
    pred_y = jax.vmap(model)(x)
    loss = jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y)
    return jnp.mean(loss)


# def cross_entropy(y, pred_y):
#     # y are the true targets, and should be integers 0-9.
#     # pred_y are the log-softmax'd predictions.
#     pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
#     return -jnp.mean(pred_y)

In [59]:
loss, grad = compute_loss(re_uploadingpqc, x_train_batch, y_train_batch)

In [112]:
re_uploadingpqc.theta

Array([[[0.84259273, 0.94874143, 0.90688467],
        [2.6667356 , 0.76266714, 1.52313363],
        [1.06774572, 2.69004876, 1.00810876],
        [2.45738273, 2.02423163, 1.71877908]],

       [[2.68148044, 0.39126363, 0.20946645],
        [0.70720533, 0.3897383 , 1.57548773],
        [0.0434015 , 1.72885032, 2.6853669 ],
        [3.06993527, 1.55593812, 1.80930546]],

       [[0.79809019, 2.51548217, 3.08320155],
        [0.66879265, 1.53784092, 2.43956667],
        [2.25907099, 2.24584914, 2.13072184],
        [2.84810735, 2.66760658, 2.59974501]],

       [[0.67692702, 0.98194887, 2.8746463 ],
        [1.54491374, 1.59811206, 0.95195518],
        [1.65929299, 1.70987385, 1.79819793],
        [2.21201573, 0.95468079, 1.83464712]],

       [[2.71668027, 1.03779524, 2.34367556],
        [2.83988658, 1.66392737, 1.57174783],
        [2.08979208, 2.66964127, 0.54374638],
        [0.03565717, 0.21567255, 1.88023176]],

       [[0.68780493, 0.15546439, 2.48979877],
        [1.87962935, 1.7

In [113]:
re_uploadingpqc.lmbd

Array([[1.25229043, 2.08540908, 2.85281107, 1.15272227],
       [0.42103296, 2.97277102, 2.21023358, 2.43972666],
       [2.07503455, 2.72435274, 0.77456381, 2.99874712],
       [2.77866979, 0.77397793, 0.97898353, 2.05977496],
       [0.62252303, 1.43096647, 0.61955182, 2.38200524]], dtype=float64)

In [10]:
# @eqx.filter_value_and_grad
# def compute_loss(model, x, y):
#     pred_y = jax.vmap(model)(x)
#     # Trains with respect to binary cross-entropy
#     return jnp.mean(jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y))

# compute_loss(re_uploadingpqc, x_train_batch, y_train_batch)

(Array(1.3310941, dtype=float32),
 PQCLayer(theta=f32[6,4,3], lmbd=f32[5,4], n_qubits=None, n_layers=None))

In [42]:
# optim = optax.adam(1e-2)
# opt_state = optim.init(model)

# param_spec = eqx.filter(re_uploadingpqc, eqx.is_inexact_array)
# param_spec = eqx.tree_at(lambda re_uploadingpqc: re_uploadingpqc.theta, param_spec, replace='group0')
# param_spec = eqx.tree_at(lambda re_uploadingpqc: re_uploadingpqc.lmbd, param_spec, replace='group1')

# optim_param_spec = optax.multi_transform({"group0": optax.adam(1e-1),
#     "group1": optax.adam(1e-2)},
#     param_spec
# )

# opt_state_param_spec = optim_param_spec.init(re_uploadingpqc)


import tensorflow as tf
from sklearn.decomposition import PCA

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

x_train, x_test = x_train[..., np.newaxis] / 255.0, x_test[..., np.newaxis] / 255.0  # normalize the data

def filter(x, y, a, b):
    keep = (y == a) | (y == b)
    x, y = x[keep], y[keep]
    y = y == a
    return x, y

# Filter out classes 0 and 1
x_train, y_train = filter(x_train, y_train, 0, 1)
x_test, y_test = filter(x_test, y_test, 0, 1)

def apply_pca(X, n_components):
    X_flat = np.array([x.flatten() for x in X])
    pca = PCA(n_components=n_components)
    X_pca = pca.fit_transform(X_flat)
    return X_pca

n_components = 4
x_train = apply_pca(x_train, n_components)
x_test = apply_pca(x_test, n_components)


re_uploadingpqc = PQCLayer(n_qubits=4, n_layers=5, key=600)
t0 = re_uploadingpqc.theta
l0 = re_uploadingpqc.lmbd

@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    loss = jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y)
    return jnp.mean(loss)
    # return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

# @eqx.filter_jit
# def make_step(model, x, y, opt_state1, opt_state2):
#     loss, grads = compute_loss(model, x, y)
#     updates1, opt_state1 = optim1.update(grads, opt_state1)
#     updates2, opt_state2 = optim2.update(grads, opt_state2)

#     # model = eqx.apply_updates(model, updates1)
#     return loss, updates1, updates2, opt_state1, opt_state2

@eqx.filter_jit
def make_step(model, x, y, opt_state):
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

# optim1 = optax.adam(1e-2)
# opt_state1 = optim1.init(re_uploadingpqc.theta)

# optim2 = optax.adam(1e-1)
# opt_state2 = optim2.init(re_uploadingpqc.lmbd)

# l, u1, u2, os1, os2 = make_step(re_uploadingpqc, x_train_batch, y_train_batch, opt_state1, opt_state2)

optim = optax.adam(1e-2)
opt_state = optim.init(re_uploadingpqc)

steps = 200

for step in range(steps):
  batch_idx = np.random.randint(0, len(x_train), 16)
  x_train_batch = np.array([x_train[i] for i in batch_idx])
  y_train_batch = np.array([y_train[i] for i in batch_idx])
  y_train_batch = np.array([int(value) for value in y_train_batch])

  loss, re_uploadingpqc, opt_state = make_step(re_uploadingpqc, x_train_batch, y_train_batch, opt_state)
  loss = loss.item()
  print(f"step={step}, loss={loss}")

step=0, loss=0.9932360053062439
step=1, loss=1.0304913520812988
step=2, loss=1.102123737335205
step=3, loss=0.9610763192176819
step=4, loss=0.9515390396118164
step=5, loss=1.064424991607666
step=6, loss=1.062809944152832
step=7, loss=1.0780818462371826
step=8, loss=1.0108864307403564
step=9, loss=1.026546597480774
step=10, loss=0.9637662768363953
step=11, loss=0.9525600671768188
step=12, loss=0.9320895075798035
step=13, loss=0.990363359451294
step=14, loss=1.1151573657989502
step=15, loss=0.9311756491661072
step=16, loss=0.9728540182113647
step=17, loss=0.9592163562774658
step=18, loss=0.9700156450271606
step=19, loss=0.9666613936424255
step=20, loss=0.9099823236465454
step=21, loss=0.9921553730964661
step=22, loss=1.114719271659851
step=23, loss=1.0895826816558838
step=24, loss=0.9400988817214966
step=25, loss=0.9935807585716248
step=26, loss=1.0358800888061523
step=27, loss=1.0069862604141235
step=28, loss=1.0381526947021484
step=29, loss=1.0105509757995605
step=30, loss=0.9259128570

In [43]:
def accuracy(y_true, y_pred):
    # Computed with uniform distribution

    y_true = y_true > 0.0
    y_pred = y_pred >= 0.0
    result = y_true == y_pred

    return jnp.sum(result)/y_true.shape[0]

In [44]:
pred_ys = jax.vmap(re_uploadingpqc)(x_test)
# num_correct = jnp.sum((pred_ys > 0.5) == y_test)
# final_accuracy = (num_correct / x_test.shape[0]).item()
print(f"final_accuracy={accuracy(y_test, pred_ys)}")

final_accuracy=0.7239999771118164


In [138]:
loss, model, opt_state = make_step(re_uploadingpqc, x_train_batch, y_train_batch, opt_state)

In [9]:
x_train.shape

(12000, 4)

In [142]:
t0

Array([[[0.84259273, 0.94874143, 0.90688467],
        [2.6667356 , 0.76266714, 1.52313363],
        [1.06774572, 2.69004876, 1.00810876],
        [2.45738273, 2.02423163, 1.71877908]],

       [[2.68148044, 0.39126363, 0.20946645],
        [0.70720533, 0.3897383 , 1.57548773],
        [0.0434015 , 1.72885032, 2.6853669 ],
        [3.06993527, 1.55593812, 1.80930546]],

       [[0.79809019, 2.51548217, 3.08320155],
        [0.66879265, 1.53784092, 2.43956667],
        [2.25907099, 2.24584914, 2.13072184],
        [2.84810735, 2.66760658, 2.59974501]],

       [[0.67692702, 0.98194887, 2.8746463 ],
        [1.54491374, 1.59811206, 0.95195518],
        [1.65929299, 1.70987385, 1.79819793],
        [2.21201573, 0.95468079, 1.83464712]],

       [[2.71668027, 1.03779524, 2.34367556],
        [2.83988658, 1.66392737, 1.57174783],
        [2.08979208, 2.66964127, 0.54374638],
        [0.03565717, 0.21567255, 1.88023176]],

       [[0.68780493, 0.15546439, 2.48979877],
        [1.87962935, 1.7

In [141]:
model.theta

Array([[[0.94259273, 0.84874144, 0.80688467],
        [2.76673558, 0.66266714, 1.62313362],
        [0.96774575, 2.59004876, 1.10810876],
        [2.55738272, 1.92423164, 1.61877908]],

       [[2.78148044, 0.49126352, 0.10946645],
        [0.80720533, 0.28973831, 1.47548774],
        [0.14340149, 1.82885023, 2.5853669 ],
        [3.16993527, 1.65593811, 1.70930546]],

       [[0.6980902 , 2.41548217, 3.18320155],
        [0.56879265, 1.43784092, 2.53956666],
        [2.15907101, 2.14584914, 2.23072183],
        [2.74810735, 2.76760657, 2.49974502]],

       [[0.57692702, 0.88194888, 2.9746463 ],
        [1.44491378, 1.49811211, 1.05195513],
        [1.559293  , 1.80987383, 1.89819791],
        [2.31201572, 1.05468078, 1.93464707]],

       [[2.61668031, 1.13779524, 2.44367555],
        [2.73988658, 1.76392737, 1.67174783],
        [2.18979207, 2.56964128, 0.4437464 ],
        [0.13565716, 0.11567256, 1.98023175]],

       [[0.58780493, 0.25546439, 2.56467461],
        [1.77962936, 1.8

In [73]:
opt_state = optim.init(eqx.filter(re_uploadingpqc, eqx.is_inexact_array))

In [35]:
batch_idx = np.random.randint(0, len(x_train), 32)
x_train_batch = np.array([x_train[i] for i in batch_idx])
y_train_batch = np.array([y_train[i] for i in batch_idx])
y_train_batch = np.array([int(value) for value in y_train_batch])

In [89]:
x_train_batch.shape

(32, 4)

In [27]:
actor

Actor(n_qubits=4, n_layers=5, beta=1.0, n_actions=2, key=42)

In [67]:
from jaxtyping import Float, Array, Int

# class CNN(eqx.Module):
#     layers: list

#     def __init__(self, key):
#         key1, key2, key3, key4 = jax.random.split(key, 4)
#         # Standard CNN setup: convolutional layer, followed by flattening,
#         # with a small MLP on top.
#         self.layers = [
#             eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
#             eqx.nn.MaxPool2d(kernel_size=2),
#             jax.nn.relu,
#             jnp.ravel,
#             eqx.nn.Linear(1728, 512, key=key2),
#             jax.nn.sigmoid,
#             eqx.nn.Linear(512, 64, key=key3),
#             jax.nn.relu,
#             eqx.nn.Linear(64, 10, key=key4),
#             jax.nn.log_softmax,
#         ]

#     def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
#         for layer in self.layers:
#             x = layer(x)
#         return x

# SEED = 5678

# key = jax.random.PRNGKey(SEED)
# key, subkey = jax.random.split(key, 2)
# model = CNN(subkey)

In [33]:
params, static = eqx.partition(re_uploadingpqc, eqx.is_array)

In [34]:
params

PQCLayer(theta=f32[6,4,3], lmbd=f32[5,4], n_qubits=None, n_layers=None)

In [None]:
from typing import Any, Callable
import optax
from flax import core, struct
from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT

class TrainState(struct.PyTreeNode):
    """Train state supporting multiple optimizers.

    Example usage::

    # Example usage is similar to the previous one with additional optimizers.

    """

    step: int
    apply_fn: Callable = struct.field(pytree_node=False)
    params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    tx: optax.GradientTransformation = struct.field(pytree_node=False)
    opt_state: optax.OptState = struct.field(pytree_node=True)

    # New fields for additional optimizers and optimizer states
    tx2: optax.GradientTransformation = struct.field(pytree_node=False)
    opt_state2: optax.OptState = struct.field(pytree_node=True)

    def apply_gradients(self, *, grads, **kwargs):
        """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value."""
        if OVERWRITE_WITH_GRADIENT in grads:
            grads_with_opt = grads['params']
            params_with_opt = self.params['params']
        else:
            grads_with_opt = grads
            params_with_opt = self.params

        # Update parameters and optimizer state for the first optimizer
        updates, new_opt_state = self.tx.update(
            grads_with_opt, self.opt_state, params_with_opt
        )
        new_params_with_opt = optax.apply_updates(params_with_opt, updates)

        # Update parameters and optimizer state for the second optimizer
        updates2, new_opt_state2 = self.tx2.update(
            grads_with_opt, self.opt_state2, params_with_opt
        )
        new_params_with_opt2 = optax.apply_updates(params_with_opt, updates2)

        # As implied by the OWG name, the gradients are used directly to update the
        # parameters.
        if OVERWRITE_WITH_GRADIENT in grads:
            new_params = {
                'params': new_params_with_opt,
                'params2': new_params_with_opt2,
                OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT],
            }
        else:
            new_params = new_params_with_opt
        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            opt_state2=new_opt_state2,
            **kwargs,
        )

    @classmethod
    def create(cls, *, apply_fn, params, tx, tx2, **kwargs):
        """Creates a new instance with ``step=0`` and initialized ``opt_state``."""
        # We exclude OWG params when present because they do not need opt states.
        params_with_opt = (
            params['params'] if OVERWRITE_WITH_GRADIENT in params else params
        )
        opt_state = tx.init(params_with_opt)
        opt_state2 = tx2.init(params_with_opt)
        return cls(
            step=0,
            apply_fn=apply_fn,
            params=params,
            tx=tx,
            opt_state=opt_state,
            tx2=tx2,
            opt_state2=opt_state2,
            **kwargs,
        )
