In [5]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
from qml_hep_lhc.encodings import AmplitudeMap
from qml_hep_lhc.data import ElectronPhoton
import sympy as sp
import argparse
from qml_hep_lhc.layers.utils import get_count_of_qubits, get_num_in_symbols
from qml_hep_lhc.layers.utils import symbols_in_expr_map, resolve_formulas
from tensorflow import pad, constant
import tensorflow as tf

In [9]:
n_qubits = get_count_of_qubits('AmplitudeMap',4)
n_inputs = get_num_in_symbols('AmplitudeMap', 4)

n_qubits, n_inputs

(2, 4)

In [10]:
in_symbols = sp.symbols(f'x0:{n_inputs}')
in_symbols = np.asarray(in_symbols).reshape(1,n_inputs)

in_symbols

array([[x0, x1, x2, x3]], dtype=object)

In [11]:
qubits = cirq.GridQubit.rect(1,n_qubits)
circuit = cirq.Circuit()
circuit += AmplitudeMap().build(qubits,in_symbols[0])

User must manually normalize the input.
  "AmplitudeMap currently does not normalize the input unless padding is needed.\nUser must manually normalize the input."


In [12]:
circuit

In [13]:
circuit, expr_map = cirq.flatten(circuit)
raw_in_symbols = symbols_in_expr_map(expr_map)
data_expr = list(expr_map)

In [14]:
input_resolver = resolve_formulas(data_expr, raw_in_symbols)

In [51]:
x = np.array([[1,2,3,4]], dtype=np.float32)
d = np.sqrt(np.sum(np.square(x)))
x = x/d
x

array([[0.18257418, 0.36514837, 0.5477225 , 0.73029673]], dtype=float32)

In [52]:
resolved_x = input_resolver(x)
resolved_x

<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[0.7322796, 0.5903344, 0.7048327]], dtype=float32)>

In [53]:
resolver = cirq.ParamResolver()

In [86]:
params = { j: resolved_x[0][i].numpy()  for i,j  in enumerate(list(expr_map.values()))}
qc = cirq.resolve_parameters(circuit, params )
qc

In [87]:
qc += cirq.measure(*qubits, key='result')

In [91]:
s=cirq.Simulator()
shots = 10000
samples=s.run(qc, repetitions=shots)
res = dict(samples.histogram(key="result"))
for key, value in res.items():
    res[key] = value/shots

In [92]:
od = collections.OrderedDict(sorted(res.items()))

for k, v in od.items():
    print(k,v)

0 0.0334
1 0.1308
2 0.2986
3 0.5372


In [93]:
probs = np.square(x)[0]
for i, j in enumerate(probs):
    print(i,j)

0 0.03333333
1 0.13333333
2 0.29999995
3 0.5333333
