# Advanced pattern matching and rewriting

In the previous notebook, we saw how to define rewrite rules that look for
instances of a given pattern in a circuit, and replace them with a given replacement.

In this notebook, we will see generalise that approach by defining custom pattern matchers and replacers.

_Note: This is currently unreleased and requires to build the tket-py crate from source in the branch `lm/pyinterface`. Activate a Python virtual environment of your choice, then `pip install maturin` and run `maturin develop` from the `./tket-py` directory._

In [1]:
from pytket import Circuit, OpType
from pytket.circuit.display import render_circuit_jupyter

from tket._tket.ops import (
    TketOp,
)  # Note that we are importing the "wrong" TketOp, see issue https://github.com/CQCL/tket2/issues/1027
from tket.ops import TketOp as PyTketOp # This would be the "correct" TketOp
from tket.circuit import Tk2Circuit

# TODO: remove this
def matches_op(op: TketOp, op2: PyTketOp) -> bool:
    return op == op2._to_rs()

### Matchers and replacers

The two core concepts that allow custom rewriting are:

- A `CircuitMatcher`: an object that defines what patterns to match. It takes operations matched so far and a new operation and must decide if the new operation should be added to the match.
- A `CircuitReplacer`: an object that returns all possible replacements for a matched circuit. Given a matched pattern, it should return a list of circuits that are the result of the replacement.

The API that these two objects must implement is given by Python protocols,
that can be imported from `tket.protocol`:


In [2]:
from tket.protocol import CircuitMatcher, CircuitReplacer

### tket-provided matchers and replacers

TKET provides built-in matchers and replacers for common patterns.

Currently there are (these are dummy examples, to be expanded!)
- one matcher: the `RotationMatcher`, which matches any rotation gate, and
- one replacer: `ReplaceWithIdentity`, which replaces a circuit with the identity circuit with the same number of qubits.

In [3]:
from tket.matcher import RotationMatcher, ReplaceWithIdentity

print("RotationMatcher is a CircuitMatcher:", isinstance(RotationMatcher(), CircuitMatcher))
print("ReplaceWithIdentity is a CircuitReplacer:", isinstance(ReplaceWithIdentity(), CircuitReplacer))

RotationMatcher is a CircuitMatcher: True
ReplaceWithIdentity is a CircuitReplacer: True


Matchers and rewrites are combined together into a Rewriter (similar to the rewriters created in the previous notebook), using the `MatchReplaceRewriter` class. Creating a rewriter that removes all rotation gates in a circuit is as simple as:

In [4]:
from tket.matcher import MatchReplaceRewriter

rewriter = MatchReplaceRewriter(RotationMatcher(), ReplaceWithIdentity())

To run the rewriter repeatedly on a circuit, we use the Badger optimiser:

In [5]:
from tket.optimiser import BadgerOptimiser

circ = Circuit(2).CX(0, 1).Rz(0.2, 0).Rx(0.1, 1).CX(0, 1)

opt = BadgerOptimiser(rewriter)
no_rot_circ = opt.optimise(circ)

render_circuit_jupyter(no_rot_circ)

### Implementing custom matchers and replacers in Python

It is just as easy to implement custom matchers and replacers in Python. We just need to define
a class that implements the respective protocol.  We can for example
define a matcher that finds any sequence of two CNOT gates as follows

In [6]:
from tket.matcher import MatchOutcome, MatchContext, CircuitUnit
from typing import Any

def succeeds_previous_op(op_args: list[CircuitUnit]) -> bool:
    """Whether this current op is in the future of previously matched ops."""
    return all(arg.linear_pos in ["after", None] for arg in op_args)

class TwoCXMatcher:
    def match_tket_op(
        self, op: TketOp, op_args: list[CircuitUnit], context: MatchContext
    ) -> MatchOutcome:
        # We are only interested in CXs
        if not matches_op(op, PyTketOp.CX):
            return { "stop": True }
        
        # use the `match_info` dict key to track the number of CXs matched so far
        prev_matched_cx = context["match_info"]

        if prev_matched_cx is not None and not succeeds_previous_op(op_args):
            # The second CX we match should come AFTER the first and on the same
            # qubits. If the current op is not after the previous CX on all
            # qubits, we are not interested.
            return { "skip": True }
        

        if prev_matched_cx is None:
            # This is the first CX we matched, so we proceed to match the second.
            match op_args:
                case [CircuitUnit(linear_index=ctrl), CircuitUnit(linear_index=tgt)]:
                    return { "proceed": (ctrl, tgt) }
                case _:
                    raise ValueError(f"Unexpected op args: {op_args}")
        else:
            match op_args:
                case [CircuitUnit(linear_index=snd_ctrl), CircuitUnit(linear_index=snd_tgt)]:
                    if prev_matched_cx == (snd_ctrl, snd_tgt):
                        # We have successfully matched two CXs, so we are done!
                        return { "complete": True }
                    else:
                        # The second CX is upside down, so we are not interested.
                        return { "stop": True }
                case _:
                    raise ValueError(f"Unexpected op args: {op_args}")
        

