### A2.5.2. Pattern Rewriting

> *Pattern rewriting in MLIR matches subgraphs of operations in the IR and replaces them with equivalent or optimized alternatives using a declarative or imperative pattern specification.*

**Explanation:**

MLIR transformations are built around a **pattern rewriting** framework. Instead of writing ad-hoc IR mutations, transformations are expressed as reusable patterns that the framework applies systematically.

**Three Mechanisms:**

1. **C++ RewritePattern** ‚Äî Imperative patterns that subclass `mlir::RewritePattern`. The `matchAndRewrite` method checks if a pattern applies and performs the rewrite.

2. **Declarative Rewrite Rules (DRR)** ‚Äî TableGen-based rules of the form `def : Pat<(SourceOp ...), (TargetOp ...)>`. Automatically generates the C++ pattern.

3. **PDL (Pattern Description Language)** ‚Äî An MLIR dialect for writing patterns themselves in MLIR syntax, enabling pattern compilation and optimization.

**Pattern Application:**

The **GreedyRewriteDriver** repeatedly applies patterns until a fixed point:

1. Build a worklist of operations.
2. For each operation, try all patterns sorted by **benefit** (higher benefit = tried first).
3. If a pattern matches, apply it and add new/modified ops to the worklist.
4. Repeat until no more patterns match.

**Example (C++ RewritePattern):**

```cpp
struct AddZeroElim : public OpRewritePattern<arith::AddIOp> {
  LogicalResult matchAndRewrite(arith::AddIOp op,
                                PatternRewriter &rewriter) const override {
    if (matchPattern(op.getRhs(), m_Zero())) {
      rewriter.replaceOp(op, op.getLhs());
      return success();
    }
    return failure();
  }
};
```

In [None]:
from dataclasses import dataclass


@dataclass
class Operation:
    name: str
    operands: list
    result: str


class RewritePattern:
    benefit: int = 1

    def match_and_rewrite(self, operation, operations_list):
        raise NotImplementedError


class AddZeroElimination(RewritePattern):
    benefit = 10

    def match_and_rewrite(self, operation, operations_list):
        if operation.name != "arith.addi":
            return False
        if not (isinstance(operation.operands[1], int) and operation.operands[1] == 0):
            return False
        for other_op in operations_list:
            other_op.operands = [
                operation.operands[0] if operand == operation.result else operand
                for operand in other_op.operands
            ]
        operations_list.remove(operation)
        print(f"  Eliminated: {operation.result} = {operation.name}({operation.operands[0]}, 0)")
        return True


class MulOneElimination(RewritePattern):
    benefit = 10

    def match_and_rewrite(self, operation, operations_list):
        if operation.name != "arith.muli":
            return False
        if not (isinstance(operation.operands[1], int) and operation.operands[1] == 1):
            return False
        for other_op in operations_list:
            other_op.operands = [
                operation.operands[0] if operand == operation.result else operand
                for operand in other_op.operands
            ]
        operations_list.remove(operation)
        print(f"  Eliminated: {operation.result} = {operation.name}({operation.operands[0]}, 1)")
        return True


def greedy_rewrite(operations, patterns):
    sorted_patterns = sorted(patterns, key=lambda p: p.benefit, reverse=True)
    changed = True
    iteration = 0
    while changed:
        changed = False
        iteration += 1
        print(f"Iteration {iteration}:")
        for operation in list(operations):
            for pattern in sorted_patterns:
                if pattern.match_and_rewrite(operation, operations):
                    changed = True
                    break


operations = [
    Operation("arith.addi", ["%x", 0], "%a"),
    Operation("arith.muli", ["%a", 1], "%b"),
    Operation("arith.addi", ["%b", "%y"], "%c"),
]

patterns = [AddZeroElimination(), MulOneElimination()]

print("Before:")
for op in operations:
    print(f"  {op.result} = {op.name}({', '.join(str(o) for o in op.operands)})")
print()

greedy_rewrite(operations, patterns)

print("\nAfter:")
for op in operations:
    print(f"  {op.result} = {op.name}({', '.join(str(o) for o in op.operands)})")

**References:**

[üìò MLIR Project. *Pattern Rewriting.*](https://mlir.llvm.org/docs/PatternRewriter/)

[üìò MLIR Project. *Declarative Rewrite Rules (DRR).*](https://mlir.llvm.org/docs/DeclarativeRewrites/)

---

[‚¨ÖÔ∏è Previous: MLIR Dialects](./01_mlir_dialects.ipynb) | [Next: Lowering to LLVM ‚û°Ô∏è](./03_lowering_to_llvm.ipynb)