<a href="https://colab.research.google.com/github/Zakuta/D-QRL/blob/main/Reuploading_PQC_optax_working_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
# !pip install equinox
# !pip install tensorcircuit
# !pip install qiskit
# !pip install tensorcircuit
# !pip install cirq
# !pip install openfermion
# !pip install gymnax
# !pip install brax
# !pip install distrax



In [55]:
import jax
from jax import config

config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)

import jax.numpy as jnp
DTYPE=jnp.float64

import chex
import numpy as np
import optax
from flax import struct
from functools import partial
import tensorcircuit as tc

import tensorflow as tf
from sklearn.decomposition import PCA
import equinox as eqx
import types
from jaxtyping import Array, PRNGKeyArray
from typing import Union, Sequence, List, NamedTuple, Optional, Tuple, Any, Literal, TypeVar
import jax.tree_util as jtu
from gymnax.environments import environment, spaces
from brax import envs
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper

K = tc.set_backend("jax")

In [82]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

x_train, x_test = x_train[..., np.newaxis] / 255.0, x_test[..., np.newaxis] / 255.0  # normalize the data

def filter(x, y, a, b):
    keep = (y == a) | (y == b)
    x, y = x[keep], y[keep]
    y = y == a
    return x, y

# Filter out classes 0 and 1
x_train, y_train = filter(x_train, y_train, 0, 1)
x_test, y_test = filter(x_test, y_test, 0, 1)

def apply_pca(X, n_components):
    X_flat = np.array([x.flatten() for x in X])
    pca = PCA(n_components=n_components)
    X_pca = pca.fit_transform(X_flat)
    return X_pca

n_components = 4
x_train = apply_pca(x_train, n_components)
x_test = apply_pca(x_test, n_components)

x_train = jnp.array(x_train, dtype=DTYPE)
x_test = jnp.array(x_test, dtype=DTYPE)
y_train = jnp.array(y_train, dtype=DTYPE)
y_test = jnp.array(y_test, dtype=DTYPE)

x_test.shape

(2000, 4)

In [57]:
class PQCLayer(eqx.Module):
  theta: Array
  lmbd: Array
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)

  def __init__(self, n_qubits: int, n_layers: int, key: PRNGKeyArray):
    key = jax.random.PRNGKey(key)
    tkey, lkey = jax.random.split(key, num=2)
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    # rotation_params
    self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE)
    # input encoding params
    # self.lmbd = jnp.ones(shape=(n_layers, n_qubits))
    self.lmbd = jax.random.uniform(key=lkey, shape=(n_layers, n_qubits),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE)

    self.params = {'thetas': self.theta, 'lmbds': self.lmbd}

  def __call__(self, inputs):
  # def __call__(self, X, n_qubits, depth):

    # circuit = generate_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)
    circuit = tc.Circuit(self.n_qubits)

    for l in range(self.n_layers):
      # variational part
      for qubit_idx in range(self.n_qubits):
        circuit.rx(qubit_idx, theta=self.params['thetas'][l, qubit_idx, 0])
        circuit.ry(qubit_idx, theta=self.params['thetas'][l, qubit_idx, 1])
        circuit.rz(qubit_idx, theta=self.params['thetas'][l, qubit_idx, 2])

      # entangling part
      for qubit_idx in range(self.n_qubits - 1):
        circuit.cnot(qubit_idx, qubit_idx + 1)
      if self.n_qubits != 2:
        circuit.cnot(self.n_qubits - 1, 0)

      # encoding part
      for qubit_idx in range(self.n_qubits):
        linear_input = inputs[qubit_idx] * self.params['lmbds'][l, qubit_idx]
        circuit.rx(qubit_idx, theta=linear_input)

    # last variational part
    for qubit_idx in range(self.n_qubits):
      circuit.rx(qubit_idx, theta=self.params['thetas'][self.n_layers, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 2])

    return jnp.real(circuit.expectation_ps(z=[0,1,2,3]))

In [66]:
re_uploadingpqc = PQCLayer(n_qubits=4, n_layers=5, key=600)
t0 = re_uploadingpqc.theta
l0 = re_uploadingpqc.lmbd


@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    loss = jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y)
    return jnp.mean(loss)

# @eqx.filter_jit
# def make_step(model, x, y, opt_state):
#     loss, grads = compute_loss(model, x, y)
#     updates, opt_state = optim.update(grads, opt_state)
#     model = eqx.apply_updates(model, updates)
#     return loss, model, opt_state

# optim = optax.adam(1e-2)
# # optim = closure_to_pytree(optim)
# opt_state = optim.init(re_uploadingpqc)

# steps = 200

# for step in range(steps):
#   batch_idx = np.random.randint(0, len(x_train), 16)
#   x_train_batch = np.array([x_train[i] for i in batch_idx])
#   y_train_batch = np.array([y_train[i] for i in batch_idx])
#   y_train_batch = np.array([int(value) for value in y_train_batch])

#   loss, re_uploadingpqc, opt_state = make_step(re_uploadingpqc, x_train_batch, y_train_batch, opt_state)
#   loss = loss.item()
#   print(f"step={step}, loss={loss}")

In [67]:
compute_loss(re_uploadingpqc, x_train_batch, y_train_batch)

(Array(0.92853316, dtype=float64),
 PQCLayer(theta=f64[6,4,3], lmbd=f64[5,4], n_qubits=4, n_layers=5))

