In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [10]:
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Flatten, Layer
import sympy
from tensorflow import string
import tensorflow as tf
import tensorflow_quantum as tfq
import cirq
import numpy as np
from qml_hep_lhc.encodings import AngleMap
from qml_hep_lhc.models.quantum.utils import one_qubit_unitary
from tensorflow import Variable, random_uniform_initializer, constant, shape, repeat, tile, concat, gather

In [15]:
from tensorflow.keras import Model, Input
import sympy
from tensorflow import string
import tensorflow_quantum as tfq
import cirq
from qml_hep_lhc.models.base_model import BaseModel
import tensorflow_quantum
from tensorflow.keras.layers import Layer, Flatten
import numpy as np
from qml_hep_lhc.encodings import AngleMap
from tensorflow import random_uniform_initializer, Variable, constant, shape, repeat, tile, concat, gather
from qml_hep_lhc.utils import _import_class
import re
from sympy.core import numbers as sympy_numbers
import numbers
import tensorflow as tf
from sympy.functions.elementary.trigonometric import TrigonometricFunction
from tensorflow_quantum.python import util


In [11]:
from qml_hep_lhc.data import ElectronPhoton, MNIST
import argparse

In [31]:
args = argparse.Namespace()

# args.graph_conv = True
# args.quantum = True
# args.pca = 64
# args.resize = [4,4]
# args.binary_data = [3,6]
# args.labels_to_categorical = True
args.percent_samples = 0.1
args.min_max = True
args.center_crop = 0.2
args.use_quantum = True
args.epochs = 3
args.validation_split = 0.2
args.num_workers = 4
# args.normalize = True
# args.min_max = True
# args.threshold = 0
# args.loss = "Hinge"
# args.hinge_labels = True

In [13]:
data = ElectronPhoton(args)
data.prepare_data()
data.setup()

Center cropping...
Min-max scaling...
Center cropping...
Min-max scaling...


In [14]:
data

Dataset :Electron Photon
╒════════╤═══════════════╤═══════════════╤═══════════╕
│ Data   │ Train size    │ Test size     │ Dims      │
╞════════╪═══════════════╪═══════════════╪═══════════╡
│ X      │ (10, 8, 8, 1) │ (10, 8, 8, 1) │ (8, 8, 1) │
├────────┼───────────────┼───────────────┼───────────┤
│ y      │ (10,)         │ (10,)         │ (1,)      │
╘════════╧═══════════════╧═══════════════╧═══════════╛

Train images stats
Min: -3.14
Max: 3.14
Mean: -1.53
Std: 2.19
Train labels stats
Min: 0.00
Max: 1.00

