In [9]:
import torch
import torchvision.models as models
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
import os

In [10]:
def compile_and_load_on_refbackend(module):
    backend = RefBackendLinalgOnTensorsBackend()
    compiled = backend.compile(module)
    return backend.load(compiled)

# Example on the Resnet18 Model

### 1. Manual Conversion (to get the IRs) 

Load the model in evaluation mode

In [11]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

Create a random example of input, for "torch-mlir" to infer the shapes

In [12]:
example_input = torch.ones(1, 3, 224, 224)

Compile the model using the random example

In [13]:
compiled = torch_mlir.compile(model, example_input, output_type="linalg-on-tensors")

Save the compiled model in an MLIR file. 
Note: the resulting IR is not executable yet!

In [14]:
output_file_path = "mlir_files/resnet18.mlir"
with open(output_file_path, 'w') as file:
    file.write(str(compiled))

Execute the conversion script (that fixes, wraps, bufferizes and lowers the model into an executable format). The script returns 0 when done.

You can inspect the resulting files of every step in the "mlir_files" folder. 

In [38]:
os.system(command="./convert.sh resnet18")

./mlir_files/resnet18.mlir:641:5: error: redefinition of symbol named 'nanoTime'
    func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface}
    ^
./mlir_files/resnet18.mlir:641:5: note: see current operation: 
"func.func"() <{function_type = () -> i64, sym_name = "nanoTime", sym_visibility = "private"}> ({
}) {llvm.emit_c_interface} : () -> ()
./mlir_files/resnet18.mlir:612:5: note: see existing symbol definition here
    func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface}
    ^
./mlir_files/resnet18.mlir:641:5: error: redefinition of symbol named 'nanoTime'
    func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface}
    ^
./mlir_files/resnet18.mlir:641:5: note: see current operation: 
"func.func"() <{function_type = () -> i64, sym_name = "nanoTime", sym_visibility = "private"}> ({
}) {llvm.emit_c_interface} : () -> ()
./mlir_files/resnet18.mlir:612:5: note: see existing symbol definition here
    func.func private @nanoTime() -> i64

256

You can now execute the model using the MLIR backend:

In [18]:
os.system(command="./execute.sh resnet18")

Unranked Memref base@ = 0x573d798c7200 rank = 2 offset = 0 sizes = [1, 1000] strides = [1000, 1] data = 
[[0.309578,   -0.361994,   -1.2193,   -0.964296,   0.111097,   0.710505,   -1.25162,   -0.810416,   -1.48079,   -0.484123,   -0.0440142,   -1.37726,   -1.3669,   -0.954988,   -1.2015,   -0.762648,   -0.10179,   -1.35949,   -1.08807,   -1.18677,   -0.833686,   1.17233,   0.224612,   -0.346416,   -0.939225,   0.0787166,   -0.234258,   0.0810409,   0.478479,   -1.23911,   -0.889535,   -0.717601,   -0.484779,   -1.43571,   -0.325226,   -0.916341,   -0.397071,   -0.798738,   0.635524,   -0.949957,   -0.767533,   -0.436079,   -0.401912,   0.256277,   -0.235594,   0.558754,   -0.507873,   -0.157723,   -1.08709,   -1.24056,   -0.717526,   0.931817,   0.440858,   -0.359523,   -0.356868,   -0.955342,   -0.898017,   -1.14174,   -0.754364,   0.52962,   -0.0386584,   -0.329822,   0.552272,   0.532207,   0.336979,   -0.135797,   0.347811,   -0.682906,   0.848631,   0.340697,   -1.11406,   1.26157

4.949861 GFLOPS


0

### Using the refbackend of torch-mlir (jit compilation)

In [19]:
jit_module = compile_and_load_on_refbackend(compiled)

In [27]:
logits = torch.from_numpy(jit_module.forward(example_input.numpy()))
logits

tensor([[-3.9137e-02,  1.1446e-01, -1.7968e+00, -1.2343e+00, -8.1900e-01,
          3.2396e-01, -2.1866e+00, -1.2877e+00, -1.9019e+00, -7.3148e-01,
          7.1643e-01, -1.6698e+00, -1.4515e+00, -1.2659e+00, -1.5797e+00,
         -1.0382e+00, -2.1478e-01, -2.0713e+00, -1.5538e+00, -1.2831e+00,
         -5.8318e-01,  1.6193e+00, -3.0488e-02, -4.8139e-01, -1.1298e+00,
         -3.6930e-01,  3.8818e-01,  5.7440e-02,  4.6316e-01, -2.7053e-01,
         -1.4319e+00, -7.5139e-01, -4.1541e-01, -1.8500e+00, -4.2063e-01,
         -1.1912e+00, -5.1930e-01, -1.9624e+00,  1.3662e+00, -1.1059e+00,
         -7.7725e-01, -2.0080e-02,  1.3349e-01,  1.3197e+00, -2.2508e-01,
          6.3489e-01, -1.1425e+00,  4.5811e-01, -8.9082e-01, -1.1984e+00,
         -1.0954e+00,  1.4283e+00,  4.6136e-01, -4.3548e-01, -3.3565e-01,
         -1.5134e+00, -9.2316e-01, -1.6104e+00, -1.0705e+00,  1.3485e+00,
          2.2440e-01, -8.4757e-01,  1.3267e+00,  9.8154e-01,  6.4895e-01,
         -2.2183e-01,  8.3419e-01, -1.

### Extra: Get PyTorch graph IR
https://pytorch.org/docs/stable/fx.html


In [28]:
from torch.fx import symbolic_trace

# Symbolic tracing frontend - captures the semantics of the model 
symbolic_traced : torch.fx.GraphModule = symbolic_trace(model)

# High level IR 
print(symbolic_traced.graph)

graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
    %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
    %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
    %add : [num_users=1] = call_function[target=ope

You can get the same MLIR rep using the pytorch graph IR:

In [29]:
compiled = torch_mlir.compile(symbolic_traced, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
print(compiled.operation.get_asm(large_elements_limit=10))

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0, d1) -> (d1, d0)>
#map5 = affine_map<(d0, d1) -> (0, d1)>
#map6 = affine_map<(d0, d1) -> (d1)>
module attributes {torch.debug_module_name = "ResNet"} {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @forward(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x1000xf32> {
    %false = arith.constant false
    %cst = arith.constant dense_resource<__elided__> : tensor<1000xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<1000x512xf32>
    %cst_1 = arith.constant dense_resource<__elided__> : tensor<512xf32>
    %cst_2 = arith.constant dense_resource<__elided__> : tensor<512xf32>
    %cst_3 = arith.constant dense_resource<__elided__> : tensor<512xf32>
    %cst_4 = arith.constant dense_resource<__elided__> : ten