<a href="https://colab.research.google.com/github/Zakuta/D-QRL/blob/main/Reuploading_PQC_optax_working_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
# !pip install equinox
# !pip install tensorcircuit
# !pip install qiskit
# !pip install tensorcircuit
# !pip install cirq
# !pip install openfermion
# !pip install gymnax
# !pip install brax
# !pip install distrax



In [1]:
import jax
from jax import config

config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)

import jax.numpy as jnp
DTYPE=jnp.float64

import chex
import numpy as np
import optax
from flax import struct
from functools import partial
import tensorcircuit as tc

import tensorflow as tf
from sklearn.decomposition import PCA
import equinox as eqx
import types
from jaxtyping import Array, PRNGKeyArray
from typing import Union, Sequence, List, NamedTuple, Optional, Tuple, Any, Literal, TypeVar
import jax.tree_util as jtu
from gymnax.environments import environment, spaces
from brax import envs
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper

K = tc.set_backend("jax")



In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

x_train, x_test = x_train[..., np.newaxis] / 255.0, x_test[..., np.newaxis] / 255.0  # normalize the data

def filter(x, y, a, b):
    keep = (y == a) | (y == b)
    x, y = x[keep], y[keep]
    y = y == a
    return x, y

# Filter out classes 0 and 1
x_train, y_train = filter(x_train, y_train, 0, 1)
x_test, y_test = filter(x_test, y_test, 0, 1)

def apply_pca(X, n_components):
    X_flat = np.array([x.flatten() for x in X])
    pca = PCA(n_components=n_components)
    X_pca = pca.fit_transform(X_flat)
    return X_pca

n_components = 4
x_train = apply_pca(x_train, n_components)
x_test = apply_pca(x_test, n_components)

x_train = jnp.array(x_train, dtype=DTYPE)
x_test = jnp.array(x_test, dtype=DTYPE)
y_train = jnp.array(y_train, dtype=DTYPE)
y_test = jnp.array(y_test, dtype=DTYPE)

x_test.shape

  x_train = jnp.array(x_train, dtype=DTYPE)
  x_test = jnp.array(x_test, dtype=DTYPE)
  y_train = jnp.array(y_train, dtype=DTYPE)
  y_test = jnp.array(y_test, dtype=DTYPE)


(2000, 4)

In [3]:
class PQCLayer(eqx.Module):
  theta: Array
  lmbd: Array
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)

  def __init__(self, n_qubits: int, n_layers: int, key: PRNGKeyArray):
    key = jax.random.PRNGKey(key)
    tkey, lkey = jax.random.split(key, num=2)
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    # rotation_params
    self.theta = jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE)
    # input encoding params
    # self.lmbd = jnp.ones(shape=(n_layers, n_qubits))
    self.lmbd = jax.random.uniform(key=lkey, shape=(n_layers, n_qubits),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE)

    self.params = {'thetas': self.theta, 'lmbds': self.lmbd}

  def __call__(self, inputs):
  # def __call__(self, X, n_qubits, depth):

    # circuit = generate_circuit(self.n_qubits, self.n_layers, self.theta, self.lmbd, inputs)
    circuit = tc.Circuit(self.n_qubits)

    for l in range(self.n_layers):
      # variational part
      for qubit_idx in range(self.n_qubits):
        circuit.rx(qubit_idx, theta=self.params['thetas'][l, qubit_idx, 0])
        circuit.ry(qubit_idx, theta=self.params['thetas'][l, qubit_idx, 1])
        circuit.rz(qubit_idx, theta=self.params['thetas'][l, qubit_idx, 2])

      # entangling part
      for qubit_idx in range(self.n_qubits - 1):
        circuit.cnot(qubit_idx, qubit_idx + 1)
      if self.n_qubits != 2:
        circuit.cnot(self.n_qubits - 1, 0)

      # encoding part
      for qubit_idx in range(self.n_qubits):
        linear_input = inputs[qubit_idx] * self.params['lmbds'][l, qubit_idx]
        circuit.rx(qubit_idx, theta=linear_input)

    # last variational part
    for qubit_idx in range(self.n_qubits):
      circuit.rx(qubit_idx, theta=self.params['thetas'][self.n_layers, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 2])

    return jnp.real(circuit.expectation_ps(z=[0,1,2,3]))

