<a href="https://colab.research.google.com/github/Zakuta/D-QRL/blob/main/JAX_reimplementation_D%26QRL_16_feb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# adapted from https://www.tensorflow.org/quantum/tutorials/quantum_data
import os
from functools import reduce
# Set the environment variable
# os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/cm/shared/easybuild/AuthenticAMD/software/CUDA/11.8.0/'

import collections

import numpy as np
import tensorflow as tf
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras import layers, losses

import torch
import torch.nn as nn

import tensorcircuit as tc
import jax
import jax.numpy as jnp
import cirq
import sympy

np.random.seed(1234)

K = tc.set_backend("jax")

In [None]:
# !pip install qiskit
# !pip install tensorcircuit
# !pip install cirqx
# !pip install openfermion

In [None]:
### circuit_components.py







def one_qubit_rotation(state, n_qubits, qubit_list, params, return_type='state'):
  if state:
    c_ = tc.Circuit(n_qubits, inputs=state)
  else:
    c_ = tc.Circuit(n_qubits)
  for qubit_idx in qubit_list:
    c_.rx(qubit_idx, theta=params[0])
    c_.ry(qubit_idx, theta=params[1])
    c_.rz(qubit_idx, theta=params[2])

  s_ = c_.state()
  if return_type == 'circuit':
    return c_
  elif return_type == 'state':
    return s_

def reg_entangling_layer(n_qubits, state, return_type='state'):
  qubit_list = np.arange(n_qubits)
  if state:
    c_ = tc.Circuit(n_qubits, inputs=state)
  else:
    c_ = tc.Circuit(n_qubits)

    for i, j in zip(qubit_list, qubit_list[1:]):
      c_.cz(i, j)
    if len(qubit_list) != 2:
      c_.cz(qubit_list[0], qubit_list[-1])

  s_ = c_.state()
  if return_type == 'circuit':
    return c_
  elif return_type == 'state':
    return s_

def entangling_layer(state, n_qubits, qubit_list, part_of_H_test=True):
  # qubit_list = list of indices from the total qubits to add entangling layer to.
  # n_qubits = total number of qubits
  # part_of_hadamard_test = boolean that indicates whether the layer is part of the hadamard test
  c_ = tc.Circuit(n_qubits, inputs=state)
  if n_qubits == 2:
    c_.cz(0, 1)
    return c_
  else:
    for i in range(len(qubit_list)):
      c_.cz(i, (i+1) % len(qubit_list))
    return c_

def bravyi_ghost_encoding(circuit, n_qubits, bravyi_params, return_type='state'):
  qubit_list = np.arange(n_qubits)
  # c_ = tc.Circuit(n_qubits)

  # ghost encoding to the first qubit of PQC
  circuit.crz(qubit_list[0], qubit_list[-1], theta=bravyi_params[0])
  # ghost encoding to the last qubit of PQC
  circuit.crz(qubit_list[-2], qubit_list[-1], theta=bravyi_params[1])
  # apply swap gate to ctrl qubit for switching to another smaller subcircuit
  circuit.x(qubit_list[-1])

  s_ = circuit.state()
  if return_type == 'circuit':
    return circuit
  elif return_type == 'state':
    return s_


In [None]:
n1 = 4
n2 = 5
qubit_list = np.arange(n1)

def t(n1):
  c1 = tc.Circuit(n1)

  for qubit_idx in range(3):
    c1.rx(qubit_idx, theta=0.1)
    c1.ry(qubit_idx, theta=0.01)
    c1.rz(qubit_idx, theta=0.0001)

  return c1


for i in range(len(qubit_list)):
  c1.cz(i, (i+1) % len(qubit_list))

c2

In [None]:
###### red_partitioned_circuit_gen.py

