In [1]:
import torch
import torch_mlir
import numpy

from air.backend import linalg_on_tensors as backend

In [2]:
SIZE = [128,128]
DTYPE = torch.int32

class MMult_Mult(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b, c):
        x = torch.mm(b,c)
        y = a*x
        return y

program = MMult_Mult()

In [3]:
module = torch_mlir.compile(
    program,
    (torch.ones(SIZE, dtype=DTYPE), torch.ones(SIZE, dtype=DTYPE), torch.ones(SIZE, dtype=DTYPE)),
    output_type=torch_mlir.OutputType.LINALG_ON_TENSORS
)

In [4]:
airbackend = backend.LinalgOnTensorsAirBackend()

In [5]:
import torch_mlir.ir

import air.mlir.ir
import air.mlir.passmanager
import air.compiler.aircc.main as aircc

def compile(imported_module: torch_mlir.ir.Module):
    with air.mlir.ir.Context():
        air_module = air.mlir.ir.Module.parse(str(imported_module))
        
        # bufferize the linalg dialect
        pm = air.mlir.passmanager.PassManager.parse(air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE)
        pm.run(air_module)
        
        # convert linalg dialect to air dialect
        #LINALG_MEMREF_TO_AIR_PIPELINE = ",".join([
        #    "air-linalg-codegen",
        #    "canonicalize",
        #    "cse",
        #    "affine-to-air",
        #    "canonicalize",
        #    "cse"
        #])
        # CUSTOM: convert linalg dialect to air dialect
        LINALG_MEMREF_TO_AIR_PIPELINE = ",".join([
            "air-linalg-name",
            "air-linalg-codegen{input-filter=linalg.matmul1 herd-size=8,2 l1-tile-size=16,64,32}",
            "air-linalg-codegen{input-filter=linalg.generic2 herd-size=8,1 l1-tile-size=16,128,32}",
            "air-rm-linalg-name",
            "canonicalize",
            "cse",
            "air-par-to-herd",
            "air-copy-to-dma",
            "canonicalize",
            "cse"
        ])
        
        pm = air.mlir.passmanager.PassManager.parse(LINALG_MEMREF_TO_AIR_PIPELINE)
        pm.run(air_module)
        
        # print the air dialect mlir
        print(air_module)
        
        # run aircc to build the herds
        # the loader expects the output to be called 'torch.mlir.so'
        aircc.run(air_module,['--shared', '-o', 'torch.mlir.so', '--sysroot=/', '-row-offset=3', '-col-offset=20', 'torch.mlir'])
        
        # generate a torch-mlir refbackend interface to the AIR control program so
        # that we can reuse the refbackend's jit and object loader on the cpu.
        with open('air_project/refback.torch.mlir') as f:
            return_module = torch_mlir.ir.Module.parse(f.read(),imported_module.context)
        return airbackend.refbackend.compile(return_module)

In [6]:
compiled = compile(module)

#map0 = affine_map<()[s0] -> (s0 * 16)>
#map1 = affine_map<()[s0] -> (s0 * 64)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
module attributes {torch.debug_module_name = "MMult_Mult"} {
  func.func @forward(%arg0: memref<128x128xi32>, %arg1: memref<128x128xi32>, %arg2: memref<128x128xi32>) -> memref<128x128xi32> {
    %c2 = arith.constant 2 : index
    %c1 = arith.constant 1 : index
    %c8 = arith.constant 8 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = memref.alloc() {alignment = 128 : i64} : memref<128x128xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%0 : memref<128x128xi32>)
    %1 = memref.alloc() {alignment = 128 : i64} : memref<128x128xi32>
    memref.copy %0, %1 : memref<128x128xi32> to memref<128x128xi32>
    air.herd @herd_0  tile (%arg3, %arg4) in (%arg5=%c8, %arg6=%c2) args(%arg7=%arg1, %arg8=%arg2, %arg9=%1) : memref<128x128xi32>, memref<128x128xi32>, memref<128x128xi32> {
      %c64 = arith.constant 64 : index
      %c1_0 = arith.constant 1 : index
      %c16 = arith.

In [7]:
jit_module = airbackend.load(compiled)

In [8]:
a = torch.randint(100, SIZE, dtype=torch.int32)
b = torch.randint(100, SIZE, dtype=torch.int32)
c = torch.randint(100, SIZE, dtype=torch.int32)

# run the model on the device
o = jit_module.forward(a.numpy(),b.numpy(),c.numpy())

# print the results
d = torch.tensor(o)    
print(f"input:\n{a}\n{b}\n{c}\noutput:\n{d}")

input:
tensor([[32, 32, 78,  ..., 51, 17, 98],
        [54, 59, 11,  ..., 34, 58,  5],
        [15, 60, 31,  ...,  9, 45, 61],
        ...,
        [35, 12, 76,  ..., 23, 69, 97],
        [27,  0, 23,  ..., 58, 56, 41],
        [99, 17, 11,  ..., 42, 53, 29]], dtype=torch.int32)
tensor([[83, 78,  1,  ..., 35, 75, 77],
        [92, 21,  2,  ..., 99, 15,  7],
        [18, 82, 98,  ..., 61, 68, 35],
        ...,
        [55, 52, 55,  ..., 69, 72,  0],
        [ 4, 23, 49,  ..., 34, 71, 88],
        [83, 71, 75,  ..., 98, 76, 19]], dtype=torch.int32)
tensor([[75, 28, 30,  ..., 51, 92, 65],
        [28,  7, 73,  ..., 20, 73, 61],
        [66, 79, 94,  ..., 13, 72, 78],
        ...,
        [47,  7, 12,  ..., 80, 71, 25],
        [28, 16, 85,  ...,  9, 54, 76],
        [93, 64, 18,  ..., 53, 19, 54]], dtype=torch.int32)
output:
tensor([[11425216, 10173472, 22450350,  ..., 15993039,  5420433, 31370094],
        [19196784, 18855869,  3056317,  ..., 10782454, 18140776,  1599750],
        [ 5276

In [9]:
# check the results
if torch.equal(a*torch.mm(b,c),d):
    print("PASS!")
else:
    print("failed.")

PASS!


Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.

SPDX-License-Identifier: MIT