In [66]:
re_uploadingpqc = PQCLayer(n_qubits=4, n_layers=5, key=600)
t0 = re_uploadingpqc.theta
l0 = re_uploadingpqc.lmbd


@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    loss = jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y)
    return jnp.mean(loss)

# @eqx.filter_jit
# def make_step(model, x, y, opt_state):
#     loss, grads = compute_loss(model, x, y)
#     updates, opt_state = optim.update(grads, opt_state)
#     model = eqx.apply_updates(model, updates)
#     return loss, model, opt_state

# optim = optax.adam(1e-2)
# # optim = closure_to_pytree(optim)
# opt_state = optim.init(re_uploadingpqc)

# steps = 200

# for step in range(steps):
#   batch_idx = np.random.randint(0, len(x_train), 16)
#   x_train_batch = np.array([x_train[i] for i in batch_idx])
#   y_train_batch = np.array([y_train[i] for i in batch_idx])
#   y_train_batch = np.array([int(value) for value in y_train_batch])

#   loss, re_uploadingpqc, opt_state = make_step(re_uploadingpqc, x_train_batch, y_train_batch, opt_state)
#   loss = loss.item()
#   print(f"step={step}, loss={loss}")

In [67]:
compute_loss(re_uploadingpqc, x_train_batch, y_train_batch)

(Array(0.92853316, dtype=float64),
 PQCLayer(theta=f64[6,4,3], lmbd=f64[5,4], n_qubits=4, n_layers=5))

In [61]:
# has_theta = lambda x: hasattr(x, "theta")
# where_theta = lambda m: tuple(x.theta for x in jax.tree_leaves(m, is_leaf=has_theta) if has_theta(x))
# where_theta(params)

# param_spec = eqx.tree_at(where_theta, params, replace_fn=lambda _: "group2")

# has_lmbd = lambda x: hasattr(x, "lmbd")
# where_lmbd = lambda m: tuple(x.lmbd for x in jax.tree_leaves(m, is_leaf=has_theta) if has_lmbd(x))
# where_lmbd(param_spec)

# param_spec = eqx.tree_at(where_lmbd, param_spec, replace_fn=lambda _: "group1")

  where_theta = lambda m: tuple(x.theta for x in jax.tree_leaves(m, is_leaf=has_theta) if has_theta(x))
  where_lmbd = lambda m: tuple(x.lmbd for x in jax.tree_leaves(m, is_leaf=has_theta) if has_lmbd(x))


In [62]:
# optim = optax.multi_transform({"group2": optax.adam(1e-1),
#     "group1": optax.adam(1e-0),
#     },
#     param_spec
# )

In [64]:
eqx.filter(re_uploadingpqc, eqx.is_array)

PQCLayer(theta=f64[6,4,3], lmbd=f64[5,4], n_qubits=4, n_layers=5)

In [4]:
# https://github.com/google-deepmind/optax/issues/577
# shamelessly taken from Patrick Kidger's issue from optax github. He's a legend, basically converted the optim instance of optax.optimizer to pytree!

def _make_cell(val):
    fn = lambda: val
    return fn.__closure__[0]  # pyright: ignore


def _adjust_function_closure(fn, closure):
    out = types.FunctionType(
        code=fn.__code__,
        globals=fn.__globals__,
        name=fn.__name__,
        argdefs=fn.__defaults__,
        closure=closure,
    )
    out.__module__ = fn.__module__
    out.__qualname__ = fn.__qualname__
    out.__doc__ = fn.__doc__
    out.__annotations__.update(fn.__annotations__)
    if fn.__kwdefaults__ is not None:
        out.__kwdefaults__ = fn.__kwdefaults__.copy()
    return out