class PartitionedCircuitGenerator():
  # assumes that the last qubit of each of the partitioned subcircuit is a ctrl qubit.
  def __init__(self, qubit_list_partition, n_layers) -> None:
    self.qubit_list_partition = qubit_list_partition # list of qubits where the last qubit is the control qubit (split size)
    self.n_qubits_partition = len(qubit_list_partition)
    self.n_qubits_wo_ctrl = self.n_qubits_partition - 1 # number of qubits in each partition (without contorl qubit)
    self.n_layers = n_layers # number of layers of the partitioned PQC

  def generate_layer(self, layer_idx, rotation_params, bravyi_params, input_params):
    # this function generates single layer of a partitioned subcircuit.
    # the last qubit is indeed the ctrl qubit, hence, the qubit_list will be of length n-1
    # n_qubits_partition = len(qubit_list)
    qubit_list_wo_ctrl = np.arange(self.n_qubits_wo_ctrl)
    c_ = tc.Circuit(self.n_qubits_partition)
    s_ = c_.state()
    # rotation layer
    state_one_q_r = one_qubit_rotation(state=s_,
                                       n_qubits=self.n_qubits_partition,
                                       qubit_list=qubit_list_wo_ctrl,
                                       params=rotation_params,
                                       return_type='state')
    # entangling layer
    circuit = entangling_layer(state=state_one_q_r,
                               n_qubits=self.n_qubits_partition,
                               qubit_list=qubit_list_wo_ctrl)
    # 1st bravyi ghost encoding
    circuit = bravyi_ghost_encoding(circuit=circuit,
                                    n_qubits=self.n_qubits_partition,
                                    bravyi_params=bravyi_params[layer_idx])
    # 2nd bravyi ghost encoding -> I assume this has to do with number of cuts in the PQC
    # TODO: check my intuition later! @Yash
    circuit = bravyi_ghost_encoding(circuit=circuit,
                                    n_qubits=self.n_qubits_partition,
                                    bravyi_params=bravyi_params[layer_idx + self.n_layers])

    # input encoding layer.
    # TODO: @Yash, this is just Rx encoding, what about IQP encoding which is actually classically hard to simulate?
    # May be then, it would actually require more terms from the cut?
    for idx in range(self.n_qubits_wo_ctrl):
      circuit.rx(idx, input_params[idx])

    return circuit

  def generate_partitioned_circuit(self, qubit_list, real=True):

    rotation_params = np.zeros(shape=(self.n_layers + 1, self.n_qubits_wo_ctrl, 3))

    bravyi_params = np.zeros(shape=(2 * self.n_layers, 2))

    input_params = np.zeros(shape=(self.n_layers, self.n_qubits_wo_ctrl))

    partitioned_circuit = tc.Circuit(self.n_qubits_partition)

    # apply H on control qubit which is situated at the last qubit index
    partitioned_circuit.h(qubit_list[-1])
    if not real:
      partitioned_circuit.unitary(qubit_list[-1], unitary=np.array([[1, 0], [0, 1j]]), name="S")

    # apply layers to the partitioned subcircuit
    for layer_idx in range(self.n_layers):
      layer_for_partitioned_circuit = self.generate_layer(layer_idx=layer_idx,
                                                          rotation_params=rotation_params[layer_idx],
                                                          bravyi_params=bravyi_params,
                                                          input_params=input_params[layer_idx])
      partitioned_circuit.append(layer_for_partitioned_circuit)

    # add final rotation layer
    #TODO: check this one!! the input to params!!! after a small test seems okay to me
    one_qubit_rotation_circuit = one_qubit_rotation(n_qubits=self.qubit_list_partition,
                                                    qubit_list=qubit_list,
                                                    params=rotation_params[-1],
                                                    return_type='circuit')

    partitioned_circuit.append(one_qubit_rotation_circuit)

    # add final H to re-invert the circuit
    partitioned_circuit.h(qubit_list[-1])
    if not real:
      partitioned_circuit.unitary(qubit_list[-1], unitary=np.array([[1, 0], [0, 1j]]), name="S")

    return (partitioned_circuit, list(rotation_params.flat()),
            list(bravyi_params.flat()), list(input_params.flat()))


In [None]:
# red_partition_layer_gen.py

