# Imports

In [1]:
from tvm import relax
from AllocationFinder import AllocationFinder

# Relax Model Definition

In [2]:
class RelaxMnist(relax.frontend.nn.Module):
    def __init__(self):
        super(RelaxMnist, self).__init__()
        self.conv1 = relax.frontend.nn.Conv2D(3, 32, kernel_size=5, stride=1, padding=2, bias=True)
        self.relu1 = relax.frontend.nn.ReLU()
        self.conv2 = relax.frontend.nn.Conv2D(32, 64, kernel_size=5, stride=1, padding=2, bias=True)
        self.relu2 = relax.frontend.nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x


input_shape = (1, 3, 128, 128)
rconv_mod, rconv_params = RelaxMnist().export_tvm({"forward": {"x": relax.frontend.nn.spec.Tensor(input_shape, "float32")}})
rconv_mod.show()

In [3]:
transforms = [
    # # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
    relax.transform.LegalizeOps(),
    relax.transform.AnnotateTIROpPattern(),
    relax.transform.FoldConstant(),
    relax.transform.FuseOps(),
    relax.transform.FuseTIR(),
]

new_mod = rconv_mod
for t in transforms:
    new_mod = t(new_mod)

new_mod.show()

In [4]:
alloc_finder = AllocationFinder(new_mod)
alloc_finder.walk()

for func, mbs in alloc_finder.memblocks.items():
    print(f"Function: {func}")
    for mb in mbs:
        print("   ", mb)

Function: reshape1
    MemBlock(T_reshape:8482231a, shape=(1, 64, 1, 1), dtype=float32, size=256, origin=reshape1)
Function: fused_conv2d1_add1_relu1
    MemBlock(pad_temp:2f37fa5a, shape=(1, 32, 132, 132), dtype=float32, size=2230272, origin=fused_conv2d1_add1_relu1)
    MemBlock(conv2d_nchw_intermediate:f04ff303, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=fused_conv2d1_add1_relu1)
    MemBlock(T_add_intermediate:805d1ebb, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=fused_conv2d1_add1_relu1)
    MemBlock(compute_intermediate:fac748c1, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=fused_conv2d1_add1_relu1)
Function: fused_conv2d_add_relu
    MemBlock(pad_temp:9066512c, shape=(1, 3, 132, 132), dtype=float32, size=209088, origin=fused_conv2d_add_relu)
    MemBlock(conv2d_nchw_intermediate:83fb3d4d, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=fused_conv2d_add_relu)
    MemBlock(T_add_intermediate:91939108, shape=(1, 32, 12

In [7]:
for mb in alloc_finder.memblocks["fused_conv2d_add_relu"]:
    print(f"{mb} ")
    print(f"    - Depends on {mb.depends_on}")
    print(f"    - Links to {mb.links_to}")

MemBlock(pad_temp:9066512c, shape=(1, 3, 132, 132), dtype=float32, size=209088, origin=fused_conv2d_add_relu) 
    - Depends on []
    - Links to [MemBlock(conv2d_nchw_intermediate:83fb3d4d, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=fused_conv2d_add_relu)]
MemBlock(conv2d_nchw_intermediate:83fb3d4d, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=fused_conv2d_add_relu) 
    - Depends on [MemBlock(pad_temp:9066512c, shape=(1, 3, 132, 132), dtype=float32, size=209088, origin=fused_conv2d_add_relu)]
    - Links to [MemBlock(T_add_intermediate:91939108, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=fused_conv2d_add_relu)]
MemBlock(T_add_intermediate:91939108, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=fused_conv2d_add_relu) 
    - Depends on [MemBlock(conv2d_nchw_intermediate:83fb3d4d, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=fused_conv2d_add_relu)]
    - Links to [MemBlock(compute_intermediate:16375aa8, s

In [8]:
for mb in alloc_finder.memblocks["forward"]:
    print(f"{mb} ")
    print(f"    - Depends on {mb.depends_on}")
    print(f"    - Links to {mb.links_to}")


MemBlock(x:8bb5d9e1, shape=(1, 3, 128, 128), dtype=float32, size=196608, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(lv:d814e86e, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=relax.call_tir.fused_conv2d_add_relu)]
MemBlock(conv1_weight:6f62c1cf, shape=(32, 3, 5, 5), dtype=float32, size=9600, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(lv:d814e86e, shape=(1, 32, 128, 128), dtype=float32, size=2097152, origin=relax.call_tir.fused_conv2d_add_relu)]
MemBlock(conv1_bias:daf698da, shape=(32,), dtype=float32, size=128, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(lv1:cd457789, shape=(1, 32, 1, 1), dtype=float32, size=128, origin=relax.call_tir.reshape)]
MemBlock(conv2_weight:65b9869c, shape=(64, 32, 5, 5), dtype=float32, size=204800, origin=relax.input) 
    - Depends on []
    - Links to [MemBlock(gv:bfb394eb, shape=(1, 64, 128, 128), dtype=float32, size=4194304, origin=relax.call_tir.fused_conv2d1_add1_relu1)]
