In [1]:
import torch
import torch_mlir
import numpy

from air.backend import linalg_on_tensors as backend

In [2]:
SIZE = [128,128]
DTYPE = torch.int32

class MMult_Mult(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b, c):
        x = torch.mm(b,c)
        y = a*x
        return y

program = MMult_Mult()

In [3]:
module = torch_mlir.compile(
    program,
    (torch.ones(SIZE, dtype=DTYPE), torch.ones(SIZE, dtype=DTYPE), torch.ones(SIZE, dtype=DTYPE)),
    output_type=torch_mlir.OutputType.LINALG_ON_TENSORS
)

In [4]:
airbackend = backend.LinalgOnTensorsAirBackend()
compiled = airbackend.compile(module, verbose=False)

In [5]:
jit_module = airbackend.load(compiled)

In [6]:
a = torch.randint(100, SIZE, dtype=torch.int32)
b = torch.randint(100, SIZE, dtype=torch.int32)
c = torch.randint(100, SIZE, dtype=torch.int32)

# run the model on the device
o = jit_module.forward(a.numpy(),b.numpy(),c.numpy())

# print the results
d = torch.tensor(o)    
print(f"input:\n{a}\n{b}\n{c}\noutput:\n{d}")

input:
tensor([[50, 65,  6,  ..., 98,  2, 27],
        [66, 70, 66,  ..., 55, 67, 77],
        [59, 23, 18,  ..., 63, 87, 45],
        ...,
        [21, 50, 63,  ..., 76, 24, 47],
        [30, 86, 64,  ..., 34, 31, 82],
        [37, 99, 12,  ..., 63, 92,  1]], dtype=torch.int32)
tensor([[49, 94, 59,  ..., 98, 79, 48],
        [85, 39, 53,  ...,  4, 68, 69],
        [75, 11, 64,  ...,  0, 98, 33],
        ...,
        [86, 51, 78,  ..., 75, 36,  7],
        [ 0, 61, 37,  ..., 70, 65, 20],
        [ 4, 44, 59,  ..., 88, 53, 72]], dtype=torch.int32)
tensor([[45, 71, 53,  ..., 80, 18, 21],
        [23, 36, 17,  ..., 45, 80, 18],
        [36,  1, 70,  ..., 45, 88, 68],
        ...,
        [85, 26, 81,  ...,  9, 46,  2],
        [98, 28, 43,  ..., 62, 52, 66],
        [24, 39, 38,  ..., 64, 73, 40]], dtype=torch.int32)
output:
tensor([[16198150, 21693295,  1865142,  ..., 35965412,   690436,  7566993],
        [19832274, 21378210, 19271736,  ..., 17902940, 20828692, 20548143],
        [17618

In [7]:
# check the results
if torch.equal(a*torch.mm(b,c),d):
    print("PASS!")
else:
    print("failed.")

PASS!


Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.