In [61]:
# has_theta = lambda x: hasattr(x, "theta")
# where_theta = lambda m: tuple(x.theta for x in jax.tree_leaves(m, is_leaf=has_theta) if has_theta(x))
# where_theta(params)

# param_spec = eqx.tree_at(where_theta, params, replace_fn=lambda _: "group2")

# has_lmbd = lambda x: hasattr(x, "lmbd")
# where_lmbd = lambda m: tuple(x.lmbd for x in jax.tree_leaves(m, is_leaf=has_theta) if has_lmbd(x))
# where_lmbd(param_spec)

# param_spec = eqx.tree_at(where_lmbd, param_spec, replace_fn=lambda _: "group1")

  where_theta = lambda m: tuple(x.theta for x in jax.tree_leaves(m, is_leaf=has_theta) if has_theta(x))
  where_lmbd = lambda m: tuple(x.lmbd for x in jax.tree_leaves(m, is_leaf=has_theta) if has_lmbd(x))


In [62]:
# optim = optax.multi_transform({"group2": optax.adam(1e-1),
#     "group1": optax.adam(1e-0),
#     },
#     param_spec
# )

In [64]:
eqx.filter(re_uploadingpqc, eqx.is_array)

PQCLayer(theta=f64[6,4,3], lmbd=f64[5,4], n_qubits=4, n_layers=5)

In [72]:
# The saviour of all the demons!
# https://github.com/patrick-kidger/equinox/issues/256
class PQCLayer(eqx.Module):
  theta: Array
  lmbd: Array
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)

  def __init__(self, n_qubits: int, n_layers: int, params):
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    # rotation_params
    self.theta = params['thetas']
    self.lmbd = params['lmbds']

  def __call__(self, inputs):
    circuit = tc.Circuit(self.n_qubits)

    for l in range(self.n_layers):
      # variational part
      for qubit_idx in range(self.n_qubits):
        circuit.rx(qubit_idx, theta=self.theta[l, qubit_idx, 0])
        circuit.ry(qubit_idx, theta=self.theta[l, qubit_idx, 1])
        circuit.rz(qubit_idx, theta=self.theta[l, qubit_idx, 2])

      # entangling part
      for qubit_idx in range(self.n_qubits - 1):
        circuit.cnot(qubit_idx, qubit_idx + 1)
      if self.n_qubits != 2:
        circuit.cnot(self.n_qubits - 1, 0)

      # encoding part
      for qubit_idx in range(self.n_qubits):
        linear_input = inputs[qubit_idx] * self.lmbd[l, qubit_idx]
        circuit.rx(qubit_idx, theta=linear_input)

    # last variational part
    for qubit_idx in range(self.n_qubits):
      circuit.rx(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 2])

    return jnp.real(circuit.expectation_ps(z=[0,1,2,3]))

In [73]:
n_layers = 5
n_qubits = 4
key = 42
key = jax.random.PRNGKey(key)
tkey, lkey = jax.random.split(key, num=2)
params = {'thetas': jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE),
          'lmbds': jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)}


model = PQCLayer(n_qubits=n_qubits,
                 n_layers=n_layers,
                 params=params)

opt_init, opt_update = optax.adam(0.01)
opt_state = opt_init(eqx.filter(model, eqx.is_array))
flat_model, treedef_model = jtu.tree_flatten(model)
flat_opt_state, treedef_opt_state = jtu.tree_flatten(opt_state)

[2465931498 3679230171] [255383827 267815257]


In [77]:
@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    loss = jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y)
    return jnp.mean(loss)

@eqx.filter_jit
def make_step(flat_model, flat_opt_state, x, y):
    model = jtu.tree_unflatten(treedef_model, flat_model)
    opt_state = jtu.tree_unflatten(treedef_opt_state, flat_opt_state)
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    flat_model = jtu.tree_leaves(model)
    flat_opt_state = jtu.tree_leaves(opt_state)
    return loss, flat_model, flat_opt_state

(2000, 4)

In [83]:
for step in range(100):
  loss, flat_model, flat_opt_state = make_step(flat_model, flat_opt_state, x_train, y_train)
  print(loss)

0.9963174666323078
0.9857290538749609
0.9751737059779213
0.9647824093088274
0.9547015993555735
0.9448562439507416
0.935205648239081
0.9258598069544338
0.9169494841364357
0.9085765941426701
0.9007932588974266
0.8935966540239557
0.8869450713482607
0.880784429170754
0.8750709038856439
0.8697800936516723
0.8649066040339191
0.8604563607789265
0.8564367507390658
0.8528446979983128
0.849653590634601
0.8468016393531482
0.8441903373571016
0.8417000211002499
0.8392147946308638
0.8366442105739261
0.8339308184209852
0.8310504618426687
0.828005415469505
0.824816503001018
0.8215147062129982
0.8181366654936185
0.8147214102472062
0.8113072717419417
0.8079284209458274
0.8046108311032294
0.8013690617893784
0.7982035903204305
0.7951025368397094
0.7920460999762872
0.7890159481975909
0.7860006750016085
0.783000384597292
0.7800245178242428
0.7770877110051126
0.7742036987572015
0.771378733801539
0.7686078784486453
0.7658750105492994
0.7631551672879043
0.7604168338465775
0.7576269349031112
0.7547511700378333
