In [1]:
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
from tvm import relax

In [3]:
import sys
import os

# Equivalent without using __file__, assuming you're in notebooks/
project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from AllocationFinder import AllocationFinder

# Relax Model Definition

In [4]:
class RelaxMnist(relax.frontend.nn.Module):
    def __init__(self):
        super(RelaxMnist, self).__init__()
        self.conv1 = relax.frontend.nn.Conv2D(3, 32, kernel_size=5, stride=1, padding=2, bias=True)
        self.relu1 = relax.frontend.nn.ReLU()
        self.conv2 = relax.frontend.nn.Conv2D(32, 64, kernel_size=5, stride=1, padding=2, bias=True)
        self.relu2 = relax.frontend.nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x


input_shape = (1, 3, 128, 128)
rconv_mod, rconv_params = RelaxMnist().export_tvm({"forward": {"x": relax.frontend.nn.spec.Tensor(input_shape, "float32")}})
rconv_mod.show()

In [5]:
transforms = [
    # # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
    relax.transform.LegalizeOps(),
    relax.transform.AnnotateTIROpPattern(),
    relax.transform.FoldConstant(),
    relax.transform.FuseOps(),
    relax.transform.FuseTIR(),
]

new_mod = rconv_mod
for t in transforms:
    new_mod = t(new_mod)

new_mod.show()

