In [1]:
from functools import partial
import numpy as np
import tensorflow as tf
import jax
from jax.config import config

config.update("jax_enable_x64", True)
from jax import numpy as jnp
import optax
import tensorcircuit as tc

Please first ``pip install -U cirq`` to enable related functionality in translation module


In [2]:
tc.set_backend("tensorflow")
tc.set_dtype("complex128")

('complex128', 'float64')

In [3]:
# numpy data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0


def filter_pair(x, y, a, b):
    keep = (y == a) | (y == b)
    x, y = x[keep], y[keep]
    y = y == a
    return x, y


x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin)[:100]

2023-09-01 20:09:36.363369: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2023-09-01 20:09:36.367473: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 76286336 exceeds 10% of free system memory.


In [4]:
# tensorflow data

x_train_tf = tf.reshape(tf.constant(x_train_bin, dtype=tf.float64), [-1, 9])
y_train_tf = tf.constant(y_train[:100], dtype=tf.float64)

# jax data

x_train_jax = jnp.array(x_train_bin, dtype=np.float64).reshape([100, -1])
y_train_jax = jnp.array(y_train[:100], dtype=np.float64).reshape([100])

In [18]:
mnist_data = (
    tf.data.Dataset.from_tensor_slices((x_train_tf, y_train_tf))
    .repeat(200)
    .shuffle(100)
    .batch(32)
)

In [5]:
tc.set_backend("jax")

jax_backend

In [12]:
nlayers = 3

def qml_ys(x, weights, nlayers):
    n = 9
    weights = tc.backend.cast(weights, "complex128")
    x = tc.backend.cast(x, "complex128")
    c = tc.Circuit(n)
    for i in range(n):
        c.rx(i, theta=x[i])
    for j in range(nlayers):
        for i in range(n - 1):
            c.cnot(i, i + 1)
        for i in range(n):
            c.rx(i, theta=weights[2 * j, i])
            c.ry(i, theta=weights[2 * j + 1, i])
    ypreds = []
    for i in range(n):
        ypred = c.expectation([tc.gates.z(), (i,)])
        ypred = tc.backend.real(ypred)
        ypred = (tc.backend.real(ypred) + 1) / 2.0
        ypreds.append(ypred)
    return tc.backend.stack(ypreds)

In [13]:
key = jax.random.PRNGKey(42)
key, *subkeys = jax.random.split(key, num=4)
params = {
    "qweights": jax.random.normal(subkeys[0], shape=[nlayers * 2, 9]),
    "cweights:w": jax.random.normal(subkeys[1], shape=[9]),
    "cweights:b": jax.random.normal(subkeys[2], shape=[1]),
}

In [14]:
def qml_hybrid_loss(x, y, params, nlayers):
    weights = params["qweights"]
    w = params["cweights:w"]
    b = params["cweights:b"]
    ypred = qml_ys(x, weights, nlayers)
    ypred = tc.backend.reshape(ypred, [-1, 1])
    ypred = w @ ypred + b
    ypred = jax.nn.sigmoid(ypred)
    ypred = ypred[0]
    loss = -y * tc.backend.log(ypred) - (1 - y) * tc.backend.log(1 - ypred)
    return loss

In [15]:
qml_hybrid_loss_vag = tc.backend.jit(
    tc.backend.vvag(qml_hybrid_loss, vectorized_argnums=(0, 1), argnums=2),
    static_argnums=3,
)

In [16]:
qml_hybrid_loss_vag(x_train_jax, y_train_jax, params, nlayers)

(Array([3.73282398, 0.02421603, 0.02899787, 0.02421603, 4.08996787,
        0.03069481, 0.02421603, 0.01688146, 4.08996787, 0.03069481,
        4.08996787, 0.02421603, 4.08996787, 0.02421603, 0.02899787,
        0.03354042, 0.02421603, 0.02421603, 0.01688146, 4.08996787,
        0.03354042, 0.02421603, 0.02421603, 0.03069481, 0.02421603,
        0.02421603, 0.03069481, 3.73798651, 0.02421603, 3.68810189,
        4.08996787, 0.03069481, 3.73282398, 0.03069481, 3.73282398,
        0.02421603, 3.49674264, 0.02421603, 4.08996787, 0.02899787,
        0.02421603, 0.02421603, 0.03069481, 0.03069481, 3.73282398,
        0.02533775, 0.03069481, 3.68810189, 3.73282398, 3.49896983,
        0.02899787, 0.03069481, 4.08996787, 3.41172721, 0.02421603,
        0.02421603, 3.73282398, 0.02421603, 3.73798651, 3.68810189,
        4.08996787, 0.03069481, 4.08996787, 0.02421603, 0.03069481,
        0.02421603, 3.68810189, 3.49896983, 3.49896983, 4.08996787,
        0.02421603, 0.02421603, 0.02421603, 0.02

In [19]:
optimizer = optax.adam(5e-3)
opt_state = optimizer.init(params)
for i, (xs, ys) in zip(range(2000), mnist_data):  # using tf data loader here
    xs = xs.numpy()
    ys = ys.numpy()
    v, grads = qml_hybrid_loss_vag(xs, ys, params, nlayers)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 30 == 0:
        print(jnp.mean(v))

1.273615182667982
1.124644842058637
0.7292118873826958
0.6674623788129876
0.6532813834632301
0.7178720292192107
0.6414470539370705
0.6474652457593493
0.5847641858529047
0.6065594368458376
0.6120711071926281
0.5350079057453128
0.5833160269879363
0.5850338299436212
0.5602535448981196
0.5386627879688142
0.5621691259039326
0.5350945646936454
0.46571385470218807
0.45973075771167393
0.44349699787046953
