In [None]:
# general imports
import matplotlib.pyplot as plt

# magic word for producing visualizations in notebook
%matplotlib widget

from braket.circuits import Circuit
from braket.devices import LocalSimulator

import pennylane as qml
from pennylane import numpy as np

In [None]:
wires = 4
layers = 2
device = qml.device("default.qubit", wires=wires)

In [None]:
def rotations(params, wire):
    qml.RZ(params[0], wires=wire)
    qml.RY(params[1], wires=wire)
    qml.RZ(params[2], wires=wire)

def entangle():
    if wires <= 0:
        return
    for ii in range(wires):
        qml.CNOT(wires=[ii, (ii+1) % wires])

def training_layer(params):
    for ii in range(wires):
        rotations(params[ii,:], ii)
    entangle()

def encoding_layer(params):
    for ii in range(wires):
        qml.RX(params[ii], wires=ii)

In [None]:
@qml.qnode(device, diff_method="adjoint")
def circuit(enc_params, rot_params):
    for ii in range(layers):
        encoding_layer(enc_params[ii])
        training_layer(rot_params[ii])
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(wires-1))

In [None]:
init_enc_params = np.array([0.1]*wires, requires_grad=True)
init_enc_params = np.tile(init_enc_params, (layers,) + (1,))
init_rot_params = np.array([[0.1, 0.2, 0.3]]*wires, requires_grad=True)
init_rot_params = np.tile(init_rot_params, (layers,) + (1,1))
print(init_enc_params)
print(init_rot_params)

In [None]:
print("Drawing of circuit:\n")
fig, ax = qml.draw_mpl(circuit)(init_enc_params, init_rot_params)

In [None]:
opt = qml.GradientDescentOptimizer(stepsize=0.1)

In [None]:
iterations = 20

costs = []

enc_params = init_enc_params
rot_params = init_rot_params

for i in range(iterations):
    params, cost = opt.step_and_cost(circuit, enc_params, rot_params)
    enc_params, rot_params = params
    costs.append(cost)

# Visualize results
costs.append(circuit(enc_params, rot_params))
plt.figure()
plt.plot(costs, "-o")
plt.xlabel("Iterations")
plt.ylabel("Cost")

print("Minimized circuit output:", circuit(enc_params, rot_params))
print("Optimized parameters:", params)