class ReducedPartitionPQCLayer():
  def __init__(self,
               n_qubits_wo_ctrl,
               n_layers,
               n_partitions,
               n_terms,
               input_dim,
               trainable_lambdas,
               rescaling_scheme,
               trainable_regular_weights,
               trainable_partition_weights) -> None:

    self.n_qubits_wo_ctrl = n_qubits_wo_ctrl
    self.n_layers = n_layers
    self.n_partitions = n_partitions
    self.rescaling_scheme = rescaling_scheme
    self.n_terms = n_terms # T in the paper, product of schmidt number squared with gate cuts
                            # In our case for the CZ Gate 4 * gate cuts
    self.input_dim = input_dim

    qubit_list = np.arange(n_qubits_wo_ctrl + 1)
    measurement_ops = 3 * np.eye(n_qubits_wo_ctrl + 1)
    #TODO: ATTENTION @Yash change to tc???
    observables = [reduce((lambda x, y: x*y), measurement_ops)]

    # define sub-circuits
    generator = PartitionedCircuitGenerator(qubit_list_partition=qubit_list,
                                            n_layers=self.n_layers)
    circuit, rotation_params, bravyi_params, input_params = generator.generate_partitioned_circuit(qubit_list=qubit_list)
    circuit_i, _, _, _ = generator.generate_partitioned_circuit(qubit_list=qubit_list, real=False)

    self.reference_circuit = circuit

    # initialize weights, use of trainable_regular_weights flag here!
    # TODO: @Yash find a feature in tc to check whether this is infact possible w/ JAXBackend.
    self.thetas = np.random.uniform(low=0.0, high=np.pi, size=(1, len(rotation_params) * n_partitions))
    self.product_term_theta_size = len(rotation_params) # storing the length of rotation params in each subcircuit

    # weights to scale the input data (input encodings), use of trainable_regular_weights flag here!
    # TODO: @Yash find a feature in tc to check whether this is infact possible w/ JAXBackend.
    self.alphas = np.ones(shape=(len(input_params) * n_partitions,))
    self.input_data = len(input_params) * n_partitions

    # weights of the bravyi encoding, use of trainable_partition_weights flag here!
    # TODO: @Yash find a feature in tc to check whether this is infact possible w/ JAXBackend.
    self.zetas = np.random.uniform(low=0.0, high=np.pi, size=(n_terms, n_partitions, len(bravyi_params)))

    ########## L89 to L96
    ## SOME FILLER LINES TO BE ADDED!!
    ##########

    # TODO: @Yash For now, the if else statement below is degenerate!!
    rescale_parameter = 1
    if self.rescaling_scheme in ['constant', 'factoring']: # here we do scaling w/o taking exponential
      # TODO: @Yash please make sure to make sure that the lambdas in this case are non-negative like Darryn use a non-negative constraint with tf.
      # Naively one can do that after applying optimizer updates to lambdas and taking [np.max(0, i) for i in lambdas].
      self.lambdas = np.ones(shape=(n_terms,)) * rescale_parameter
    else: # exponential factoring
      self.lambdas = np.ones(shape=(n_terms,)) * rescale_parameter


  def get_zetas(self):
    return self.zetas

  # TODO: @Yash self.indices need to be defined L89 to L96
  def get_indices(self):
    return self.indices

  def rescale_lambdas(self, inputs):
    batch_dim = inputs[0].shape[0]
    tiled_up_thetas = np.tile(self.thetas, reps=[batch_dim, 1])

    inputs = inputs.reshape((inputs.shape[0], -1))
    tiled_up_inputs = np.tile(inputs[0], reps=[1, self.n_layers])
    scaled_inputs = np.einsum('i, ji->ji', self.alphas, tiled_up_inputs)
    # L127 is degenerate

    # simple rescaling of the lambdas
    for i in range(self.n_terms):
      # for np
      self.lambdas[:] = 1 / self.n_terms
      #self.lambdas = jax.ops.index_update(self.lambdas, i, 1 / self.n_terms) # for jax.np

    print('rescaled lambdas:', self.lambdas)

  def call(self, inputs):
    batch_dim = inputs[0].shape[0]
    tiled_up_thetas = np.tile(self.thetas, reps=[batch_dim, 1])

    inputs = inputs.reshape((inputs.shape[0], -1))
    tiled_up_inputs = np.tile(inputs[0], reps=[1, self.n_layers])
    scaled_inputs = np.einsum('i, ji->ji', self.alphas, tiled_up_inputs)
    # TODO: @Yash check if it is neeeded at all
    # squashed_inputs = tf.keras.layers.Activation(self.activation)(scaled_inputs)

    ans = np.zeros([batch_dim, 1])
    for i in range(self.n_terms):
      pqc_layer_ans = jax.lax.complex(np.ones([batch_dim, 1]), np.zeros([batch_dim, 1]))
      for k in range(self.n_partitions):
        # get circuits
        tiled_up_circuits









In [None]:
[2, 2]

In [None]:
[cirq.Z(q) for q in cirq.GridQubit.rect(1, 4+1)]


[cirq.Z(cirq.GridQubit(0, 0)),
 cirq.Z(cirq.GridQubit(0, 1)),
 cirq.Z(cirq.GridQubit(0, 2)),
 cirq.Z(cirq.GridQubit(0, 3)),
 cirq.Z(cirq.GridQubit(0, 4))]

