# New Device API Prototype

In [1]:
import pennylane as qml
from pennylane import numpy as np
from pennylane.devices.experimental import TestDevicePythonSim

Note: this only works with the new return types workflow:

In [2]:
qml.enable_return()

In [3]:
dev = TestDevicePythonSim()

Add some attributes and properties to match existing required interface.

We will have to adjust the existing workflow to accomodate these changes.

In [4]:
dev.batch_execute = dev.execute

dev.batch_transform = dev.preprocess
dev.expand_fn = lambda circuit, max_expansion: circuit

dev.shots = None
dev._shot_vector = []
dev.shot_vector = None
dev.short_name = "testpython"

### Let's try a device gradient!

In [33]:
@qml.qnode(dev, diff_method="device")
def circuit(a):
    ops =[qml.RX(a[0], wires=0),
    qml.CNOT(wires=(0,1)),
    qml.RY(a[1], wires=1),
    qml.RZ(a[2], wires=1)]
    return qml.expval(qml.PauliX(1))

x = qml.numpy.array([1.2, 2.3, 3.4])

In [34]:
circuit(x)

[-0.26124053720169715]

In [35]:
qml.grad(circuit)(x)

array([0.67195027, 0.23341436, 0.06905029])

## Parameter Shift?

In [36]:
@qml.qnode(dev, diff_method=qml.gradients.param_shift)
def circuit(a):
    ops =[qml.RX(a[0], wires=0),
    qml.CNOT(wires=(0,1)),
    qml.RY(a[1], wires=1),
    qml.RZ(a[2], wires=1)]
    return qml.expval(qml.PauliX(1))

x = qml.numpy.array([1.2, 2.3, 3.4])

In [37]:
qml.grad(circuit)(x)

array([0.67195027, 0.23341436, 0.06905029])

### How about backprop with Jax jit?

In [40]:
import jax
from jax import numpy as jnp

In [41]:
x = jnp.array([1.2, 2.3, 3.4])

@jax.jit
@qml.qnode(dev, interface="jax", diff_method="backprop")
def circuit(a):
    ops =[qml.RX(a[0], wires=0),
    qml.CNOT(wires=(0,1)),
    qml.RY(a[1], wires=1),
    qml.RZ(a[2], wires=1)]
    return qml.expval(qml.PauliX(1))

In [42]:
jax.jacobian(circuit)(x)

[DeviceArray([0.67195027, 0.23341436, 0.06905029], dtype=float64)]

No substitution of device at QNode level!

Device just dispatches to a different simulator.

In [13]:
circuit.device is dev

True

## More complicated measurement processes?

Simulator can just use the `StateMeasurement.process_state` method.

In [14]:
@qml.qnode(dev, diff_method=None)
def circuit_mutual(x):
    qml.IsingXX(x, wires=[0, 1])
    return qml.mutual_info(wires0=[0], wires1=[1])

circuit_mutual(np.pi/2)

[1.3862943611198906]

## Device Tracking

In [15]:
@qml.qnode(dev, diff_method="device")
def circuit(a):
    ops =[qml.RX(a[0], wires=0),
    qml.CNOT(wires=(0,1)),
    qml.RY(a[1], wires=1),
    qml.RZ(a[2], wires=1)]
    return qml.expval(qml.PauliX(1))

x = qml.numpy.array([1.2, 2.3, 3.4])

In [16]:
def callback(totals=None, history=None, latest=None):
    print("Totals: ", totals)

with qml.Tracker(dev, callback=callback) as tracker:
    circuit(x)
    qml.grad(circuit)(x)

Totals:  {'batches': 1, 'batch_len': 1}
Totals:  {'batches': 1, 'batch_len': 1, 'executions': 1, 'results': -0.26124053720169715}
Totals:  {'batches': 1, 'batch_len': 1, 'executions': 1, 'results': -0.26124053720169715, 'gradients': 1}
Totals:  {'batches': 2, 'batch_len': 2, 'executions': 1, 'results': -0.26124053720169715, 'gradients': 1}
Totals:  {'batches': 2, 'batch_len': 2, 'executions': 2, 'results': -0.5224810744033943, 'gradients': 1}
Totals:  {'batches': 2, 'batch_len': 2, 'executions': 2, 'results': -0.5224810744033943, 'gradients': 2}


## Native execution of non-commuting observables?

Easily handled at the simulator level.

Diagonalizing gates are handled when taking a measurement, not when executing the circuit

In [65]:
@qml.qnode(dev, diff_method=None)
def circuit(a):
    qml.RX(a, 0)
    return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0))

with qml.Tracker(dev) as tracker:
    print("Execution: ", circuit(1.2))
    
tracker.totals['executions']

Execution:  [(0.0, 0.36235775447667357)]


1

## Arbitrary wire labels?

In [18]:
@qml.qnode(dev)
def circuit(a):
    qml.RX(a, "a")
    return qml.expval(qml.PauliZ("a"))

circuit(1.2)

[0.36235775447667357]

Preprocessing can map wires to adjacent integers starting from zero. Then simulators can just
treat wire labels as indices!

In [32]:
qs = qml.tape.QuantumScript([qml.PauliX("a"), qml.PauliY(10)])
qbatch, post_processing_fn = dev.preprocess(qs)
qbatch[0].circuit

