In [1]:
import torch
import torch._dynamo as dynamo

from air.backend import linalg_on_tensors as backend

In [2]:
# customize the linalg to air pass pipeline
pipeline = "builtin.module("+",".join([
    "air-linalg-name",
    "air-linalg-codegen{input-filter=linalg.matmul1 herd-size=8,2 l1-tile-size=16,64,32}",
    "air-linalg-codegen{input-filter=linalg.generic2 herd-size=8,1 l1-tile-size=16,128,32}",
    "air-rm-linalg-name",
    "canonicalize", "cse",
    "air-par-to-herd",
    "air-copy-to-dma",
    "canonicalize", "cse"
])+")"
air_backend = backend.make_dynamo_backend(pipeline, partition_offset=[20,3])

In [3]:
SIZE = [128,128]
DTYPE = torch.int32

class MMult_Mult(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b, c):
        x = torch.mm(b,c)
        y = a*x
        return y

In [4]:
a = torch.randint(100, SIZE, dtype=DTYPE)
b = torch.randint(100, SIZE, dtype=DTYPE)
c = torch.randint(100, SIZE, dtype=DTYPE)

model = MMult_Mult()

In [5]:
# run the model on the device
dynamo_model = dynamo.optimize(air_backend)(model)
result = dynamo_model(a, b, c)

# print the results
print(f"input:\n{a}\n{b}\n{c}\noutput:\n{result}")

 AIE Compilation: ━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:04 17/17 1 Workers
 AIE Compilation: ━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:01 9/9 1 Workers
input:
tensor([[59, 73, 74,  ..., 13, 26, 29],
        [24, 41, 86,  ..., 23, 72, 59],
        [32, 76, 84,  ..., 42, 82, 94],
        ...,
        [44, 33,  5,  ..., 37, 33, 57],
        [54, 87, 93,  ..., 88, 29, 80],
        [38, 26, 36,  ..., 23, 41, 96]], dtype=torch.int32)
tensor([[91, 29, 81,  ..., 89, 52, 72],
        [23, 76, 45,  ..., 28, 80, 68],
        [84, 32, 60,  ..., 78, 44, 66],
        ...,
        [46, 50, 74,  ..., 45, 62, 74],
        [ 5, 13, 66,  ..., 41, 92, 19],
        [96,  6, 75,  ..., 59, 12, 89]], dtype=torch.int32)
tensor([[56, 93, 53,  ..., 88, 93, 31],
        [60, 60, 90,  ..., 64, 18, 86],
        [30, 19, 89,  ..., 49, 28, 43],
        ...,
        [73, 63, 84,  ..., 87,  6, 52],
        [59, 39, 73,  ..., 93, 81, 15],
        [24, 69, 83,  ..., 19, 40, 44]], dtype=torch.int32)
output:
tens

In [6]:
# check the results
if torch.equal(a*torch.mm(b,c),result):
    print("PASS!")
else:
    print("failed.")

PASS!


Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.

SPDX-License-Identifier: MIT