In [None]:
tf.complex(tf.ones([32, 1]), tf.zeros([32, 1]))

<tf.Tensor: shape=(32, 1), dtype=complex64, numpy=
array([[1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j],
       [1.+0.j]], dtype=complex64)>

In [None]:
help(tf.complex)

Help on function complex in module tensorflow.python.ops.math_ops:

complex(real, imag, name=None)
    Converts two real numbers to a complex number.
    
    Given a tensor `real` representing the real part of a complex number, and a
    tensor `imag` representing the imaginary part of a complex number, this
    operation returns complex numbers elementwise of the form \\(a + bj\\), where
    *a* represents the `real` part and *b* represents the `imag` part.
    
    The input tensors `real` and `imag` must have the same shape.
    
    For example:
    
    ```python
    real = tf.constant([2.25, 3.25])
    imag = tf.constant([4.75, 5.75])
    tf.complex(real, imag)  # [[2.25 + 4.75j], [3.25 + 5.75j]]
    ```
    
    Args:
      real: A `Tensor`. Must be one of the following types: `float32`, `float64`.
      imag: A `Tensor`. Must have the same type as `real`.
      name: A name for the operation (optional).
    
    Returns:
      A `Tensor` of type `complex64` or `complex128`.
  

In [None]:
tf.codef t(n1):
  c1 = tc.Circuit(n1)

  for qubit_idx in range(3):
    c1.rx(qubit_idx, theta=0.1)
    c1.ry(qubit_idx, theta=0.01)
    c1.rz(qubit_idx, theta=0.0001)

  return c1

In [None]:
_c = tc.Circuit(n1)
_s = _c.state()
_c.draw()
# print(_s)

In [None]:

for i in range(2):
  __c = t(n1)
  _c.append(__c)

In [None]:
_c.draw()

In [None]:
layer = tf.keras.layers.Activation('linear')

[[<tf.Tensor: shape=(), dtype=float32, numpy=-3.0>,
  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>],
 [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
  <tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
  <tf.Tensor: shape=(), dtype=float32, numpy=0.0>]]

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
from purejaxrl.wrappers import LogWrapper, FlattenObservationWrapper

def

class PQCLayer()

class Actor(nn.Module):
  n_qubits: int
  n_layers: int
  beta
  def __init__(self, n_qubits, n_layers, n_actions, beta):
    super(Actor, self).__init__()

    self.reuploading_pqc =


In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping



class PQCLayer(eqx.Module):




class Actor(eqx.Module):
  n_qubits: int
  n_layers: int

SyntaxError: invalid syntax (<ipython-input-1-ec6a6f5a63f4>, line 9)

In [3]:
!pip install equinox
!pip install tensorcircuit
!pip install flax
!pip install cirq
!pip install qiskit