# Not a pytree.
# Used so that two different local functions, with different identities, can still
# compare equal. This is needed as these leaves are compared statically when
# filter-jit'ing.
class _FunctionWithEquality:
    def __init__(self, fn: types.FunctionType):
        self.fn = fn

    def information(self):
        return self.fn.__qualname__, self.fn.__module__

    def __hash__(self):
        return hash(self.information())

    def __eq__(self, other):
        return type(self) == type(other) and self.information() == other.information()


class _Closure(eqx.Module):
    fn: _FunctionWithEquality
    contents: Optional[tuple[Any, ...]]

    def __init__(self, fn: types.FunctionType):
        self.fn = _FunctionWithEquality(fn)
        if fn.__closure__ is None:
            contents = None
        else:
            contents = tuple(
                closure_to_pytree(cell.cell_contents) for cell in fn.__closure__
            )
        self.contents = contents

    def __call__(self, *args, **kwargs):
        if self.contents is None:
            closure = None
        else:
            closure = tuple(_make_cell(contents) for contents in self.contents)
        fn = _adjust_function_closure(self.fn.fn, closure)
        return fn(*args, **kwargs)


def _fixup_closure(leaf):
    if isinstance(leaf, types.FunctionType):
        return _Closure(leaf)
    else:
        return leaf


def closure_to_pytree(tree):
    """Convert all function closures into pytree nodes.

    **Arguments:**

    - `tree`: Any pytree.

    **Returns:**

    A copy of `tree`, where all function closures have been replaced by a new object
    that is (a) callable like the original function, but (b) iterates over its
    `__closure__` as subnodes in the pytree.

    !!! Example

        ```python
        def some_fn():
            a = jnp.array(1.)

            @closure_to_pytree
            def f(x):
                return x + a

            print(jax.tree_util.tree_leaves(f))  # prints out `a`
        ```

    !!! Warning

        One annoying technical detail in the above example: we had to wrap the whole lot
        in a `some_fn`, so that we're in a local scope. Python treats functions at the
        global scope differently, and this conversion won't result in any global
        variable being treated as part of the pytree.

        In practice, the intended use case of this function is to fix Optax, which
        always uses local functions.
    """
    return jtu.tree_map(_fixup_closure, tree)


# EXAMPLE USAGE
# lr = jnp.array(1e-3)
# optim = optax.chain(
#     optax.adam(lr),
#     optax.scale_by_schedule(optax.piecewise_constant_schedule(1, {200: 0.1})),
# )
# optim = closure_to_pytree(optim)


In [5]:
# The saviour of all the demons!
# https://github.com/patrick-kidger/equinox/issues/256
class PQCLayer(eqx.Module):
  theta: Array
  lmbd: Array
  n_qubits: int = eqx.field(static=True)
  n_layers: int = eqx.field(static=True)

  def __init__(self, n_qubits: int, n_layers: int, params):
    self.n_qubits = n_qubits
    self.n_layers = n_layers
    # rotation_params
    self.theta = params['thetas']
    self.lmbd = params['lmbds']

  def __call__(self, inputs):
    circuit = tc.Circuit(self.n_qubits)

    for l in range(self.n_layers):
      # variational part
      for qubit_idx in range(self.n_qubits):
        circuit.rx(qubit_idx, theta=self.theta[l, qubit_idx, 0])
        circuit.ry(qubit_idx, theta=self.theta[l, qubit_idx, 1])
        circuit.rz(qubit_idx, theta=self.theta[l, qubit_idx, 2])

      # entangling part
      for qubit_idx in range(self.n_qubits - 1):
        circuit.cnot(qubit_idx, qubit_idx + 1)
      if self.n_qubits != 2:
        circuit.cnot(self.n_qubits - 1, 0)

      # encoding part
      for qubit_idx in range(self.n_qubits):
        linear_input = inputs[qubit_idx] * self.lmbd[l, qubit_idx]
        circuit.rx(qubit_idx, theta=linear_input)

    # last variational part
    for qubit_idx in range(self.n_qubits):
      circuit.rx(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 0])
      circuit.ry(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 1])
      circuit.rz(qubit_idx, theta=self.theta[self.n_layers, qubit_idx, 2])

    return jnp.real(circuit.expectation_ps(z=[0,1,2,3]))