[PauliX(wires=[0]), PauliY(wires=[1])]

# Preprocessing of Script

In [45]:
@qml.qnode(dev, diff_method=None)
def circuit(params):
    qml.StronglyEntanglingLayers(params, wires=(0,1,2,3))
    return qml.expval(qml.PauliZ(3))

In [46]:
n_layers = 4
shape = qml.StronglyEntanglingLayers.shape(n_layers=n_layers, n_wires=4)

rng = np.random.default_rng(seed=42)
params = rng.random(shape)

In [47]:
circuit(params)

[0.246704388316073]

Preprocessing expands till it reaches supported operations

In [48]:
batched_qs, post_process_fn = dev.preprocess(circuit.tape)

print(qml.drawer.tape_text(batched_qs[0]))

0: ──Rot─╭●───────╭X──Rot─╭●────╭X──Rot──────╭●─╭X──Rot──────╭●─────────╭X─┤     
1: ──Rot─╰X─╭●────│───Rot─│──╭●─│──╭X────Rot─│──╰●─╭X────Rot─╰X───╭●────│──┤     
2: ──Rot────╰X─╭●─│───Rot─╰X─│──╰●─│─────Rot─│─────╰●───╭X────Rot─╰X─╭●─│──┤     
3: ──Rot───────╰X─╰●──Rot────╰X────╰●────Rot─╰X─────────╰●────Rot────╰X─╰●─┤  <Z>


## Unsupportable Quantum Script?

In [23]:
%xmode Minimal

@qml.qnode(dev, diff_method=None)
def circuit(theta, phi):
    qml.Beamsplitter(theta, phi, wires=(0,1))
    return qml.expval(qml.PauliX(0))

circuit(1.2, 2.3)


Exception reporting mode: Minimal


NotImplementedError: Beamsplitter(1.2, 2.3, wires=[0, 1]) not supported on device

Allows additional forms of validation:

In [50]:
%xmode Minimal

@qml.qnode(dev)
def circuit():
    [qml.PauliX(i) for i in range(50)]
    return qml.expval(qml.PauliX(0))

circuit()

Exception reporting mode: Minimal


NotImplementedError: Requested execution with 50 qubits. We support at most 30.

# What does this look like internally?

### Separation of driver from interface

Device is just the interface.  Implementation details, like simulators or hardware drivers, can be handled in an additional level of abstraction:

In [51]:
from pennylane.devices.experimental import PlainNumpySimulator, JaxSimulator

In [52]:
jax_sim = JaxSimulator()

In [53]:
[obj for obj in dir(jax_sim) if obj[0] != "_"]

['apply_matrix',
 'apply_matrix_einsum',
 'apply_matrix_tensordot',
 'apply_operation',
 'create_state_vector_state',
 'create_zeroes_state',
 'execute',
 'measure']

Improves the documentation and ease of developement for the simulator.

In [54]:
state = jax_sim.create_zeroes_state(1)
print(state)
state = jax_sim.apply_operation(state, qml.PauliX(0))
print(state)
output = jax_sim.measure(state, qml.expval(qml.PauliZ(0)))
print(output)

[1.+0.j 0.+0.j]
[0.+0.j 1.+0.j]
-1.0


## Required device Interface

In [55]:
fresh_dev = TestDevicePythonSim()

[obj for obj in dir(fresh_dev) if obj[0] != "_"]

['capabilities',
 'execute',
 'execute_and_gradients',
 'gradient',
 'preprocess',
 'register_execute',
 'register_fn',
 'register_gradient',
 'registrations',
 'tracker',
 'vjp']

In [56]:
fresh_dev.capabilities()

{}

## Execution Config

Getting the workflow to support this will be work.

In [60]:
from pennylane.runtime import ExecutionConfig

In [63]:
config = ExecutionConfig(shots=100, interface="jax")
config

ExecutionConfig(shots=100, grad=None, preproc=None, postproc=None, interface='jax', diff_method=<DiffType.DEVICE: 2>, cache_size=10000, max_expansion=10, max_diff=1, grad_args={}, expansion_strategy=<ExpansionStrategy.GRADIENT: 2>)

In [64]:
dev.execute(qs, config)

0.36235775447667357

## Registrations

In [57]:
fresh_dev.registrations

{<FnType.GRADIENT: 4>: {1: <function pennylane.devices.experimental.custom_device_3_numpydev.python_device.gradient(self, qscript: pennylane.tape.qscript.QuantumScript, order: int = 1)>}}

In [58]:
@TestDevicePythonSim.register_gradient(order=2)
def hessian(self, qscript, order: int = 2):
    print("look! I'm computing the hessian!")

fresh_dev2 = TestDevicePythonSim()
qs = qml.tape.QuantumScript([qml.RX(1.2, wires=0)], [qml.expval(qml.PauliZ(0))])

fresh_dev2.gradient(qs, order=2)
fresh_dev2.gradient(qs, order=3)

look! I'm computing the hessian!


ValueError: Device does not support 3 order derivatives

In [59]:
fresh_dev2.registrations[qml.devices.experimental.FnType.GRADIENT]

{1: <function pennylane.devices.experimental.custom_device_3_numpydev.python_device.gradient(self, qscript: pennylane.tape.qscript.QuantumScript, order: int = 1)>,
 2: <function __main__.hessian(self, qscript, order: int = 2)>}