In [21]:
class QLinear(Layer):

    def __init__(self, input_dim, fm_class):
        super(QLinear, self).__init__()

        self.dim = np.prod(input_dim)
        self.fm_class = fm_class

        # Prepare qubits
        self.n_qubits = self.dim
        self.n_in_symbols = self.dim

        if self.fm_class == 'AmplitudeMap':
            self.n_qubits = int(np.ceil(np.log2(self.dim)))
            self.n_in_symbols = 2**self.n_qubits - 1

        self.qubits = cirq.GridQubit.rect(1, self.n_qubits)
        self.readout = cirq.GridQubit(-1, -1)
        self.observables = [cirq.Z(self.readout)]

        var_symbols = sympy.symbols(f'qnn0:{2*self.n_qubits}')
        self.var_symbols = np.asarray(var_symbols).reshape((self.n_qubits, 2))

        in_symbols = sympy.symbols(f'x0:{self.n_in_symbols}')
        self.in_symbols = np.asarray(in_symbols).reshape((self.n_in_symbols))

    def build(self, input_shape):
        self.data_circuit = cirq.Circuit()
        self.model_circuit = cirq.Circuit()

        # Prepare the readout qubit
        self.model_circuit.append(cirq.X(self.readout))
        self.model_circuit.append(cirq.H(self.readout))

        self.fm = _import_class(f"qml_hep_lhc.encodings.{self.fm_class}")()
        self.data_circuit += self.fm.build(self.qubits, self.in_symbols)

        for i, qubit in enumerate(self.qubits):
            self.model_circuit.append(
                cirq.XX(qubit, self.readout)**self.var_symbols[i, 0])
        for i, qubit in enumerate(self.qubits):
            self.model_circuit.append(
                cirq.ZZ(qubit, self.readout)**self.var_symbols[i, 1])

        # Finally, prepare the readout qubit.
        self.model_circuit.append(cirq.H(self.readout))

        self.var_symbols = list(self.var_symbols.flat)
        self.in_symbols = list(self.in_symbols.flat)

        var_init = random_uniform_initializer(minval=-np.pi / 2,
                                              maxval=np.pi / 2)
        self.theta = Variable(initial_value=var_init(
            shape=(1, len(self.var_symbols)), dtype="float32"),
                              trainable=True,
                              name="thetas")

        # Define explicit symbol order
        symbols = [str(symb) for symb in self.var_symbols + self.in_symbols]
        self.indices = constant([symbols.index(a) for a in sorted(symbols)])

        print('-' * 100)

        print(self.data_circuit)
        print(get_circuit_symbols(self.data_circuit))
        print(type(get_circuit_symbols(self.data_circuit)))
        _, expr_map = cirq.flatten(self.data_circuit)
        print(list(expr_map))

        self.base_input_symbols_list = get_circuit_symbols(self.data_circuit)
        self.expr_map = list(expr_map)
        # print(resolve_formulas(self.expr_map, self.base_input_symbols_list)())
        print(util.convert_to_tensor([self.data_circuit]))
        print('-' * 100)

        self.empty_circuit = tfq.convert_to_tensor([cirq.Circuit()])
        self.computation_layer = tfq.layers.ControlledPQC(
            self.model_circuit, self.observables)

    def call(self, input_tensor):
        print(input_tensor)
        print("=" * 100)
        print(
            resolve_formulas(self.expr_map,
                             self.base_input_symbols_list)(input_tensor))
        print("=" * 100)
        resolved_inputs = resolve_formulas(
            self.expr_map, self.base_input_symbols_list)(input_tensor)

        batch_dim = shape(input_tensor)[0]
        # rqc = cirq.resolve_parameters(self.data_circuit, { j: 1 for i,j in enumerate(input)})

        tiled_up_circuits = repeat(self.data_circuit,
                                   repeats=batch_dim,
                                   name="tiled_up_circuits")
        print(tiled_up_circuits)
        print(tiled_up_circuits.shape)
        tiled_up_thetas = tile(self.theta,
                               multiples=[batch_dim, 1],
                               name="tiled_up_thetas")
        joined_vars = concat([tiled_up_thetas, input_tensor], axis=-1)
        joined_vars = gather(joined_vars,
                             self.indices,
                             axis=-1,
                             name='joined_vars')
        out = self.computation_layer([tiled_up_circuits, joined_vars])
        return out


class QNN(BaseModel):
    """
    Quantum Neural Network.
    This implementation is based on https://www.tensorflow.org/quantum/tutorials/mnist
    """

    def __init__(self, data_config, args=None):
        super().__init__(args)
        self.args = vars(args) if args is not None else {}

        # Data config
        self.input_dim = data_config["input_dims"]
        self.n_qubits = np.prod(self.input_dim)
        fm_class = self.args.get("feature_map")
        if fm_class is None:
            fm_class = "AngleMap"
        self.qlinear = QLinear(self.input_dim, fm_class)

    def call(self, input_tensor):
        """
        The function takes in an input tensor and returns the expectation of the input tensor
        
        Args:
          input_tensor: The input tensor to the layer.
        
        Returns:
          The expectation of the input tensor.
        """
        x = Flatten()(input_tensor)
        out = self.qlinear(x)
        return out

    def build_graph(self):
        # x = Input(shape=(), dtype=string)
        x = Input(shape=self.input_dim)
        return Model(inputs=[x], outputs=self.call(x), name="QNN")

    @staticmethod
    def add_to_argparse(parser):
        return parser


def natural_key(symbol):
    '''Keys for human sorting
    Reference:
    http://nedbatchelder.com/blog/200712/human_sorting.html
    '''
    return [atoi(s) for s in re.split(r'(\d+)', symbol.name)]


