<br><br/>
[Back to table of contents](0_Table_of_Contents.ipynb)

# Chapter 5: Vector IR

Instead of going straight from Toy IR to RISC-V, we'll go through an intermediate dialect,
with some simpler types. In this IR, tensors represented as two 1D vectors, containing
the shape and data information. We can still do pointwise addition on the contents of the
vectors.

In [None]:
import xdsl, riscemu

from compiler import (parse_toy, print_op, optimise_toy, lower_from_toy, 
                                  optimise_vir, lower_to_riscv, emulate_riscv)

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<6> = [1, 2, 3, 4, 5, 6];
  var c<2, 3> = b;
  var d = a + c;
  print(d);
}
"""

toy_0 = parse_toy(example)
toy_1 = optimise_toy(toy_0)
vir_0 = lower_from_toy(toy_1)

print_op(vir_0)

### Lowering Toy to Vector IR


TODO: lower toy.add to toy.tensor.add

### Optimising VIR

As we can see, the code expanded quite a bit. The vector constants are initialised
individually, and then combined using the `toy.tensor.make` operation. One optimisation we
can implement is to use the original data vector instead of using the one from the tensor:

``` python
# Before
shape = ...
data_0 = ...
tensor = Tensor(shape, data_0)
data_1 = tensor.data

# After
shape = ...
data_0 = ...
tensor = Tensor(shape, data_0)
data_1 = data_0
```

If the tensor is unused, we can delete it altogether. Let's see what that gives us.

In [None]:
from typing import cast

from xdsl.ir import OpResult, Operation
from xdsl.dialects.builtin import DenseIntOrFPElementsAttr
from xdsl.pattern_rewriter import (op_type_rewrite_pattern, RewritePattern,
                                   PatternRewriter, PatternRewriteWalker)

from toy.rewrites import RemoveUnusedOperations

import toy.dialect as td
import vector_ir.dialect as tvd

class SimplifyRedundantDataAccess(RewritePattern):

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: td.TensorDataOp,
                          rewriter: PatternRewriter):
        """
        Fold tensor(t_shape, t_data).data -> t_data
        """
        # Look at the input of the current transpose.
        tensor_data_input = op.tensor
        if not isinstance(tensor_data_input, OpResult):
            # Input was not produced by an operation, could be a function argument
            return

        tensor_make_op = tensor_data_input.op
        if not isinstance(tensor_make_op, td.TensorMakeOp):
            # Input defined by a constant passed in? If not, no match.
            return

        rewriter.replace_op(op, [], [tensor_make_op.tensor])

PatternRewriteWalker(SimplifyRedundantDataAccess()).rewrite_module(vir_0)
PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(vir_0)
print_op(vir_0)


In [None]:
from vector_ir.rewrites import SimplifyRedundantShapeAccess

PatternRewriteWalker(SimplifyRedundantShapeAccess()).rewrite_module(vir_0)
PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(vir_0)

print_op(vir_0)

In [None]:
from compiler import optimise_vir

print_op(optimise_vir(vir_0))

<br><br/>
[Back to table of contents](0_Table_of_Contents.ipynb)