# Parsing a torch model in MLIR Syntax

Those can already be generated by [Torch-MLIR](https://github.com/llvm/torch-mlir)!

Let's just parse it and print it for now

One can see that some tensor literals are only used in transpose operations. Let's optimize this. 

In [17]:
from torchxdsl.dialect import *

from xdsl.dialects.func import Func
from xdsl.dialects.builtin import Builtin
from xdsl.parser import Parser, Source

from compiler import print_op
from xdsl.ir import MLContext

context = MLContext()
context.register_dialect(Torch)
context.register_dialect(Func)
context.register_dialect(Builtin)

with open('examples/alexnet.mlir')as f:
    parser = Parser(context, f.read(), Source.MLIR, f.name)
    module = parser.parse_module()

print_op(module)

Traceback (most recent call last):
  File "/home/papychacal/.local/lib/python3.10/site-packages/xdsl/parser.py", line 321, in backtracking
    try:
  File "/home/papychacal/.local/lib/python3.10/site-packages/xdsl/parser.py", line 1278, in try_parse_builtin_named_attr
    'dense': self._parse_builtin_dense_attr,
  File "/home/papychacal/.local/lib/python3.10/site-packages/xdsl/parser.py", line 1276, in not_implemented
    self.tokenizer.consume_peeked(name)
NotImplementedError


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/papychacal/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3442, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_6163/3355953327.py", line 17, in <module>
    module = parser.parse_module()
  File "/home/papychacal/.local/lib/python3.10/site-packages/xdsl/parser.py", line 622, in parse_module
  File "/home/papychacal/.local/lib/python3.10/site-packages/xdsl/parser.py", line 1067, in raise_error
    at_position = self.tokenizer.next_token(peek=True)
xdsl.utils.exceptions.ParseError: examples/alexnet.mlir:11:44
    %7 = "torch.vtensor.literal"() {value = dense_resource<__elided__> : tensor<1000x4096xf32>} : () -> !torch.vtensor<[1000,4096],f32>
                                            ^^^^^^^^^^^^^^
                                            Unexpected exception: 

examples/alexnet.mlir:11:44
    %7 = "torch.vtensor.literal"() {value = dense_resource<__elided__> : ten

In [18]:
# Import some things from the xdsl.pattern_rewriter module:
from xdsl.pattern_rewriter import (GreedyRewritePatternApplier,
                                   PatternRewriter, PatternRewriteWalker,
                                   RewritePattern, op_type_rewrite_pattern)

# Create our rewriter class:
class TransposedLiteralOptimizer(RewritePattern):
    
    @op_type_rewrite_pattern
    def match_and_rewrite(self, transpose: TransposeOp, rewriter: PatternRewriter):
        """
        This method will be called on each TransposeOp in our Torch-xDSL module.
        """
        # we iterate over all operands (arguments) of the add instruction
        if isinstance(transpose.tensor.op, VTensorLitteralOp):
            
            transposed_literal = transpose.tensor.op.clone()
            t = transposed_literal.res.typ.dimensions.data[transpose.dim1.op.value.value.data]
            transposed_literal.res.typ.dimensions.data[transpose.dim1.op.value.value.data] = transposed_literal.res.typ.dimensions.data[transpose.dim2.op.value.value.data]
            transposed_literal.res.typ.dimensions.data[transpose.dim2.op.value.value.data] = t

            rewriter.replace_matched_op(transposed_literal)
            if len(transpose.tensor.uses) == 0:
                rewriter.erase_op(transpose.tensor.op)
            
optimized_module = module.clone()
PatternRewriteWalker(TransposedLiteralOptimizer()).rewrite_module(optimized_module)
print_op(optimized_module)

"builtin.module"() ({
  "func.func"() ({
  ^0(%0 : #torch.vtensor<[1 : i64, 3 : i64, 224 : i64, 224 : i64], f32>):
    %1 = "torch.constant.int"() {"value" = 0 : i64} : () -> #torch.int
    %2 = "torch.constant.int"() {"value" = 1 : i64} : () -> #torch.int
    %3 = "torch.constant.float"() {"value" = 1.0 : f64} : () -> #torch.float
    %4 = "torch.constant.int"() {"value" = -1 : i64} : () -> #torch.int
    %5 = "torch.constant.bool"() {"value" = true} : () -> #torch.bool
    %6 = "torch.constant.bool"() {"value" = false} : () -> #torch.bool
    %7 = "torch.constant.none"() : () -> #torch.none
    %8 = "torch.vtensor.literal"() {"value" = dense<[1.0]> : tensor<1000xf32>} : () -> #torch.vtensor<[1000 : i64], f32>
    %9 = "torch.vtensor.literal"() {"value" = dense<[1.0]> : tensor<4096xf32>} : () -> #torch.vtensor<[4096 : i64], f32>
    %10 = "torch.vtensor.literal"() {"value" = dense<[1.0]> : tensor<4096xf32>} : () -> #torch.vtensor<[4096 : i64], f32>
    %11 = "torch.vtensor.literal"() 