def get_circuit_symbols(circuit, to_str=True, sort_key=natural_key):
    """Returns a list of parameter symbols in a circuit
    
    Arguments:
        circuit: cirq.Circuit, quple.QuantumCircuit
            The circuit to find the associated parameter symbols
        to_str: boolean, default=True
            Whether to convert symbol to strings
        sort_key:
            Sort key for the list of symbols
    Returns:
        A list of symbols in the circuit
    """
    all_symbols = set()
    for moment in circuit:
        for op in moment:
            if cirq.is_parameterized(op):
                all_symbols |= symbols_in_op(op.gate)
    sorted_symbols = sorted(list(all_symbols), key=sort_key)
    if to_str:
        return [str(x) for x in sorted_symbols]
    return sorted_symbols


def symbols_in_op(op):
    """Returns the set of symbols associated with a parameterized gate operation.
    
    Arguments:
        op: cirq.Gate
            The parameterised gate operation to find the set of symbols associated with
    
    Returns:
        Set of symbols associated with the parameterized gate operation
    """
    if isinstance(op, cirq.EigenGate):
        return op.exponent.free_symbols

    if isinstance(op, cirq.FSimGate):
        ret = set()
        if isinstance(op.theta, sympy.Basic):
            ret |= op.theta.free_symbols
        if isinstance(op.phi, sympy.Basic):
            ret |= op.phi.free_symbols
        return ret

    if isinstance(op, cirq.PhasedXPowGate):
        ret = set()
        if isinstance(op.exponent, sympy.Basic):
            ret |= op.exponent.free_symbols
        if isinstance(op.phase_exponent, sympy.Basic):
            ret |= op.phase_exponent.free_symbols
        return ret

    raise ValueError("Attempted to scan for symbols in circuit with unsupported"
                     " ops inside. Expected op found in tfq.get_supported_gates"
                     " but found: ".format(str(op)))


def atoi(symbol):
    return int(symbol) if symbol.isdigit() else symbol


def stack(func, lambda_set, intermediate=None):
    if intermediate is None:
        return stack(func, lambda_set[1:], lambda_set[0])
    if len(lambda_set) > 0:
        new_lambda = lambda x: func(intermediate(x), lambda_set[0](x))
        return stack(func, lambda_set[1:], new_lambda)
    else:
        return intermediate


def resolve_formulas(formulas, symbols):
    lambda_set = [resolve_formula(f, symbols) for f in formulas]
    stacked_ops = stack(lambda x, y: tf.concat((x, y), 0), lambda_set)
    n_formula = tf.constant([len(formulas)])
    transposed_x = lambda x: tf.transpose(
        x, perm=tf.roll(tf.range(tf.rank(x)), shift=1, axis=0))
    resolved_x = lambda x: stacked_ops(transposed_x(x))
    reshaped_x = lambda x: tf.reshape(
        resolved_x(x),
        tf.concat(
            (n_formula, tf.strided_slice(tf.shape(x), begin=[0], end=[-1])),
            axis=0))
    transformed_x = lambda x: tf.transpose(
        reshaped_x(x), perm=tf.roll(tf.range(tf.rank(x)), shift=-1, axis=0))
    return transformed_x


def resolve_formula(formula, symbols):

    tf_ops_map = {
        sympy.sin: tf.sin,
        sympy.cos: tf.cos,
        sympy.tan: tf.tan,
        sympy.asin: tf.asin,
        sympy.acos: tf.acos,
        sympy.atan: tf.atan,
        sympy.atan2: tf.atan2,
        sympy.cosh: tf.cosh,
        sympy.tanh: tf.tanh,
        sympy.sinh: tf.sinh
    }

    # Input is a pass through type, no resolution needed: return early
    value = resolve_value(formula)
    if value is not NotImplemented:
        return lambda x: value

    # Handles 2 cases:
    # formula is a string and maps to a number in the dictionary
    # formula is a symbol and maps to a number in the dictionary
    # in both cases, return it directly.
    if formula in symbols:
        index = symbols.index(formula)
        return lambda x: x[index]

    # formula is a symbol (sympy.Symbol('a')) and its string maps to a number
    # in the dictionary ({'a': 1.0}).  Return it.
    if isinstance(formula, sympy.Symbol) and formula.name in symbols:
        index = symbols.index(formula.name)
        return lambda x: x[index]

    # the following resolves common sympy expressions
    if isinstance(formula, sympy.Add):
        addents = [resolve_formula(arg, symbols) for arg in formula.args]
        return stack(tf.add, addents)

    if isinstance(formula, sympy.Mul):
        factors = [resolve_formula(arg, symbols) for arg in formula.args]
        return stack(tf.multiply, factors)

    if isinstance(formula, sympy.Pow) and len(formula.args) == 2:
        base = resolve_formula(formula.args[0], symbols)
        exponent = resolve_formula(formula.args[1], symbols)
        return lambda x: tf.pow(base(x), exponent(x))

    if isinstance(formula, sympy.Pow):
        base = resolve_formula(formula.args[0], symbols)
        exponent = resolve_formula(formula.args[1], symbols)
        return lambda x: tf.pow(base(x), exponent(x))

    if isinstance(formula, TrigonometricFunction):
        ops = tf_ops_map.get(type(formula), None)
        if ops is None:
            raise ValueError("unsupported sympy operation: {}".format(
                type(formula)))
        arg = resolve_formula(formula.args[0], symbols)
        return lambda x: ops(arg(x))