assert isinstance(TwoCXMatcher(), CircuitMatcher)

As expected, if we combine this with the `ReplaceWithIdentity` rewriter and apply it to the circuit with two CXs,
we get the empty circuit.

In [7]:
cancel_cx = MatchReplaceRewriter(TwoCXMatcher(), ReplaceWithIdentity())

opt = BadgerOptimiser(cancel_cx)
empty_circ = opt.optimise(no_rot_circ)

assert empty_circ == Circuit(2)

Observe that thanks to the `succeeds_previous_op` check, this matcher will not match two CX if they are not on the same qubits. Neither will it match two CXs on the same qubits, but with controls and targets swapped:

In [8]:
circ = Circuit(3).CX(0, 1).CX(1, 2).CX(2, 1)

opt = BadgerOptimiser(cancel_cx)
same_circ = opt.optimise(circ)

assert same_circ == circ

### A fully fledged example

To complete this notebook, we will implement a slightly more interesting optimisation, which combines CX cancellation as above with ZZPhase flipping.

The `flip_zzphase` rewriter is composed of the `ZZPhaseMatcher` and `FlippedZZPhase` replacer:

In [9]:
from typing import Literal

MatchState = None \
    | tuple[Literal["matched_first_cx"], int, int]\
    | tuple[Literal["matched_rotation"], int, int, float]

class ZZPhaseMatcher:
    def match_tket_op(
        self, op: TketOp, op_args: list[CircuitUnit], context: MatchContext
    ) -> MatchOutcome:
        state: MatchState = context["match_info"]

        match state:
            case None:
                # We are looking for a CX
                if matches_op(op, PyTketOp.CX):
                    # This is the first matched op, so the relative position of
                    # the qubits with respect to the already matched subcircuit
                    # is not yet known.
                    [ctrl_qubit, tgt_qubit] = [arg.linear_index for arg in op_args]
                    assert all(arg.linear_pos is None for arg in op_args)

                    return {
                        "proceed": ("matched_first_cx", ctrl_qubit, tgt_qubit)
                    }
                else:
                    return { "stop": True }
            case ("matched_first_cx", ctrl_qubit, tgt_qubit):
                # must come after the first CX
                if not succeeds_previous_op(op_args):
                    return { "skip": True }

                # We are looking for a rotation
                if matches_op(op, PyTketOp.Rz) and op_args[0].linear_index == tgt_qubit \
                                               and op_args[1].constant_float is not None:
                    rot_angle = op_args[1].constant_float
                    return {
                        "proceed": ("matched_rotation", ctrl_qubit, tgt_qubit, rot_angle)
                    }
                else:
                    return { "skip": True }
            case ("matched_rotation", ctrl_qubit, tgt_qubit, rot_angle):
                # must come after the first CX and the rotation
                if not succeeds_previous_op(op_args):
                    return { "skip": True }

                # We are looking for a second CX
                if matches_op(op, PyTketOp.CX) and [arg.linear_index for arg in op_args] == [ctrl_qubit, tgt_qubit]:
                    return { "complete": rot_angle }
                else:
                    return { "skip": True }

class FlippedZZPhase:
    def replace_match(self, circuit: Tk2Circuit, match_info: float) -> list[Tk2Circuit]:
        assert circuit.to_tket1().n_qubits == 2

        flipped_circ = Circuit(2).CX(1, 0).Rz(match_info, 0).CX(1, 0)
        # zzphase_circ = Circuit(2).CX(0, 1).Rz(match_info, 1).CX(0, 1)
        return [Tk2Circuit(flipped_circ)]

assert isinstance(ZZPhaseMatcher(), CircuitMatcher)
assert isinstance(FlippedZZPhase(), CircuitReplacer)

flip_zzphase = MatchReplaceRewriter(ZZPhaseMatcher(), FlippedZZPhase())

We would like to combine `flip_zzphase` with `cancel_cx` to optimise the following circuit:

In [10]:
circ = Circuit(3).CX(0, 1).CX(0, 1).Rz(0.111, 1).CX(0, 1).CX(2, 0).CX(0, 2).Rz(0.5, 2).CX(0, 2)

assert len(flip_zzphase.get_rewrites(Tk2Circuit(circ))) == 2

render_circuit_jupyter(circ)

Note that the order in which the rewrites should be performed is important in this case: if we apply all zzphase flips first, then the first two CX gates will no longer cancel out.

The badger optimiser is smart: it will always find the best sequence of rewrites to minimise CX count. In this case, it will thus cancel the first two CX gates, then flip the ZZPhase at the end of the circuit before cancelling the resulting two CXs in the middle of the circuit:

In [11]:
opt = BadgerOptimiser([cancel_cx, flip_zzphase])
opt_circ = opt.optimise(circ)

assert opt_circ.n_gates_of_type(OpType.CX) == 2

render_circuit_jupyter(opt_circ)