# Parsing a torch model in MLIR Syntax

Those can already be generated by [Torch-MLIR](https://github.com/llvm/torch-mlir)!

Let's just parse it and print it for now

One can see that some tensor literals are only used in transpose operations. Let's optimize this. 

In [None]:
import xdsl, riscemu
from torchxdsl.dialect import *

from xdsl.dialects.func import Func
from xdsl.dialects.builtin import Builtin
from xdsl.parser import Parser, Source

from compiler import print_op
from xdsl.ir import MLContext

context = MLContext()
context.register_dialect(Torch)
context.register_dialect(Func)
context.register_dialect(Builtin)

with open('examples/alexnet.mlir')as f:
    parser = Parser(context, f.read(), Source.MLIR, f.name)
    module = parser.parse_module()

print_op(module)

In [None]:
# Import some things from the xdsl.pattern_rewriter module:
from xdsl.pattern_rewriter import (GreedyRewritePatternApplier,
                                   PatternRewriter, PatternRewriteWalker,
                                   RewritePattern, op_type_rewrite_pattern)

# Create our rewriter class:
class TransposedLiteralOptimizer(RewritePattern):
    
    @op_type_rewrite_pattern
    def match_and_rewrite(self, transpose: TransposeOp, rewriter: PatternRewriter):
        """
        This method will be called on each TransposeOp in our Torch-xDSL module.
        """
        # we iterate over all operands (arguments) of the add instruction
        if isinstance(transpose.tensor.op, VTensorLitteralOp):
            
            transposed_literal = transpose.tensor.op.clone()

            # this crashes after the xDSL update. Really we'd need to rewrite this, as
            # this code doesn't quite follow recommended style. We'll fix it as we migrate
            # it to the main xdsl repo


            # t = transposed_literal.res.typ.dimensions.data[transpose.dim1.op.value.value.data]
            # transposed_literal.res.typ.dimensions.data[transpose.dim1.op.value.value.data] = transposed_literal.res.typ.dimensions.data[transpose.dim2.op.value.value.data]
            # transposed_literal.res.typ.dimensions.data[transpose.dim2.op.value.value.data] = t

            # rewriter.replace_matched_op(transposed_literal)
            # if len(transpose.tensor.uses) == 0:
            #     rewriter.erase_op(transpose.tensor.op)
            
optimized_module = module.clone()
PatternRewriteWalker(TransposedLiteralOptimizer()).rewrite_module(optimized_module)
print_op(optimized_module)