def resolve_value(val):
    if isinstance(val, numbers.Number) and not isinstance(val, sympy.Basic):
        return tf.constant(float(val), dtype=tf.float32)
    elif isinstance(val,
                    (sympy_numbers.IntegerConstant, sympy_numbers.Integer)):
        return tf.constant(float(val.p), dtype=tf.float32)
    elif isinstance(val,
                    (sympy_numbers.RationalConstant, sympy_numbers.Rational)):
        return tf.divide(tf.constant(val.p, dtype=tf.float32),
                         tf.constant(val.q, dtype=tf.float32))
    elif val == sympy.pi:
        return tf.constant(np.pi, dtype=tf.float32)
    else:
        return NotImplemented


In [33]:
model = QNN(data.config(), args)

use quantum


In [34]:
model.build_graph().summary()

----------------------------------------------------------------------------------------------------
(0, 0): ────Rx(sqrt(x0*x1))─────Rx(x0**2)───

(0, 1): ────Rx(sqrt(x1*x2))─────────────────

(0, 2): ────Rx(sqrt(x2*x3))─────────────────

(0, 3): ────Rx(sqrt(x3*x4))─────────────────

(0, 4): ────Rx(sqrt(x4*x5))─────────────────

(0, 5): ────Rx(sqrt(x5*x6))─────────────────

(0, 6): ────Rx(sqrt(x6*x7))─────────────────

(0, 7): ────Rx(sqrt(x7*x8))─────────────────

(0, 8): ────Rx(sqrt(x8*x9))─────────────────

(0, 9): ────Rx(sqrt(x10*x9))────────────────

(0, 10): ───Rx(sqrt(x10*x11))───────────────

(0, 11): ───Rx(sqrt(x11*x12))───────────────

(0, 12): ───Rx(sqrt(x12*x13))───────────────

(0, 13): ───Rx(sqrt(x13*x14))───────────────

(0, 14): ───Rx(sqrt(x14*x15))───────────────

(0, 15): ───Rx(sqrt(x15*x16))───────────────

(0, 16): ───Rx(sqrt(x16*x17))───────────────

(0, 17): ───Rx(sqrt(x17*x18))───────────────

(0, 18): ───Rx(sqrt(x18*x19))───────────────

(0, 19): ───Rx(sqrt(x19*x

ValueError: Arithmetic expression outside of simple scalar multiplication is currently not supported. See serializer.py for more information.

In [35]:
loss_fn = tf.keras.losses.MeanSquaredError
optimizer = tf.keras.optimizers.Adam

In [36]:
@tf.function
def custom_accuracy(y_true, y_pred):
    y_true = tf.squeeze(y_true)
    y_pred = tf.map_fn(lambda x: 1.0 if x >= 0 else -1.0, y_pred)
    return tf.keras.backend.mean(tf.keras.backend.equal(y_true, y_pred))

In [37]:
model.compile()

In [38]:
model.fit(data, callbacks=[])

Epoch 1/3


TypeError: Exception encountered when calling layer "qnn_3" (type QNN).

list indices must be integers or slices, not tuple

Call arguments received:
  • input_tensor=tf.Tensor(shape=(8, 8, 8, 1), dtype=float32)