In [None]:
pip install pennylane

Collecting pennylane
  Downloading PennyLane-0.37.0-py3-none-any.whl (1.8 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/1.8 MB[0m [31m8.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m30.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
Collecting rustworkx (from pennylane)
  Downloading rustworkx-0.15.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m42.4 MB/s[0m eta [36m0:00:00[0m
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)
Collecting semantic-version>=2.7 (from pennylane)
  Downloading semantic_versi

In [None]:
pip install -U jax

Collecting jax
  Downloading jax-0.4.30-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxlib<=0.4.30,>=0.4.27 (from jax)
  Downloading jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl (79.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 MB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
    Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
      Successfully uninstalled jaxlib-0.4.26+cuda12.cudnn89
  Attempting uninstall: jax
    Found existing installation: jax 0.4.26
    Uninstalling jax-0.4.26:
      Successfully uninstalled jax-0.4.26
Successfully installed jax-0.4.30 jaxlib-0.4.30


In [None]:
pip install jaxopt

Collecting jaxopt
  Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.3/172.3 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxopt
Successfully installed jaxopt-0.8.3


In [None]:
import pennylane as qml
from jax import numpy as np
import jax

In [None]:
dev1 = qml.device("lightning.qubit", wires=1)     #Initializing a device using qubit model

In [None]:
@qml.qnode(dev1)    #Constructing a quantum node bound to our device
def circuit(params):      #Quantum function evaluated in the qnode
    qml.RX(params[0], wires=0)   # Applying a RX gate with its angle as parameter
    qml.RY(params[1], wires=0)   # Applying RY gate with its angle as parameter
    return qml.expval(qml.PauliZ(0))   # Measuring expectation value of the Pauli-Z operator (lies between -1 and 1)

In [None]:
params = np.array([0.54, 0.12])   # Sample run of qnode on the device
print(circuit(params))

0.85154057


# Calculating Quantum Gradients

In [None]:
dcircuit = jax.grad(circuit, argnums=0)   # returns a function representing the gradient (i.e., the vector of partial derivatives) of given quantum function

In [None]:
print(dcircuit(params))   # Sample run

[-0.5104387  -0.10267819]


# Optimization

In [None]:
def cost(x):  # Cost function to minimize is our quantum function itself(As its minimum value -1 is what is desired)
    return circuit(x)

In [None]:
init_params = np.array([0.011, 0.012])  # Choosing small initial values of the parameters
print(cost(init_params))

0.9998675


In [None]:
import jaxopt

# initialise the optimizer
opt = jaxopt.GradientDescent(cost, stepsize=0.4, acceleration = False)   # using gradient descent optimizer to update the circuit parameters



# set the number of steps
steps = 100
# set the initial parameter values
params = init_params
opt_state = opt.init_state(params)  # Starting state for the optimizer

for i in range(steps):
    # update the circuit parameters
    params, opt_state = opt.update(params, opt_state)

    if (i + 1) % 5 == 0:
        print("Cost after step {:5d}: {: .7f}".format(i + 1, cost(params)))   # Printing value of the cost function after each step

print("Optimized rotation angles: {}".format(params))    #Final optimized value of parameters

Cost after step     5:  0.9961779
Cost after step    10:  0.8974943
Cost after step    15:  0.1440490
Cost after step    20: -0.1536721
Cost after step    25: -0.9152496
Cost after step    30: -0.9994046
Cost after step    35: -0.9999964
Cost after step    40: -1.0000000
Cost after step    45: -1.0000000
Cost after step    50: -1.0000000
Cost after step    55: -1.0000000
Cost after step    60: -1.0000000
Cost after step    65: -1.0000000
Cost after step    70: -1.0000000
Cost after step    75: -1.0000000
Cost after step    80: -1.0000000
Cost after step    85: -1.0000000
Cost after step    90: -1.0000000
Cost after step    95: -1.0000000
Cost after step   100: -1.0000000
Optimized rotation angles: [7.1526556e-18 3.1415925e+00]
