In [1]:
import pennylane as qml
import numpy as np

In [2]:
dev = qml.device('lightning.qubit', wires=2)

@qml.qjit
@qml.qnode(dev)
def circuit1(x):
    qml.RX(x, wires=0)
    qml.RX(x**2, wires=1)
    return qml.expval(qml.PauliZ(0))

circuit1(0.5)

array(0.87758256)

In [5]:
help(qml.qjit)

Help on function for_loop in module pennylane.qjit_compile:

for_loop(*args, **kwargs)
    A :func:`~.qjit` compatible for-loop decorator for PennyLane/Catalyst.
    
    This for-loop representation is a functional version of the traditional
    for-loop, similar to ``jax.cond.fori_loop``. That is, any variables that
    are modified across iterations need to be provided as inputs/outputs to
    the loop body function:
    
    - Input arguments contain the value of a variable at the start of an
      iteration.
    
    - output arguments contain the value at the end of the iteration. The
      outputs are then fed back as inputs to the next iteration.
    
    The final iteration values are also returned from the transformed
    function.
    
    This form of control flow can also be called from the Python interpreter without needing to use
    :func:`~.qjit`.
    
    The semantics of ``for_loop`` are given by the following Python pseudo-code:
    
    .. code-block:: python
    


In [2]:
dev = qml.device('lightning.qubit', wires=2)

@qml.qjit(target="mlir")
@qml.qnode(dev)
def circuit1(x: float):
    qml.RX(x, wires=0)
    qml.RX(x**2, wires=1)
    return qml.expval(qml.PauliZ(0))

print(circuit1.mlir)

module @circuit1 {
  func.func public @jit_circuit1(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
    %0 = call @circuit1(%arg0) : (tensor<f64>) -> tensor<f64>
    return %0 : tensor<f64>
  }
  func.func private @circuit1(%arg0: tensor<f64>) -> tensor<f64> attributes {diff_method = "finite-diff", llvm.linkage = #llvm.linkage<internal>, qnode} {
    "quantum.device"() {specs = ["kwargs", "{'shots': 0}"]} : () -> ()
    "quantum.device"() {specs = ["backend", "lightning.qubit"]} : () -> ()
    %0 = stablehlo.constant dense<2> : tensor<i64>
    %1 = "quantum.alloc"() {nqubits_attr = 2 : i64} : () -> !quantum.reg
    %2 = stablehlo.multiply %arg0, %arg0 : tensor<f64>
    %3 = stablehlo.constant dense<0> : tensor<i64>
    %4 = "tensor.extract"(%3) : (tensor<i64>) -> i64
    %5 = "quantum.extract"(%1, %4) : (!quantum.reg, i64) -> !quantum.bit
    %6 = "tensor.extract"(%arg0) : (tensor<f64>) -> f64
    %7 = "quantum.custom"(%6, %5) {gate_name = "RX", operand_segment_

In [3]:
dev = qml.device('lightning.qubit', wires=2)

@qml.qjit
@qml.qnode(dev)
def circuit2(n: int):
    @qml.for_loop(0, n, 1)
    def loop_fn(_, x):
        qml.RY(x, wires=0)
        return x + np.pi / 4

    loop_fn(0.0)
    return qml.expval(qml.PauliZ(0))

print(circuit2(5))

3.3306690738754696e-16


In [4]:
dev = qml.device('lightning.qubit', wires=1)

@qml.qjit
@qml.qnode(dev)
def circuit(x):
    @qml.cond(x > 4.8)
    def cond_fn():
        return x * 8

    @cond_fn.else_if(x > 2.7)
    def cond_elif():
        return x * 4

    @cond_fn.else_if(x > 1.4)
    def cond_elif2():
        return x * 2

    @cond_fn.otherwise
    def cond_else():
        return x

    return cond_fn()

assert circuit(5) == 40


In [7]:
# Measure Test

dev = qml.device("lightning.qubit", wires=3)

@qml.qjit
@qml.qnode(dev)
def func():
    qml.PauliX(1)
    m_0 = qml.measure(1)
    return qml.probs(wires=[1])

func()

array([0., 1.])

In [10]:
# Measure Test
from jax import numpy as jnp

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x: float):
    qml.RX(x, wires=0)
    m1 = qml.measure(wires=0)
    maybe_pi = m1 * jnp.pi
    qml.RX(maybe_pi, wires=1)
    m2 = qml.measure(wires=1)
    return m2

assert circuit(jnp.pi)
assert not circuit(0.0)


In [3]:
import pennylane as qml

# Gradient Test
def f(x):
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliY(0))

@qml.qjit
def grad_fn(x: float):
    g = qml.qnode(qml.device("lightning.qubit", wires=1))(f)
    h = qml.grad(g, argnum=0)
    return h(x)

grad_fn(0.526)

array(-0.86482227)