# Chapter 4: Toy Optimisations

As we saw in the previous chapter, the IR generated from the input program has many
opportunities for optimisation. In this chapter, we'll implement three optimisations:

1. Removing redundant reshapes
2. Reshaping constants during compilation time
3. Eliminating operations whose results are not used

Let's take a look again at our example input:

In [None]:
import xdsl, riscemu

from compiler import parse_toy, print_op

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<6> = [1, 2, 3, 4, 5, 6];
  var c<2, 3> = b;
  var d = a + c;
  print(d);
}
"""

toy = parse_toy(example)
print_op(toy)
print()

## Redundant Reshapes

In [None]:
from typing import cast
from xdsl.ir import OpResult, Operation
from xdsl.dialects.builtin import DenseIntOrFPElementsAttr
from xdsl.pattern_rewriter import (op_type_rewrite_pattern, RewritePattern,
                                   PatternRewriter, PatternRewriteWalker)


import toy.dialect as td

class ReshapeReshapeOptPattern(RewritePattern):

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: td.ReshapeOp, rewriter: PatternRewriter):
        """
        Reshape(Reshape(x)) = Reshape(x)
        """
        # Look at the input of the current reshape.
        reshape_input = op.arg
        if not isinstance(reshape_input, OpResult):
            # Input was not produced by an operation, could be a function argument
            return

        reshape_input_op = reshape_input.op
        if not isinstance(reshape_input_op, td.ReshapeOp):
            # Input defined by another transpose? If not, no match.
            return

        t = cast(td.TensorTypeI32, op.res.typ)
        new_op = td.ReshapeOp.from_input_and_type(reshape_input_op.arg, t)
        rewriter.replace_matched_op(new_op)

# Use `PatternRewriteWalker` to rewrite all matched operations
PatternRewriteWalker(ReshapeReshapeOptPattern()).rewrite_module(toy)
print_op(toy)

This looks very similar to what we had before, but is subtly different. Importantly,
the reshape that assigns to %4 now takes %2 as input, instead of %3. %3 is now no longer
used, and because it's an operation with no observable side-effects, we can avoid doing
the work altogether.

In [None]:
from toy.rewrites import RemoveUnusedOperations

PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(toy)

print_op(toy)

## Fold Constant Reshaping

One more opportunity for optimisation is to reshape the constants at compile-time,
instead of at runtime. We can do this with another custom `RewritePattern`:

In [None]:
class FoldConstantReshapeOptPattern(RewritePattern):

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: td.ReshapeOp, rewriter: PatternRewriter):
        """
        Reshaping a constant can be done at compile time
        """
        # Look at the input of the current reshape.
        reshape_input = op.arg
        if not isinstance(reshape_input, OpResult):
            # Input was not produced by an operation, could be a function argument
            return

        reshape_input_op = reshape_input.op
        if not isinstance(reshape_input_op, td.ConstantOp):
            # Input defined by another transpose? If not, no match.
            return

        new_value = DenseIntOrFPElementsAttr.create_dense_int(
            type=op.res.typ, data=reshape_input_op.value.data.data)
        new_op = td.ConstantOp.from_value(new_value)
        rewriter.replace_matched_op(new_op)


PatternRewriteWalker(FoldConstantReshapeOptPattern()).rewrite_module(toy)
print_op(toy)

In [None]:
# Remove now unused original constants
PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(toy)
print_op(toy)

Now that we've done all the optimisations we could on this level of abstraction, let's
go one level lower towards RISC-V.