In [9]:
n_layers = 5
n_qubits = 4
key = 42
key = jax.random.PRNGKey(key)
tkey, lkey = jax.random.split(key, num=2)
params = {'thetas': jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE),
          'lmbds': jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)}


model = PQCLayer(n_qubits=n_qubits,
                 n_layers=n_layers,
                 params=params)


def map_nested_fn(fn):
  '''Recursively apply `fn` to the key-value pairs of a nested dict'''
  def map_fn(nested_dict):
    return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
            for k, v in nested_dict.items()}
  return map_fn

# gradients = jax.tree_util.tree_map(jnp.ones_like, params)  # dummy gradients

label_fn = map_nested_fn(lambda k, _: k)
optim = optax.multi_transform({'thetas': optax.adam(0.01), 'lmbds': optax.adam(0.001)},
                           label_fn)

# optim = closure_to_pytree(optim)
opt_state = optim.init(params)

@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    loss = jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y)
    return jnp.mean(loss)

def apply_updates_to_model(model, new_params):
  # this function is specific and works for this example only. one can think of
  # generalizing it to work for updating any given attribute/params of the model.

  model_new = eqx.tree_at(where=lambda model: model.theta, pytree=model, replace=new_params['thetas'])
  model_new = eqx.tree_at(where=lambda model: model.lmbd, pytree=model_new, replace=new_params['lmbds'])

  return model_new


@eqx.filter_jit
def make_step(model, x, y, opt_state, params):
    loss, grads = compute_loss(model, x, y)
    grads = {'thetas': grads.theta, 'lmbds': grads.lmbd}
    updates, opt_state = optim.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    model_new = apply_updates_to_model(model, new_params)
    # model_new = eqx.tree_at(where=lambda model: model.theta, pytree=model, replace=new_params['thetas'])
    # model_new = eqx.tree_at(where=lambda model: model.lmbd, pytree=model_new, replace=new_params['lmbds'])
    # model = eqx.apply_updates(model, updates)
    return loss, model_new, opt_state

for step in range(100):
  batch_idx = np.random.randint(0, len(x_train), 32)
  x_train_batch = np.array([x_train[i] for i in batch_idx])
  y_train_batch = np.array([y_train[i] for i in batch_idx])
  y_train_batch = np.array([int(value) for value in y_train_batch])

  loss, model, opt_state = make_step(model, x_train_batch, y_train_batch, opt_state, params)
  loss = loss.item()
  print(f"step={step}, loss={loss}")

# opt_init, opt_update = optax.adam(0.01)
# opt_state = opt_init(eqx.filter(model, eqx.is_array))
# flat_model, treedef_model = jtu.tree_flatten(model)
# flat_opt_state, treedef_opt_state = jtu.tree_flatten(opt_state)

  params = {'thetas': jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
  'lmbds': jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)}


step=0, loss=0.9442745447158813
step=1, loss=0.9854795336723328
step=2, loss=0.9921755790710449
step=3, loss=1.0407124757766724
step=4, loss=0.9191699624061584
step=5, loss=0.9675570726394653
step=6, loss=0.9539644718170166
step=7, loss=0.979794442653656
step=8, loss=1.0050121545791626
step=9, loss=0.9828323721885681
step=10, loss=1.0165404081344604
step=11, loss=1.0894443988800049
step=12, loss=0.9302136898040771
step=13, loss=1.0786528587341309
step=14, loss=1.0308709144592285
step=15, loss=1.008302927017212
step=16, loss=0.9381150007247925
step=17, loss=1.0086443424224854
step=18, loss=0.9673950672149658
step=19, loss=0.9813898801803589
step=20, loss=1.079079270362854
step=21, loss=0.9439525604248047
step=22, loss=0.9594672918319702
step=23, loss=1.040024995803833
step=24, loss=1.0713002681732178
step=25, loss=0.9770899415016174
step=26, loss=0.9704107642173767
step=27, loss=0.9381617903709412
step=28, loss=0.9306100010871887
step=29, loss=1.0057768821716309
step=30, loss=0.99195367

