***Gradients: Adjoint differentiation***  500 points

In the last two problems, we dove into the parameter-shift rule. It's an elegant, simple, and extremely useful way to differentiate quantum circuits because of its hardware compatibility. There are other differentiation methods that exist and are very efficient, but are not hardware compatible in general. We'll describe a way to perform one of those methods in this challenge: adjoint differentiation.


![title](gradients_adjDiff1.jpg)

![title](gradients_adjDiff2.jpg)

![title](gradients_adjDiff3.jpg)

where the freedom for how to define each bra and ket is left to you! Surely, if four PennyLane circuits are made that create each bra and ket, the derivative can be calculated with two inner products added together and multiplied by a coefficient ...

For the keen readers: do you actually need four circuits 🧐?

***Challenge code***

In the code below, you are given a few functions:

![title](gradients_adjDiff4.jpg)

***Inputs***

As input to this problem, you are given:

![title](gradients_adjDiff5.jpg)

***Outputs***

This code will output the derivative, a float, of the circuit with respect to the given parameter.

If your solution matches the correct one within the given tolerance specified in check (in this case it's a 1e-4 relative error tolerance), the output will be "Correct!". Otherwise, you will receive a "Wrong answer" prompt.

Good luck!

***Code***

In [10]:
import functools
import json
import math
import pandas as pd
import pennylane as qml
import pennylane.numpy as np
import scipy

In [11]:
def generator_info(operator):
    """Provides the generator of a given operator.

    Args:
        operator (qml.ops): A PennyLane operator

    Returns:
        (qml.ops): The generator of the operator.
        (float): The coefficient of the generator.
    """
    gen = qml.generator(operator, format="observable")
    return gen.ops[0], gen.coeffs[0]

In [12]:
def derivative(op_order, params, diff_idx, wires, measured_wire):
    """A function that calculates the derivative of a circuit w.r.t. one parameter.

    NOTE: you cannot use qml.grad in this function.

    Args:
        op_order (list(int)):
            This is a list of integers that defines the circuit in question.
            The entries of this list correspond to dictionary keys to op_dict.
            For example, [1,0,2] means that the circuit in question contains
            an RY gate, an RX gate, and an RZ gate in that order.

        params (np.array(float)):
            The parameters that define the gates in the circuit. In this case,
            they're all rotation angles.

        diff_idx (int):
            The index of the gate in the circuit that is to be differentiated
            with respect to. For instance, if diff_idx = 2, then the derivative
            of the third gate in the circuit will be calculated.

        wires (list(int)):
            A list of wires that each gate in the circuit will be applied to.

        measured_wire (int):
            The expectation value that needs to be calculated is with respect
            to the Pauli Z operator. measured_wire defines what wire we're
            measuring on.

    Returns:
        float: The derivative evaluated at the given parameters.
    """
    op_dict = {0: qml.RX, 1: qml.RY, 2: qml.RZ}
    dev = qml.device("default.qubit", wires=2)

    obs = qml.PauliZ(measured_wire)
    operator = op_dict[op_order[diff_idx]](params[diff_idx], wires[diff_idx])
    gen, coeff = generator_info(operator)

    @qml.qnode(dev)
    def circuit_bra1():

        # Put your code here #
        for i in range(len(op_order)):
            if i == diff_idx:
                qml.apply(gen)
                qml.apply(operator)
                
            else:
                op_dict[op_order[i]](params[i], wires[i])
                
        return qml.state()

    @qml.qnode(dev)
    def circuit_ket1():

        # Put your code here #
        for ii in range(len(op_order)):
            op_dict[op_order[ii]](params[ii], wires[ii])
            
        qml.PauliZ(measured_wire)
        
        return qml.state()

    @qml.qnode(dev)
    def circuit_bra2():

        # Put your code here #
        for jj in range(len(op_order)):
            op_dict[op_order[jj]](params[jj], wires[jj])
            
        return qml.state()

    @qml.qnode(dev)
    def circuit_ket2():

        # Put your code here #
        for j in range(len(op_order)):
            if j == diff_idx:
                qml.apply(operator)
                qml.apply(gen)                
            else:
                op_dict[op_order[j]](params[j], wires[j])
                
        qml.PauliZ(measured_wire)

        return qml.state()

    bra1 = circuit_bra1()
    ket1 = circuit_ket1()
    bra2 = circuit_bra2()
    ket2 = circuit_ket2()

    return  (1j*coeff*((-1)*np.vdot(bra1,ket1)+np.vdot(bra2,ket2))).real # Put your code here #

In [13]:
# These functions are responsible for testing the solution.

def run(test_case_input: str) -> str:
    op_order, params, diff_idx, wires, measured_wire = json.loads(test_case_input)
    params = np.array(params, requires_grad=True)
    der = derivative(op_order, params, diff_idx, wires, measured_wire)
    return str(der)

def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    assert np.allclose(
        solution_output, expected_output, rtol=1e-4
    ), "Your derivative isn't quite right!"

In [14]:
test_cases = [['[[1,0,2,1,0,1], [1.23, 4.56, 7.89, 1.23, 4.56, 7.89], 0, [1, 0, 1, 1, 1, 0], 1]', '-0.2840528']]

In [15]:
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[1,0,2,1,0,1], [1.23, 4.56, 7.89, 1.23, 4.56, 7.89], 0, [1, 0, 1, 1, 1, 0], 1]'...
Correct!
