In [1]:
import pennylane as qml
from pennylane import numpy as np

In [2]:
dev = qml.device(name = 'default.qubit', wires = 5, shots = 1000)

In [3]:
@qml.qnode(dev)

# RYRZ linear entanglement
def circuit(param, wires):
    for w in range(wires):
        qml.RY(param [0, w], wires = w)
        qml.RZ(param[1, w], wires = w)
        
    for w in range(wires-1):
        qml.CNOT(wires = [w, w+1])
        
    for w in range(wires):
        qml.RY(param [0, w+5], wires = w)
        qml.RZ(param[1, w+5], wires = w)
        
    for w in range(wires-1):
        qml.CNOT(wires = [w, w+1])
        
    return qml.expval(qml.PauliZ(wires = 0))

In [12]:
def psr(qnode, param, wires, i, j):
    shift_param_f = param.copy()
    shift_param_b = param.copy()
    shift = np.pi/2
    
    # Forward
    shift_param_f[i, j] = shift_param_f[i, j] + shift
    forward = qnode(shift_param_f, wires)
    
    # Backward
    shift_param_b[i, j] = shift_param_b[i, j] - shift
    backward = qnode(shift_param_b, wires)
    
    gradient = (forward - backward) / (2*np.sin(shift))
    
    return gradient

In [16]:
wires = 5
np.random.seed(100)
param = np.random.uniform(low = -np.pi/2, high = np.pi, size = (2, 10))
gradient = np.zeros_like(param)

# print(gradient)
# print(param)
# print(param[0])

# [[np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi], [np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi, np.pi]]

print(circuit(param, wires))
print(circuit.draw())

for i in range(2):
    for j in range(len(param)):
        gradient[i, j] = psr(circuit, param, wires, i, j)
    
print(gradient)

# print(circuit(gradient, wires))
# print(circuit.draw())

0.42821777499282
 0: ──RY(0.99)────RZ(2.63)────╭C───RY(-0.998)───RZ(3.04)─────────────╭C────────────────────────────────┤ ⟨Z⟩ 
 1: ──RY(-0.259)──RZ(-0.585)──╰X──╭C────────────RY(1.59)───RZ(2.25)──╰X───────────╭C───────────────────┤     
 2: ──RY(0.43)────RZ(-0.697)──────╰X───────────╭C──────────RY(2.32)───RZ(-0.761)──╰X───────────╭C──────┤     
 3: ──RY(2.41)────RZ(-1.06)────────────────────╰X─────────╭C──────────RY(-0.927)───RZ(2.28)────╰X──╭C──┤     
 4: ──RY(-1.55)───RZ(-0.535)──────────────────────────────╰X──────────RY(1.14)─────RZ(-0.279)──────╰X──┤     

[[-0.36731758 -0.49353969  0.          0.          0.          0.
   0.          0.          0.          0.        ]
 [ 0.07351399  0.08660538  0.          0.          0.          0.
   0.          0.          0.          0.        ]]