In [145]:
gradients = jax.tree_util.tree_map(jnp.ones_like, params)  # dummy gradients


In [196]:
n_layers = 5
n_qubits = 4
key = 42
key = jax.random.PRNGKey(key)
tkey, lkey = jax.random.split(key, num=2)
params = {'thetas': jax.random.uniform(key=tkey, shape=(n_layers + 1, n_qubits, 3),
                                    minval=0.0, maxval=np.pi, dtype=DTYPE),
          'lmbds': jnp.ones(shape=(n_layers, n_qubits), dtype=DTYPE)}


model = PQCLayer(n_qubits=n_qubits,
                 n_layers=n_layers,
                 params=params)


def map_nested_fn(fn):
  '''Recursively apply `fn` to the key-value pairs of a nested dict'''
  def map_fn(nested_dict):
    return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
            for k, v in nested_dict.items()}
  return map_fn

# gradients = jax.tree_util.tree_map(jnp.ones_like, params)  # dummy gradients

label_fn = map_nested_fn(lambda k, _: k)
optim = optax.multi_transform({'thetas': optax.adam(0.01), 'lmbds': optax.adam(0.001)},
                           label_fn)
opt_state = optim.init(params)
# updates, new_state = optim.update(gradients, opt_state, params)
# new_params = optax.apply_updates(params, updates)

@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    loss = jnp.maximum(0, 1 - (2.0 * y - 1.0) * pred_y)
    return jnp.mean(loss)


loss, grads = compute_loss(model, x_train[:16], y_train[:16])

In [197]:
grads = {'thetas': grads.theta, 'lmbds': grads.lmbd}
updates, opt_state = optim.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)

In [198]:
new_params

