<br/>
<div style="text-align: center">
<span style="float: left">
    <a href="Toy_Ch2.ipynb">&lt; Previous Chapter</a>
</span>
<span style="">
    <a href="table_of_contents.ipynb">Table Of Contents 🏠</a>
</span>
</div>

In [1]:
from pathlib import Path
from io import StringIO

from xdsl.ir import Dialect, MLContext
from xdsl.dialects.builtin import ModuleOp
from xdsl.printer import Printer

from toy.dialect import Toy
from toy.mlir_gen import MLIRGen
from toy.parser import Parser

ctx = MLContext()
ctx.register_dialect(Toy)

def parse(program: str) -> ModuleOp:
    mlir_gen = MLIRGen(ctx)
    module_ast = Parser(Path('in_memory'), program).parseModule()
    module_op = mlir_gen.mlir_gen_module(module_ast)
    return module_op

def print_module(module: ModuleOp):
    Printer(target=Printer.Target.MLIR).print(module)

In [2]:
tt = """
def transpose_transpose(x) {
  return transpose(transpose(x));
}
"""

In [3]:
module = parse(tt)
print_module(module)

"builtin.module"() ({
  "toy.func"() ({
  ^0(%0 : tensor<*xi32>):
    %1 = "toy.transpose"(%0) : (tensor<*xi32>) -> tensor<*xi32>
    %2 = "toy.transpose"(%1) : (tensor<*xi32>) -> tensor<*xi32>
    "toy.return"(%2) : (tensor<*xi32>) -> ()
  }) {"sym_name" = "transpose_transpose", "function_type" = (tensor<*xi32>) -> tensor<*xi32>, "sym_visibility" = "private"} : () -> ()
}) : () -> ()


In [4]:
from xdsl.pattern_rewriter import (GreedyRewritePatternApplier,
                                   PatternRewriteWalker)

from toy.rewrites import SimplifyRedundantTranspose

PatternRewriteWalker(
    GreedyRewritePatternApplier([SimplifyRedundantTranspose()
                                    ])).rewrite_module(module)

print_module(module)

"builtin.module"() ({
  "toy.func"() ({
  ^0(%0 : tensor<*xi32>):
    %1 = "toy.transpose"(%0) : (tensor<*xi32>) -> tensor<*xi32>
    "toy.return"(%0) : (tensor<*xi32>) -> ()
  }) {"sym_name" = "transpose_transpose", "function_type" = (tensor<*xi32>) -> tensor<*xi32>, "sym_visibility" = "private"} : () -> ()
}) : () -> ()


In [5]:
from toy.rewrites import RemoveUnusedOperations

PatternRewriteWalker(
    GreedyRewritePatternApplier([RemoveUnusedOperations()
                                    ])).rewrite_module(module)

print_module(module)

"builtin.module"() ({
  "toy.func"() ({
  ^0(%0 : tensor<*xi32>):
    "toy.return"(%0) : (tensor<*xi32>) -> ()
  }) {"sym_name" = "transpose_transpose", "function_type" = (tensor<*xi32>) -> tensor<*xi32>, "sym_visibility" = "private"} : () -> ()
}) : () -> ()


In [6]:
constants = """
def main() {
  var a<2,1> = [1, 2];
  var b<2,1> = a;
  var c<2,1> = b;
  print(c);
}
"""

In [7]:
module = parse(constants)
print_module(module)

"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
    %1 = "toy.reshape"(%0) : (tensor<2xi32>) -> tensor<2x1xi32>
    %2 = "toy.reshape"(%1) : (tensor<2x1xi32>) -> tensor<2x1xi32>
    %3 = "toy.reshape"(%2) : (tensor<2x1xi32>) -> tensor<2x1xi32>
    "toy.print"(%3) : (tensor<2x1xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()


In [8]:
from toy.rewrites import ReshapeReshapeOptPattern

PatternRewriteWalker(
    GreedyRewritePatternApplier([ReshapeReshapeOptPattern()])
    ).rewrite_module(module)

PatternRewriteWalker(
    GreedyRewritePatternApplier([RemoveUnusedOperations()])
    ).rewrite_module(module)

print_module(module)

"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
    %1 = "toy.reshape"(%0) : (tensor<2xi32>) -> tensor<2x1xi32>
    "toy.print"(%1) : (tensor<2x1xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()


In [9]:
from toy.rewrites import FoldConstantReshapeOptPattern

PatternRewriteWalker(
    GreedyRewritePatternApplier([FoldConstantReshapeOptPattern()])
    ).rewrite_module(module)

PatternRewriteWalker(
    GreedyRewritePatternApplier([RemoveUnusedOperations()])
    ).rewrite_module(module)

print_module(module)

"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1], [2]]> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
    "toy.print"(%0) : (tensor<2x1xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()


<br/>
<div style="text-align: center">
<span style="float: left">
    <a href="Toy_Ch2.ipynb">&lt; Previous Chapter</a>
</span>
<span style="">
    <a href="table_of_contents.ipynb">Table Of Contents 🏠</a>
</span>
</div>