<br><br/>
[Back to table of contents](0_Table_of_Contents.ipynb)

# Chapter 5: Vector IR

Instead of going straight from Toy IR to RISC-V, we'll go through an intermediate dialect,
with some simpler types. In this IR, tensors represented as two 1D vectors, containing
the shape and data information. We can still do pointwise addition on the contents of the
vectors.

In [None]:
import xdsl, riscemu

from typing import cast

from xdsl.dialects.builtin import TensorType, UnrankedTensorType
from xdsl.pattern_rewriter import (RewritePattern, op_type_rewrite_pattern, PatternRewriter, PatternRewriteWalker)

from compiler import (parse_toy, print_op, optimise_toy, lower_from_toy, 
                                  optimise_vir, lower_to_riscv, emulate_riscv)

import toy.dialect as td
import vector_ir.dialect as tvd

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<6> = [1, 2, 3, 4, 5, 6];
  var c<2, 3> = b;
  var d = a + c;
  print(d);
}
"""

toy_0 = parse_toy(example)
toy_1 = optimise_toy(toy_0)


class LowerTensorConstantOp(RewritePattern):

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: td.ConstantOp, rewriter: PatternRewriter):
        """
        tensor = [[1, 2, 3], [4, 5, 6]]
        ->
        shape = [2, 3]
        data = [1, 2, 3, 4, 5, 6]
        tensor = Tensor(shape, data)
        """
        typ = op.value.type

        assert isinstance(
            typ, TensorType), 'Toy constants always have rank information'
        typ = cast(td.AnyTensorTypeI32, typ)

        shape: list[int] = op.get_shape()
        data: list[int] = op.get_data()

        shape_vector = tvd.VectorConstantOp.get(shape, 'tensor_shape')
        data_vector = tvd.VectorConstantOp.get(data, 'tensor_data')
        tensor = td.TensorMakeOp.get(shape_vector, data_vector, typ)

        rewriter.replace_matched_op([shape_vector, data_vector, tensor])

class LowerTensorAddOp(RewritePattern):

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: td.AddOp, rewriter: PatternRewriter):
        """
        result = lhs + rhs
        ->
        lhs_data = lhs.data
        rhs_data = rhs.data
        sum_data = lhs_data + rhs_data
        result = Tensor(lhs.shape, sum_data)

        The shapes matching has already been checked by the time of this lowering.
        """
        typ = op.res.typ
        assert isinstance(typ, TensorType | UnrankedTensorType)
        typ = cast(td.AnyTensorTypeI32, typ)

        shape = td.TensorShapeOp.get(op.lhs)
        lhs = td.TensorDataOp.get(op.lhs)
        rhs = td.TensorDataOp.get(op.rhs)
        sum = tvd.VectorAddOp.get(lhs.data, rhs.data)
        result = td.TensorMakeOp.get(shape, sum, typ)

        rewriter.replace_matched_op([shape, lhs, rhs, sum, result])


copy = toy_1.clone()

PatternRewriteWalker(LowerTensorConstantOp()).rewrite_module(copy)
PatternRewriteWalker(LowerTensorAddOp()).rewrite_module(copy)

print_op(toy_1)
print()
print_op(copy)

The IR now contains operations from three dialects: `builtin`, `toy`, and `vector`.

### Optimising VIR

As we can see, the code expanded quite a bit. The vector constants are initialised
individually, and then combined using the `toy.tensor.make` operation. One optimisation we
can implement is to use the original data vector instead of using the one from the tensor:

``` python
# Before
shape = ...
data_0 = ...
tensor = Tensor(shape, data_0)
data_1 = tensor.data

# After
shape = ...
data_0 = ...
tensor = Tensor(shape, data_0)
data_1 = data_0
```

If the tensor is unused, we can delete it altogether. Let's see what that gives us.

In [None]:
from typing import cast

from xdsl.ir import OpResult, Operation
from xdsl.dialects.builtin import DenseIntOrFPElementsAttr
from xdsl.pattern_rewriter import (op_type_rewrite_pattern, RewritePattern,
                                   PatternRewriter, PatternRewriteWalker)

from toy.rewrites import RemoveUnusedOperations

import toy.dialect as td
import vector_ir.dialect as tvd

class SimplifyRedundantDataAccess(RewritePattern):

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: td.TensorDataOp,
                          rewriter: PatternRewriter):
        """
        result = Tensor(shape, data_0).data
        ->
        result = data_0
        """
        # Look at the input of the current transpose.
        tensor_data_input = op.tensor
        if not isinstance(tensor_data_input, OpResult):
            # Input was not produced by an operation, could be a function argument
            return

        tensor_make_op = tensor_data_input.op
        if not isinstance(tensor_make_op, td.TensorMakeOp):
            # Input defined by a constant passed in? If not, no match.
            return

        rewriter.replace_op(op, [], [tensor_make_op.data])

PatternRewriteWalker(SimplifyRedundantDataAccess()).rewrite_module(copy)
PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(copy)

print_op(copy)


In [None]:
from vector_ir.rewrites import SimplifyRedundantShapeAccess

PatternRewriteWalker(SimplifyRedundantShapeAccess()).rewrite_module(copy)
PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(copy)

print_op(copy)

In [None]:
from compiler import optimise_vir

print_op(optimise_vir(copy))

<br><br/>
[Back to table of contents](0_Table_of_Contents.ipynb)