<a href="https://colab.research.google.com/github/FLjv77/Quantum_ML_Course/blob/main/codes/Catalyst.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###Installations

In [1]:
pip install pennylane-catalyst

Collecting pennylane-catalyst
  Downloading pennylane_catalyst-0.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (10 kB)
Collecting pennylane>=0.41.0 (from pennylane-catalyst)
  Downloading PennyLane-0.41.1-py3-none-any.whl.metadata (10 kB)
Collecting pennylane-lightning>=0.41.0 (from pennylane-catalyst)
  Downloading pennylane_lightning-0.41.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (12 kB)
Collecting jax==0.4.28 (from pennylane-catalyst)
  Downloading jax-0.4.28-py3-none-any.whl.metadata (23 kB)
Collecting jaxlib==0.4.28 (from pennylane-catalyst)
  Downloading jaxlib-0.4.28-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-catalyst)
  Downloading scipy_openblas32-0.3.29.265.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.1/56.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting diastatic-malt>=2.15.2 (from pennyl

In [2]:
!pip install pennylane



###Imports

In [6]:
from catalyst import qjit, measure, cond, for_loop, while_loop, grad
import pennylane as qml
from jax import numpy as jnp

###Examples

In [4]:
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
    qml.Hadamard(wires=0)
    qml.RX(theta, wires=1)
    qml.CNOT(wires=[0,1])
    return qml.expval(qml.PauliZ(wires=1))

In [7]:
jitted_circuit = qjit(circuit)
jitted_circuit(0.7)

Array(0., dtype=float64)

In [8]:
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=5))
def circuit(arg0, arg1, arg2):
    qml.RX(arg0, wires=[arg1 + 1])
    qml.RY(arg0, wires=[arg2])
    qml.CNOT(wires=[arg1, arg2])
    return qml.probs(wires=[arg1 + 1])

In [9]:
circuit(jnp.pi / 3, 1, 2)

Array([0.625, 0.375], dtype=float64)

In [16]:
dev = qml.device("lightning.qubit", wires=2)  # any Catalyst-supported backend

@qjit                                       #   ⇦ compile once, run fast
@qml.qnode(dev)
def mid_measure(theta: float):
    qml.Hadamard(0)                         # put qubit-0 in |+⟩

    bit = measure(0)                        # **mid-circuit measurement**
    qml.RX(bit * jnp.pi, wires=1)           # rotate qubit-1 iff bit == 1

    return qml.expval(qml.PauliZ(1)), bit   # return both a quantum & classical value

# --- call the compiled function -----------------------------------------

for i in range(10):
  expectation, outcome = mid_measure(0.0)      # first call triggers compilation
  print(expectation, outcome)


-1.0 True
1.0 False
-1.0 True
1.0 False
1.0 False
1.0 False
-1.0 True
1.0 False
1.0 False
1.0 False