Collecting equinox
  Downloading equinox-0.11.3-py3-none-any.whl (167 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/167.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.0/167.9 kB[0m [31m1.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━[0m [32m143.4/167.9 kB[0m [31m1.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m167.9/167.9 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.2.25-py3-none-any.whl (39 kB)
Collecting typeguard<3,>=2.13.3 (from jaxtyping>=0.2.20->equinox)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping, equinox
Successfully installed equinox-0.11.3 jaxtyping-0.2.25 typeguard-2.13.3
Collecting tensorcircuit
  Downloading tensorcircuit-0.

Collecting qiskit
  Downloading qiskit-1.0.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rustworkx>=0.14.0 (from qiskit)
  Downloading rustworkx-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m66.2 MB/s[0m eta [36m0:00:00[0m
Collecting dill>=0.3 (from qiskit)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
Collecting stevedore>=3.0.0 (from qiskit)
  Downloading stevedore-5.1.0-py3-none-any.whl (49 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.6/49.6 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Collecting symengine>=0.11 (from qiskit)
  Downloading symengine-0.11.0-cp310

In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import tensorcircuit as tc
import optax  # https://github.com/deepmind/optax
from typing import Optional

class OneQubitRotation(eqx.Module):
  n_qubits: int
  qubit_list: jax.Array
  params: jax.Array
  state: jax.Array
  return_type: str

  def __init__(self, n_qubits):

    # if state:
    #   c_ = tc.Circuit(n_qubits, inputs=state)
    # else:
    #   c_ = tc.Circuit(n_qubits)
    self.circuit = tc.Circuit(n_qubits)

  def __call__(self, qubit_list, params, return_type: str):
    for qubit_idx in qubit_list:
      self.circuit_.rx(qubit_idx, theta=params[0])
      self.circuit_.ry(qubit_idx, theta=params[1])
      self.circuit.rz(qubit_idx, theta=params[2])

    if return_type == 'circuit':
      return self.circuit
    elif return_type == 'state':
      return self.circuit.state()

In [None]:
oqr = OneQubitRotation(n_qubits=4)

AttributeError: Cannot set attribute circuit

In [11]:
# !pip install equinox
# !pip install tensorcircuit
# !pip install qiskit
# !pip install tensorcircuit
# !pip install cirq
# !pip install openfermion
!pip install gymnax
!pip install brax
# !pip install purejaxrl

Collecting brax
  Downloading brax-0.10.0-py3-none-any.whl (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
Collecting dm-env (from brax)
  Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Collecting flask-cors (from brax)
  Downloading Flask_Cors-4.0.0-py2.py3-none-any.whl (14 kB)
Collecting jaxopt (from brax)
  Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.3/172.3 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
Collecting ml-collections (from brax)
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mujoco (from brax)
  Downloading mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [40]:
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import tensorcircuit as tc
from typing import Optional

# class OneQubitRotation(nn.Module):
#     n_qubits: int

#     def setup(self):
#         # Initialize the circuit
#         self.circuit = tc.Circuit(self.n_qubits)

#     def __call__(self, qubit_list, params, return_type: str):
#         for qubit_idx in qubit_list:
#             self.circuit.rx(qubit_idx, theta=params[0])
#             self.circuit.ry(qubit_idx, theta=params[1])
#             self.circuit.rz(qubit_idx, theta=params[2])

#         if return_type == 'circuit':
#             return self.circuit
#         elif return_type == 'state':
#             return self.circuit.state()

class OneQubitRotation(nn.Module):
    n_qubits: int

    def setup(self):
        # Initialize the circuit
        self.circuit = tc.Circuit(self.n_qubits)

    def __call__(self, qubit_list, params, return_type: str):
        # Call the apply method
        return self.apply({'params': params}, qubit_list, return_type)

    def apply(self, inputs, qubit_list, return_type: str):
        # Retrieve parameters
        params = inputs['params']

        # Apply rotation gates to each qubit in the list
        for qubit_idx in qubit_list:
            self.circuit.rx(qubit_idx, theta=params[0])
            self.circuit.ry(qubit_idx, theta=params[1])
            self.circuit.rz(qubit_idx, theta=params[2])

        # Determine return type
        if return_type == 'circuit':
            return self.circuit
        elif return_type == 'state':
            return self.circuit.state()


In [29]:
oqr = OneQubitRotation(n_qubits=4)

In [12]:
import jax
import jax.numpy as jnp
import chex
import numpy as np
from flax import struct
from functools import partial
import tensorcircuit as tc
import equinox as eqx
from typing import Union, Sequence, List, NamedTuple, Optional, Tuple, Any
from gymnax.environments import environment, spaces
from brax import envs
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper

In [13]:
# shamelessly taken from purejaxrl: https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/wrappers.py

class GymnaxWrapper(object):
    """Base class for Gymnax wrappers."""

    def __init__(self, env):
        self._env = env

    # provide proxy access to regular attributes of wrapped object
    def __getattr__(self, name):
        return getattr(self._env, name)


class FlattenObservationWrapper(GymnaxWrapper):
    """Flatten the observations of the environment."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    def observation_space(self, params) -> spaces.Box:
        assert isinstance(
            self._env.observation_space(params), spaces.Box
        ), "Only Box spaces are supported for now."
        return spaces.Box(
            low=self._env.observation_space(params).low,
            high=self._env.observation_space(params).high,
            shape=(np.prod(self._env.observation_space(params).shape),),
            dtype=self._env.observation_space(params).dtype,
        )

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, state = self._env.reset(key, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, state, reward, done, info = self._env.step(key, state, action, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state, reward, done, info


@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int
    timestep: int


class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params)
        state = LogEnvState(env_state, 0, 0, 0, 0, 0)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        new_episode_return = state.episode_returns + reward
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - done),
            episode_lengths=new_episode_length * (1 - done),
            returned_episode_returns=state.returned_episode_returns * (1 - done)
            + new_episode_return * done,
            returned_episode_lengths=state.returned_episode_lengths * (1 - done)
            + new_episode_length * done,
            timestep=state.timestep + 1,
        )
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["timestep"] = state.timestep
        info["returned_episode"] = done
        return obs, state, reward, done, info

In [15]:
def generate_circuit(n_qubits, n_layers, rot_params, input_params, X):
  circuit = tc.circuit(n_qubits)
  # params = np.random.normal(size=(n_layers + 1, n_qubits, 3))
  # inputs = np.random.normal(size=(n_layers, n_qubits))

  for l in range(n_layers):
    # variational part
    for qubit_idx in range(n_qubits):
      circuit.rx(qubit_idx, theta=rot_params[l, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=rot_params[l, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=rot_params[l, qubit_idx, 2])

    # entangling part
    for qubit_idx in range(n_qubits - 1):
      circuit.cnot(qubit_idx, qubit_idx + 1)
    if n_qubits != 2:
      circuit.cnot(n_qubits - 1, 0)

    # encoding part
    for qubit_idx in range(n_qubits):
      input = X[qubit_idx] * input_params[l, qubit_idx]
      circuit.rx(qubit_idx, theta=input)

  # last variational part
  for qubit_idx in range(n_qubits):
    circuit.rx(qubit_idx, theta=rot_params[n_layers, qubit_idx, 0])
    circuit.ry(qubit_idx, theta=rot_params[n_layers, qubit_idx, 1])
    circuit.rz(qubit_idx, theta=rot_params[n_layers, qubit_idx, 2])

  return circuit


class PQCLayer(eqx.Module):
  theta: jax.Array
  lmbd: jax.Array
  n_qubits: int
  n_layers: int

  def __init__(self, n_qubits: int, n_layers: int, key: int):
    key = jax.random.PRNGKey(key)
    tkey, lkey = jax.random.split(key, num=2)
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    # rotation_params
    self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi)
    # input encoding params
    self.lmbd = jnp.ones(shape=(n_layers, n_qubits))

  def __call__(self, inputs):
  # def __call__(self, X, n_qubits, depth):

    circuit = generate_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)
    # state = circuit.state()
    # return state
    return jnp.real(circuit.expectation_ps(z=jnp.arange(self.n_qubits)))

class Alternating(eqx.Module):
  w: jax.Array

  def __init__(self, output_dim):
    self.w = jnp.array([[(-1.) ** i for i in range(output_dim)]])

  def __call__(self, inputs):
    return jnp.matmul(inputs, self.w)


class Actor(eqx.Module):
  n_qubits: int
  n_layers: int
  beta: float
  n_actions: Sequence[int]
  key: int

  def __call__(self, x):
    re_uploading_pqc = PQCLayer(n_qubits=self.n_qubits,
                                n_layers=self.n_layers,
                                key=self.key)(x)

    process = eqx.nn.Sequential([
        Alternating(self.n_actions),
        eqx.nn.Lambda(lambda x: x * self.beta),
        jax.nn.softmax()
    ])

    policy = process(re_uploading_pqc)

    return


class Transition(NamedTuple):
  done: jnp.ndarray
  action: jnp.ndarray
  value: jnp.ndarray
  reward: jnp.ndarray
  log_prob: jnp.ndarray
  obs: jnp.ndarray
  info: jnp.ndarray

In [7]:
def make_train(config):


ModuleNotFoundError: No module named 'gymnax'

In [1]:
actor = Actor(n_qubits=4, n_layers=5, beta=1.0, )



NameError: name 'eqx' is not defined

In [25]:
data_reuploading = PQCLayer(n_qubits=2, n_layers=2, key=42)

In [16]:
def generate_model_policy(n_qubits, n_layers, n_actions, beta):

    input_tensor = tf.keras.Input(shape=(len(qubits),), dtype=tf.dtypes.float32, name='input')

    re_uploading_pqc = PQCLayer(n_qubits, n_layers, 42)

    process = eqx.nn.Sequential([
        Alternating(n_actions),
        eqx.nn.Lambda(lambda x: x * beta),
        jax.nn.softmax()
    ])
    policy = process(re_uploading_pqc)
    model = tf.keras.Model(inputs=[input_tensor], outputs=policy, name="QuantumActor")

    return model

In [22]:
c = tc.Circuit(4)
jnp.real(c.expectation_ps(z=[0,1,2,3]))

Array(1., dtype=float32)