{'lmbds': Array([[1.001, 1.001, 0.999, 0.999],
        [0.999, 1.001, 0.999, 1.001],
        [0.999, 0.999, 0.999, 1.001],
        [1.001, 1.001, 1.001, 0.999],
        [1.001, 0.999, 0.999, 0.999]], dtype=float64),
 'thetas': Array([[[0.85259273, 0.93874143, 0.89688467],
         [2.67673558, 0.77266714, 1.51313364],
         [1.05774572, 2.68004876, 1.01810876],
         [2.46738272, 2.01423164, 1.70877908]],
 
        [[2.67148045, 0.40126361, 0.21946644],
         [0.69720534, 0.3797383 , 1.58548773],
         [0.05340149, 1.71885032, 2.6753669 ],
         [3.07993527, 1.5659381 , 1.79930547]],
 
        [[0.80809018, 2.52548216, 3.07320156],
         [0.65879265, 1.52784092, 2.44956667],
         [2.24907099, 2.23584914, 2.14072183],
         [2.83810737, 2.65760658, 2.58974501]],
 
        [[0.66692702, 0.97194887, 2.8846463 ],
         [1.55491374, 1.60811205, 0.94195518],
         [1.649293  , 1.69987386, 1.80819792],
         [2.20201573, 0.96468078, 1.84464712]],
 
        [[

Array([[[0.84259273, 0.94874143, 0.90688467],
        [2.6667356 , 0.76266714, 1.52313363],
        [1.06774572, 2.69004876, 1.00810876],
        [2.45738273, 2.02423163, 1.71877908]],

       [[2.68148044, 0.39126363, 0.20946645],
        [0.70720533, 0.3897383 , 1.57548773],
        [0.0434015 , 1.72885032, 2.6853669 ],
        [3.06993527, 1.55593812, 1.80930546]],

       [[0.79809019, 2.51548217, 3.08320155],
        [0.66879265, 1.53784092, 2.43956667],
        [2.25907099, 2.24584914, 2.13072184],
        [2.84810735, 2.66760658, 2.59974501]],

       [[0.67692702, 0.98194887, 2.8746463 ],
        [1.54491374, 1.59811206, 0.95195518],
        [1.65929299, 1.70987385, 1.79819793],
        [2.21201573, 0.95468079, 1.83464712]],

       [[2.71668027, 1.03779524, 2.34367556],
        [2.83988658, 1.66392737, 1.57174783],
        [2.08979208, 2.66964127, 0.54374638],
        [0.03565717, 0.21567255, 1.88023176]],

       [[0.68780493, 0.15546439, 2.48979877],
        [1.87962935, 1.7

In [201]:
print(model.theta, model.lmbd)

[[[0.84259273 0.94874143 0.90688467]
  [2.6667356  0.76266714 1.52313363]
  [1.06774572 2.69004876 1.00810876]
  [2.45738273 2.02423163 1.71877908]]

 [[2.68148044 0.39126363 0.20946645]
  [0.70720533 0.3897383  1.57548773]
  [0.0434015  1.72885032 2.6853669 ]
  [3.06993527 1.55593812 1.80930546]]

 [[0.79809019 2.51548217 3.08320155]
  [0.66879265 1.53784092 2.43956667]
  [2.25907099 2.24584914 2.13072184]
  [2.84810735 2.66760658 2.59974501]]

 [[0.67692702 0.98194887 2.8746463 ]
  [1.54491374 1.59811206 0.95195518]
  [1.65929299 1.70987385 1.79819793]
  [2.21201573 0.95468079 1.83464712]]

 [[2.71668027 1.03779524 2.34367556]
  [2.83988658 1.66392737 1.57174783]
  [2.08979208 2.66964127 0.54374638]
  [0.03565717 0.21567255 1.88023176]]

 [[0.68780493 0.15546439 2.48979877]
  [1.87962935 1.77782177 2.84671802]
  [1.14221602 2.23558458 0.38788242]
  [2.82405449 2.84739362 2.15755409]]] [[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]


In [205]:
model_new.theta

Array([[[0.85259273, 0.93874143, 0.89688467],
        [2.67673558, 0.77266714, 1.51313364],
        [1.05774572, 2.68004876, 1.01810876],
        [2.46738272, 2.01423164, 1.70877908]],

       [[2.67148045, 0.40126361, 0.21946644],
        [0.69720534, 0.3797383 , 1.58548773],
        [0.05340149, 1.71885032, 2.6753669 ],
        [3.07993527, 1.5659381 , 1.79930547]],

       [[0.80809018, 2.52548216, 3.07320156],
        [0.65879265, 1.52784092, 2.44956667],
        [2.24907099, 2.23584914, 2.14072183],
        [2.83810737, 2.65760658, 2.58974501]],

       [[0.66692702, 0.97194887, 2.8846463 ],
        [1.55491374, 1.60811205, 0.94195518],
        [1.649293  , 1.69987386, 1.80819792],
        [2.20201573, 0.96468078, 1.84464712]],

       [[2.72668026, 1.02779524, 2.33367556],
        [2.84988657, 1.67392737, 1.56174784],
        [2.07979209, 2.65964128, 0.53374639],
        [0.04565717, 0.22567254, 1.87023176]],

       [[0.69780493, 0.16546439, 2.48708459],
        [1.86962935, 1.7

In [188]:
import jax

def update_theta_and_lmbd(model, params):
    def update_fn(subtree):
        if isinstance(subtree, dict):
            # Update 'theta' attribute if present
            if 'theta' in subtree:
                subtree['theta'] = params['thetas']
            # Update 'lmbd' attribute if present
            if 'lmbd' in subtree:
                subtree['lmbd'] = params['lmbds']
        return subtree

    return jax.tree_map(update_fn, model)

# Example usage:
# model_updated = eqx.tree_at(where=lambda model: model.theta, pytree=model, replace=params['thetas'])
# model_updated = eqx.tree_at(where=lambda model: model.lmbd, pytree=model_updated, replace=params['lmbd'])
model_updated = update_theta_and_lmbd(model, new_params)

Array([[1.001, 1.001, 0.999, 0.999],
       [0.999, 1.001, 0.999, 1.001],
       [0.999, 0.999, 0.999, 1.001],
       [1.001, 1.001, 1.001, 0.999],
       [1.001, 0.999, 0.999, 0.999]], dtype=float64)

In [156]:
updates

{'lmbds': Array([[ 0.001,  0.001, -0.001, -0.001],
        [-0.001,  0.001, -0.001,  0.001],
        [-0.001, -0.001, -0.001,  0.001],
        [ 0.001,  0.001,  0.001, -0.001],
        [ 0.001, -0.001, -0.001, -0.001]], dtype=float64),
 'thetas': Array([[[ 0.01      , -0.01      , -0.01      ],
         [ 0.00999999,  0.01      , -0.01      ],
         [-0.01      , -0.01      ,  0.01      ],
         [ 0.01      , -0.01      , -0.01      ]],
 
        [[-0.00999999,  0.00999998,  0.01      ],
         [-0.00999999, -0.01      ,  0.01      ],
         [ 0.01      , -0.01      , -0.01      ],
         [ 0.00999999,  0.00999999, -0.00999999]],
 
        [[ 0.00999999,  0.00999999, -0.01      ],
         [-0.01      , -0.01      ,  0.01      ],
         [-0.01      , -0.01      ,  0.00999999],
         [-0.00999998, -0.01      , -0.01      ]],
 
        [[-0.01      , -0.01      ,  0.01      ],
         [ 0.01      ,  0.00999999, -0.01      ],
         [-0.00999999, -0.00999999,  0.009999

In [124]:
import optax
import jax
import jax.numpy as jnp



In [77]:



# @eqx.filter_jit
# def make_step(flat_model, flat_opt_state, x, y):
#     model = jtu.tree_unflatten(treedef_model, flat_model)
#     opt_state = jtu.tree_unflatten(treedef_opt_state, flat_opt_state)
#     loss, grads = compute_loss(model, x, y)
#     updates, opt_state = opt_update(grads, opt_state)
#     model = eqx.apply_updates(model, updates)
#     flat_model = jtu.tree_leaves(model)
#     flat_opt_state = jtu.tree_leaves(opt_state)
#     return loss, flat_model, flat_opt_state

(2000, 4)

0.9963174666323078
0.9857290538749609
0.9751737059779213
0.9647824093088274
0.9547015993555735
0.9448562439507416
0.935205648239081
0.9258598069544338
0.9169494841364357
0.9085765941426701
0.9007932588974266
0.8935966540239557
0.8869450713482607
0.880784429170754
0.8750709038856439
0.8697800936516723
0.8649066040339191
0.8604563607789265
0.8564367507390658
0.8528446979983128
0.849653590634601
0.8468016393531482
0.8441903373571016
0.8417000211002499
0.8392147946308638
0.8366442105739261
0.8339308184209852
0.8310504618426687
0.828005415469505
0.824816503001018
0.8215147062129982
0.8181366654936185
0.8147214102472062
0.8113072717419417
0.8079284209458274
0.8046108311032294
0.8013690617893784
0.7982035903204305
0.7951025368397094
0.7920460999762872
0.7890159481975909
0.7860006750016085
0.783000384597292
0.7800245178242428
0.7770877110051126
0.7742036987572015
0.771378733801539
0.7686078784486453
0.7658750105492994
0.7631551672879043
0.7604168338465775
0.7576269349031112
0.7547511700378333


In [84]:
def accuracy(y_true, y_pred):
    # Computed with uniform distribution

    y_true = y_true > 0.0
    y_pred = y_pred >= 0.0
    result = y_true == y_pred

    return jnp.sum(result)/y_true.shape[0]

pred_ys = jax.vmap(model)(x_test)
# num_correct = jnp.sum((pred_ys > 0.5) == y_test)
# final_accuracy = (num_correct / x_test.shape[0]).item()
print(f"final_accuracy={accuracy(y_test, pred_ys)}")

final_accuracy=0.4915
