In [1]:
import torch
import numpy

from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export

from torch_mlir.passmanager import PassManager
from air.backend import linalg_on_tensors as backend

In [2]:
SIZE = [128,128]

class MMult_Mult(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        (SIZE, torch.int32, True),
        (SIZE, torch.int32, True),
        (SIZE, torch.int32, True)
    ])
    def forward(self, a, b, c):
        x = torch.mm(b,c)
        y = a*x
        return y

program = MMult_Mult()
scripted = torch.jit.script(program)

In [3]:
class_annotator = ClassAnnotator()
extract_annotations(program, scripted, class_annotator)

mb = ModuleBuilder()
mb.import_module(scripted._c, class_annotator)

pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline', mb.module.context)
pm.run(mb.module)
print(mb.module)

#map = affine_map<(d0, d1) -> (d0, d1)>
module attributes {torch.debug_module_name = "MMult_Mult"} {
  func @forward(%arg0: tensor<128x128xi32>, %arg1: tensor<128x128xi32>, %arg2: tensor<128x128xi32>) -> tensor<?x?xi32> {
    %c0_i32 = arith.constant 0 : i32
    %0 = linalg.init_tensor [128, 128] : tensor<128x128xi32>
    %1 = linalg.fill(%c0_i32, %0) : i32, tensor<128x128xi32> -> tensor<128x128xi32> 
    %2 = linalg.matmul ins(%arg1, %arg2 : tensor<128x128xi32>, tensor<128x128xi32>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>
    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<128x128xi32>, tensor<128x128xi32>) outs(%0 : tensor<128x128xi32>) {
    ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):
      %5 = arith.muli %arg3, %arg4 : i32
      linalg.yield %5 : i32
    } -> tensor<128x128xi32>
    %4 = tensor.cast %3 : tensor<128x128xi32> to tensor<?x?xi32>
    return %4 : tensor<?x?xi32>
  }
}



In [4]:
airbackend = backend.LinalgOnTensorsAirBackend()
compiled = airbackend.compile(mb.module, verbose=False)

In [5]:
jit_module = airbackend.load(compiled)

metal: info:      Registered shmem provider linux_shm.
metal: info:      Registered shmem provider ion.reserved.
metal: info:      Registered shmem provider ion.ion_system_contig_heap.
metal: info:      Registered shmem provider ion.ion_system_heap.
metal: info:      device xilinx-aiengine in use by driver uio_dmem_genirq
metal: info:      metal_uio_dev_open: No IRQ for device f70a0000.aie-npi.


In [6]:
a = torch.randint(100, SIZE, dtype=torch.int32)
b = torch.randint(100, SIZE, dtype=torch.int32)
c = torch.randint(100, SIZE, dtype=torch.int32)

# run the model on the device
o = jit_module.forward(a.numpy(),b.numpy(),c.numpy())

# print the results
d = torch.tensor(o)    
print(f"input:\n{a}\n{b}\n{c}\noutput:\n{d}")

input:
tensor([[39, 59, 61,  ..., 14, 59, 68],
        [69, 56,  7,  ..., 97, 77, 59],
        [33, 80, 10,  ..., 70, 14, 17],
        ...,
        [84, 21, 87,  ..., 59, 33,  9],
        [98, 67, 59,  ..., 65, 99, 44],
        [83,  7, 93,  ..., 86, 59, 50]], dtype=torch.int32)
tensor([[89, 12, 52,  ..., 61, 26, 69],
        [31, 15, 75,  ..., 33, 28, 96],
        [ 0, 21,  9,  ...,  3,  8, 35],
        ...,
        [ 5, 10, 59,  ..., 25, 25, 77],
        [50,  8, 97,  ..., 38, 53, 52],
        [35, 64, 24,  ..., 97, 98, 81]], dtype=torch.int32)
tensor([[64, 43, 88,  ..., 62, 64, 99],
        [20, 98, 80,  ..., 74,  5, 14],
        [98, 62,  6,  ..., 65, 78, 19],
        ...,
        [17, 37, 80,  ..., 88, 79, 63],
        [41, 56, 57,  ..., 84,  7, 80],
        [14, 52, 19,  ..., 66, 26, 64]], dtype=torch.int32)
output:
tensor([[14096355, 19112637, 19885573,  ...,  4829972, 20945649, 26494704],
        [19625325, 14867888,  1855812,  ..., 27839582, 21603351, 16864560],
        [10450

In [7]:
# check the results
if torch.equal(a*torch.mm(b,c),d):
    print("PASS!")
else:
    print("failed.")

PASS!
