In [16]:
import pennylane as qml
from scipy.stats import unitary_group
from pennylane import numpy as np

In [3]:
dev = qml.device("default.qubit", wires=3)

U1 = unitary_group.rvs(4)
U2 = unitary_group.rvs(4)

In [5]:
with qml.tape.QuantumTape() as tape:
    qml.QubitUnitary(U1, wires=[0, 1])
    qml.QubitUnitary(U2, wires=[1, 2])
    qml.probs()

In [6]:
tape.operations

[QubitUnitary(array([[ 0.00779013-0.01666915j, -0.36930122+0.24534717j,
          0.1458749 -0.41092026j, -0.66516149+0.41292623j],
        [-0.64012655-0.34188968j,  0.17840924-0.03056921j,
          0.24723446+0.50490685j, -0.17768434+0.30488946j],
        [ 0.20979937+0.64549479j,  0.23467605+0.22889363j,
         -0.1245952 +0.52465844j, -0.36961346+0.06671577j],
        [-0.10605003-0.03293913j,  0.81436049-0.0029827j ,
          0.04153679-0.44865743j, -0.28862948-0.19532212j]]), wires=[0, 1]),
 QubitUnitary(array([[-0.44423474-0.32846843j,  0.10411459-0.60787909j,
         -0.4150412 -0.13308091j, -0.31221463-0.16419333j],
        [ 0.0322502 +0.01649398j, -0.63866283+0.09709474j,
         -0.10444914-0.40949807j,  0.1697382 -0.61152344j],
        [ 0.20954259+0.68929193j,  0.31847303-0.3059779j ,
          0.03139871+0.13318071j, -0.16169231-0.49097211j],
        [ 0.35371033+0.2220572j ,  0.06340435-0.05707805j,
         -0.74646364-0.23558658j,  0.31826407+0.32295599j]]), wir

In [7]:
expanded_tape = qml.transforms.unitary_to_rot.tape_fn(tape)

In [9]:
expanded_tape.operations

[Rot(array(-1.38913399), array(1.01970713), array(3.42668141), wires=[0]),
 Rot(array(-4.23939011), array(2.67411729), array(-0.03489373), wires=[1]),
 CNOT(wires=[1, 0]),
 RZ(0.15033000679074382, wires=[0]),
 RY(-0.6531853319509663, wires=[1]),
 CNOT(wires=[0, 1]),
 RY(-1.8309629027753909, wires=[1]),
 CNOT(wires=[1, 0]),
 Rot(array(1.13116896), array(1.75394267), array(-3.68634873), wires=[1]),
 Rot(array(1.39085061), array(1.08226276), array(-3.92630864), wires=[0]),
 Rot(array(-2.29633497), array(1.18777498), array(-0.68979513), wires=[1]),
 Rot(array(-0.03439678), array(2.20200691), array(5.4201391), wires=[2]),
 CNOT(wires=[2, 1]),
 RZ(0.1088442502568816, wires=[1]),
 RY(-0.6112932324074063, wires=[2]),
 CNOT(wires=[1, 2]),
 RY(-0.8366700004621568, wires=[2]),
 CNOT(wires=[2, 1]),
 Rot(array(-1.34977076), array(1.35615617), array(-0.00550825), wires=[2]),
 Rot(array(0.1673565), array(1.97030036), array(-5.58440811), wires=[1])]

In [24]:
def tw0_qubit_template(params):
    ...

@qml.qnode(dev)
def my_qnode(params):
    idx = 0
    for op in expanded_tape.operations:
        if op.num_params > 0:
            if op.name == "Rot":
                op.__class__(*params[idx:idx+3], wires=op.wires)
                idx += 3
            else:
                op.__class__(params[idx], wires=op.wires)
                idx += 1
        else:
            op.__class__(wires=op.wires)
    return qml.probs()

In [28]:
params = np.array(qml.math.stack(expanded_tape.data), requires_grad=True)

In [30]:
qml.jacobian(my_qnode)(params)

array([[-1.04083409e-17, -4.47429986e-02,  5.76977201e-02,
        -1.12757026e-17, -5.03054114e-02,  6.35442354e-03,
        -1.14327984e-01,  9.56227013e-02, -1.29894292e-02,
        -9.52183783e-02,  8.24783560e-02,  6.19254186e-03,
         9.28033332e-02,  4.47875538e-02,  2.77555756e-17,
         6.19254186e-03,  5.04578869e-02,  1.09738615e-01,
         2.60208521e-17,  7.55596442e-02, -1.19645542e-02,
        -1.41725362e-01, -1.11721891e-01,  6.59518086e-03,
         7.81320397e-02, -5.32048076e-02,  1.38777878e-17,
        -1.22211632e-02,  2.90169683e-02,  0.00000000e+00],
       [ 2.42861287e-17,  3.31358148e-02,  5.97539647e-02,
         2.77555756e-17,  2.45334084e-02, -2.85562126e-03,
        -6.72784735e-02, -1.82282893e-03, -3.01946228e-02,
        -4.24534161e-03, -9.02159565e-03, -2.53587822e-04,
        -1.14301489e-02,  1.76247182e-02,  8.67361738e-18,
        -2.53587822e-04, -9.90921640e-03,  1.29519394e-03,
         0.00000000e+00,  1.23379603e-01, -6.03451326e-