In [6]:
import tvm
tvm.ir.module.IRModule
new_mod

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_conv2d1_add1_relu1(relu: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(128)), "float32"), conv2_weight: T.Buffer((T.int64(64), T.int64(32), T.int64(5), T.int64(5)), "float32"), lv3: T.Buffer((T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(64), T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(132), T.int64(132)))
        conv2d_nchw_intermediate = T.alloc_buffer((T.int64(1), T.int64(64), T.int64(128), T.int64(128)))
        T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(64), T.int64(128), T.int64(128)))
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(132), T.int64(

In [7]:
from AllocationFinder import INSIGHT

alloc_finder = AllocationFinder(new_mod)
alloc_finder.walk()

In [8]:
alloc_finder.id_to_memblock

{'d007b287': MemBlock(x:d007b287, shape=(1, 3, 128, 128), dtype=float32, size=196608, origin=relax.input),
 'ce6abce4': MemBlock(conv1_weight:ce6abce4, shape=(32, 3, 5, 5), dtype=float32, size=9600, origin=relax.input),
 '4523b3ab': MemBlock(conv1_bias:4523b3ab, shape=(32,), dtype=float32, size=128, origin=relax.input),
 '5e992f67': MemBlock(conv2_weight:5e992f67, shape=(64, 32, 5, 5), dtype=float32, size=204800, origin=relax.input),
 '56ee3aa3': MemBlock(conv2_bias:56ee3aa3, shape=(64,), dtype=float32, size=256, origin=relax.input),
 '516d9447': MemBlock(lv1:516d9447, shape=(1, 32, 1, 1), dtype=float32, size=128, origin=relax.call_tir.reshape),
 '32d0cbd3': MemBlock(lv:32d0cbd3, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=relax.call_tir.fused_conv2d_add_relu),
 'a6dcc9de': MemBlock(lv3:a6dcc9de, shape=(1, 64, 1, 1), dtype=float32, size=256, origin=relax.call_tir.reshape1),
 '7abdc45b': MemBlock(gv:7abdc45b, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=

In [9]:
for func, mbs in alloc_finder.memblocks.items():
    print(f"Function: {func}")
    for mb in mbs:
        print("   ", mb)

Function: forward
    MemBlock(x:d007b287, shape=(1, 3, 128, 128), dtype=float32, size=196608, origin=relax.input)
    MemBlock(conv1_weight:ce6abce4, shape=(32, 3, 5, 5), dtype=float32, size=9600, origin=relax.input)
    MemBlock(conv1_bias:4523b3ab, shape=(32,), dtype=float32, size=128, origin=relax.input)
    MemBlock(conv2_weight:5e992f67, shape=(64, 32, 5, 5), dtype=float32, size=204800, origin=relax.input)
    MemBlock(conv2_bias:56ee3aa3, shape=(64,), dtype=float32, size=256, origin=relax.input)
    MemBlock(lv1:516d9447, shape=(1, 32, 1, 1), dtype=float32, size=128, origin=relax.call_tir.reshape)
    MemBlock(lv:32d0cbd3, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=relax.call_tir.fused_conv2d_add_relu)
    MemBlock(lv3:a6dcc9de, shape=(1, 64, 1, 1), dtype=float32, size=256, origin=relax.call_tir.reshape1)
    MemBlock(gv:7abdc45b, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=relax.call_tir.fused_conv2d1_add1_relu1)
Function: fused_conv2d1_add1_r

In [10]:
for mb in alloc_finder.memblocks["fused_conv2d_add_relu"]:
    print(f"{mb} ")
    print(f"    - Depends on {mb.depends_on}")
    print(f"    - Links to {mb.links_to}")

MemBlock(pad_temp:d8d42eab, shape=(1, 3, 132, 132), dtype=float32, size=209088, origin=tir.fused_conv2d_add_relu.pad_temp) 
    - Depends on [MemBlock(x:d007b287, shape=(1, 3, 128, 128), dtype=float32, size=196608, origin=relax.input)]
    - Links to [MemBlock(conv2d_nchw_intermediate:eba01b42, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=tir.fused_conv2d_add_relu.conv2d_nchw_intermediate)]
MemBlock(conv2d_nchw_intermediate:eba01b42, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=tir.fused_conv2d_add_relu.conv2d_nchw_intermediate) 
    - Depends on [MemBlock(pad_temp:d8d42eab, shape=(1, 3, 132, 132), dtype=float32, size=209088, origin=tir.fused_conv2d_add_relu.pad_temp), MemBlock(conv1_weight:ce6abce4, shape=(32, 3, 5, 5), dtype=float32, size=9600, origin=relax.input)]
    - Links to [MemBlock(T_add_intermediate:d93b7024, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=tir.fused_conv2d_add_relu.T_add_intermediate)]
MemBlock(T_add_intermediate:

In [11]:
for mb in alloc_finder.memblocks["fused_conv2d1_add1_relu1"]:
    print(f"{mb} ")
    print(f"    - Depends on {mb.depends_on}")
    print(f"    - Links to {mb.links_to}")

MemBlock(pad_temp:de3aec92, shape=(1, 32, 132, 132), dtype=float32, size=2230272, origin=tir.fused_conv2d1_add1_relu1.pad_temp) 
    - Depends on [MemBlock(lv:32d0cbd3, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=relax.call_tir.fused_conv2d_add_relu)]
    - Links to [MemBlock(conv2d_nchw_intermediate:f16efbd9, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=tir.fused_conv2d1_add1_relu1.conv2d_nchw_intermediate)]
MemBlock(conv2d_nchw_intermediate:f16efbd9, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=tir.fused_conv2d1_add1_relu1.conv2d_nchw_intermediate) 
    - Depends on [MemBlock(pad_temp:de3aec92, shape=(1, 32, 132, 132), dtype=float32, size=2230272, origin=tir.fused_conv2d1_add1_relu1.pad_temp), MemBlock(conv2_weight:5e992f67, shape=(64, 32, 5, 5), dtype=float32, size=204800, origin=relax.input)]
    - Links to [MemBlock(T_add_intermediate:f2dceb7f, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=tir.fused_conv2d1_add1_relu1

In [12]:
for mb in alloc_finder.memblocks["forward"]:
    print(f"{mb} ")
    print(f"    - Depends on {mb.depends_on}")
    print(f"    - Links to {mb.links_to}")


MemBlock(x:d007b287, shape=(1, 3, 128, 128), dtype=float32, size=196608, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(pad_temp:d8d42eab, shape=(1, 3, 132, 132), dtype=float32, size=209088, origin=tir.fused_conv2d_add_relu.pad_temp)]
MemBlock(conv1_weight:ce6abce4, shape=(32, 3, 5, 5), dtype=float32, size=9600, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(conv2d_nchw_intermediate:eba01b42, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=tir.fused_conv2d_add_relu.conv2d_nchw_intermediate)]
MemBlock(conv1_bias:4523b3ab, shape=(32,), dtype=float32, size=128, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(lv1:516d9447, shape=(1, 32, 1, 1), dtype=float32, size=128, origin=relax.call_tir.reshape)]
MemBlock(conv2_weight:5e992f67, shape=(64, 32, 5, 5), dtype=float32, size=204800, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(conv2d_nchw_intermediate:f16efbd9, shape=(1, 64, 128, 128), dtype=float32, siz

In [13]:
from MinizincGenerator import generate_minizinc_model


mzn_code = generate_minizinc_model(alloc_finder)
with open("../generated_mzn/model.mzn", "w") as f:
    f.write(mzn_code)