In [1]:
import torch
import numpy

from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export

from torch_mlir.passmanager import PassManager
from air.backend import linalg_on_tensors as backend

In [2]:
SIZE = [128,128]

class MMult_Mult(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        (SIZE, torch.int32, True),
        (SIZE, torch.int32, True),
        (SIZE, torch.int32, True)
    ])
    def forward(self, a, b, c):
        x = torch.mm(b,c)
        y = a*x
        return y

program = MMult_Mult()
scripted = torch.jit.script(program)

In [3]:
class_annotator = ClassAnnotator()
extract_annotations(program, scripted, class_annotator)

mb = ModuleBuilder()
mb.import_module(scripted._c, class_annotator)

pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline', mb.module.context)
pm.run(mb.module)
print(mb.module)

#map = affine_map<(d0, d1) -> (d0, d1)>
module attributes {torch.debug_module_name = "MMult_Mult"} {
  func @forward(%arg0: tensor<128x128xi32>, %arg1: tensor<128x128xi32>, %arg2: tensor<128x128xi32>) -> tensor<?x?xi32> {
    %c0_i32 = arith.constant 0 : i32
    %0 = linalg.init_tensor [128, 128] : tensor<128x128xi32>
    %1 = linalg.fill(%c0_i32, %0) : i32, tensor<128x128xi32> -> tensor<128x128xi32> 
    %2 = linalg.matmul ins(%arg1, %arg2 : tensor<128x128xi32>, tensor<128x128xi32>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>
    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<128x128xi32>, tensor<128x128xi32>) outs(%0 : tensor<128x128xi32>) {
    ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):
      %5 = arith.muli %arg3, %arg4 : i32
      linalg.yield %5 : i32
    } -> tensor<128x128xi32>
    %4 = tensor.cast %3 : tensor<128x128xi32> to tensor<?x?xi32>
    return %4 : tensor<?x?xi32>
  }
}



In [4]:
airbackend = backend.LinalgOnTensorsAirBackend()

In [5]:
import torch_mlir.ir

import air.mlir.ir
import air.mlir.passmanager
import air.compiler.aircc.main as aircc

def compile(imported_module: torch_mlir.ir.Module):
    with air.mlir.ir.Context():
        air_module = air.mlir.ir.Module.parse(str(imported_module))
        
        # bufferize the linalg dialect
        pm = air.mlir.passmanager.PassManager.parse(air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE)
        pm.run(air_module)
        #print(air_module)
        
        # convert linalg dialect to air dialect
        LINALG_MEMREF_TO_AIR_PIPELINE = ",".join([
            "air-linalg-codegen",
            "canonicalize",
            "cse",
            "affine-to-air",
            "canonicalize",
            "cse"
        ])
        pm = air.mlir.passmanager.PassManager.parse(LINALG_MEMREF_TO_AIR_PIPELINE)
        pm.run(air_module)
        
        # print the air dialect mlir
        print(air_module)
        
        # run aircc to build the herds
        # the loader expects the output to be called 'torch.mlir.so'
        aircc.run(air_module,['--shared', '-o', 'torch.mlir.so', '--sysroot=/', '-row-offset=3', '-col-offset=20', 'torch.mlir'])
        
        # generate a torch-mlir refbackend interface to the AIR control program so
        # that we can reuse the refbackend's jit and object loader on the cpu.
        with open('air_project/refback.torch.mlir') as f:
            return_module = torch_mlir.ir.Module.parse(f.read(),imported_module.context)
        return airbackend.refbackend.compile(return_module)

In [6]:
compiled = compile(mb.module)

#map0 = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<()[s0] -> (s0 * 64)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
module attributes {torch.debug_module_name = "MMult_Mult"} {
  func @forward(%arg0: memref<128x128xi32>, %arg1: memref<128x128xi32>, %arg2: memref<128x128xi32>) -> memref<?x?xi32> {
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = memref.alloc() : memref<128x128xi32>
    linalg.fill(%c0_i32, %0) : i32, memref<128x128xi32> 
    %1 = memref.alloc() : memref<128x128xi32>
    linalg.copy(%0, %1) : memref<128x128xi32>, memref<128x128xi32> 
    air.launch_herd tile (%arg3, %arg4) in (%arg5=%c4, %arg6=%c4) args(%arg7=%arg1, %arg8=%arg2, %arg9=%1) : memref<128x128xi32>, memref<128x128xi32>, memref<128x128xi32> attributes {sym_name = "herd_0"} {
      %c1024 = arith.constant 1024 : index
      %c32 = arith.constant 32 : index
      %c128 = arith.constant 128 : index
      %c0 = arith.constant 0 : index
      %

In [7]:
jit_module = airbackend.load(compiled)

In [8]:
a = torch.randint(100, SIZE, dtype=torch.int32)
b = torch.randint(100, SIZE, dtype=torch.int32)
c = torch.randint(100, SIZE, dtype=torch.int32)

# run the model on the device
o = jit_module.forward(a.numpy(),b.numpy(),c.numpy())

# print the results
d = torch.tensor(o)    
print(f"input:\n{a}\n{b}\n{c}\noutput:\n{d}")

metal: info:      Registered shmem provider linux_shm.
metal: info:      Registered shmem provider ion.reserved.
metal: info:      Registered shmem provider ion.ion_system_contig_heap.
metal: info:      Registered shmem provider ion.ion_system_heap.
metal: info:      device xilinx-aiengine in use by driver uio_dmem_genirq
metal: info:      metal_uio_dev_open: No IRQ for device f70a0000.aie-npi.


input:
tensor([[21, 48, 86,  ..., 93, 28, 79],
        [ 2, 66, 76,  ..., 71, 78, 87],
        [ 8, 71,  9,  ..., 86, 84, 81],
        ...,
        [75, 59, 18,  ..., 73, 91, 91],
        [47,  1, 19,  ..., 11,  8, 63],
        [86, 73,  3,  ..., 36, 75, 99]], dtype=torch.int32)
tensor([[16, 21, 59,  ..., 23, 83, 84],
        [30, 42, 42,  ..., 97, 27, 25],
        [24, 92, 67,  ..., 70, 20, 66],
        ...,
        [15, 18, 43,  ..., 39,  6, 95],
        [64, 88, 63,  ..., 28,  6, 36],
        [52,  5, 73,  ..., 33, 75, 86]], dtype=torch.int32)
tensor([[ 1, 19, 88,  ..., 28,  8, 16],
        [56, 94, 95,  ..., 90, 49, 98],
        [93, 95, 48,  ..., 86, 97, 35],
        ...,
        [12, 27, 20,  ..., 48, 38, 36],
        [99, 23, 69,  ..., 60, 60, 15],
        [47, 91, 71,  ..., 57, 11, 29]], dtype=torch.int32)
output:
tensor([[ 6003564, 14551056, 24231876,  ..., 27490893,  8669500, 23371597],
        [  582832, 21115446, 23704932,  ..., 20622376, 25264200, 26259906],
        [ 2669

In [9]:
# check the results
if torch.equal(a*torch.mm(b,c),d):
    print("PASS!")
else:
    print("failed.")

PASS!
