In [None]:
import air.compiler.util

from air.mlir.dialects import func
from air.mlir.dialects import linalg
from air.mlir.ir import *
import air.mlir.passmanager

import sys

In [None]:
def matmul_on_tensors(m, n, k, dtype):
    module = Module.create()
    with InsertionPoint(module.body):
        @func.FuncOp.from_py_func(
            RankedTensorType.get((m, k), dtype), RankedTensorType.get((k, n), dtype),
            RankedTensorType.get((m, n), dtype))
        def matmul(lhs, rhs, out):
            linalg.matmul(lhs, rhs, outs=[out])
    return module

In [None]:
with air.mlir.ir.Context(), Location.unknown():

    air_module = matmul_on_tensors(512, 512, 512, BF16Type.get())
    
    # convert linalg on tensors to linalg on memrefs
    pm = air.mlir.passmanager.PassManager.parse(air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE)
    pm.run(air_module)

    # tile and map to air
    pipeline = ",".join([
        "air-linalg-codegen{l1-tile-size=32,32,32 l1-promote=true l2-tile-size=64,64,64 l2-promote=true}",
        "affine-to-air{herd-assign-depth=1}",
        "canonicalize", "cse",
    ])
    pm = air.mlir.passmanager.PassManager.parse(pipeline)
    pm.run(air_module)
    
    #print ("\nAIR Dialect Module\n")
    #print (air_module)

    # generate dependency information for runner
    pm = air.mlir.passmanager.PassManager.parse("air-dependency,canonicalize,cse")
    pm.run(air_module)

    print ("\nAIR Dialect Module (async)\n")
    print (air_module)

In [None]:
arch = {
  "clock": 1000000000,
  "cores": 1,
  "datatype": {
    "bytes": 2,
    "name": "fp16"
  },
  "devicename": "testdevice",
  "interfaces": [
    {
      "bytes_per_second": 100000000000,
      "dst": 1,
      "src": 0
    },
    {
      "bytes_per_second": 100000000000,
      "dst": 0,
      "src": 1
    },
    {
      "bytes_per_second": 100000000000,
      "dst": 2,
      "src": 0
    },
    {
      "bytes_per_second": 100000000000,
      "dst": 0,
      "src": 2
    },
    {
      "bytes_per_second": 100000000000,
      "dst": 2,
      "src": 1
    },
    {
      "bytes_per_second": 100000000000,
      "dst": 1,
      "src": 2
    }
  ],
  "kernels": {
    "linalg.copy": {
      "efficiency": 1,
      "name": "linalg.copy"
    },
    "linalg.fill": {
      "efficiency": 1,
      "name": "linalg.fill"
    },
    "linalg.matmul": {
      "efficiency": 1,
      "name": "linalg.matmul"
    }
  },
  "ops_per_core_per_cycle": 512,
  "num_herd_slots": 1,
  "num_dispatch_queues": 1
}


In [None]:
runner = air.compiler.util.Runner(arch)
trace = runner.run(air_module, "matmul")

In [None]:
arch["num_herd_slots"] = 4
arch["num_dispatch_queues"] = 8
runner = air.compiler.util.Runner(arch)
trace = runner.run(air_module, "matmul")

In [None]:
with open("/work/trace.out", "w") as f:
   f.write(trace)