In [1]:
import dace
from dace.transformation.interstate import GPUTransformSDFG, StateFusion
from dace.transformation.dataflow import MapTiling, InLocalStorage, MapExpansion, MapCollapse, MapReduceFusion
from dace.transformation.optimizer import Optimizer
from dace.transformation import helpers as xfutil

def find_map_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry:
    """ Finds the first map entry node by the given parameter name. """
    return next((n, state) for n, state in sdfg.all_nodes_recursive()
                if isinstance(n, dace.nodes.MapEntry) and pname in n.params)


def find_map_by_name(sdfg: dace.SDFG, name: str) -> dace.nodes.MapEntry:
    """ Finds the first map entry node by the given parameter name. """
    return next((n, state) for n, state in sdfg.all_nodes_recursive()
                if isinstance(n, dace.nodes.MapEntry) and name == n.label)

M = dace.symbol('M')
N = dace.symbol('N')
K = dace.symbol('K')
alpha = dace.symbol('alpha')
beta = dace.symbol('beta')

@dace.program
def matmul(A: dace.float64[M, K], B: dace.float64[K, N], C: dace.float64[M, N], alpha: dace.float64, beta: dace.float64):
    return alpha * (A @ B) + beta * C

M = 640
N = 640
K = 640
beta = 1
alpha = 1

In [2]:
sdfg = matmul.to_sdfg()
sdfg

Applied 4 StateFusion.


In [3]:
sdfg.expand_library_nodes()
sdfg

Automatically expanded library node "_MatMult_" with implementation "specialize".
Automatically expanded library node "_MatMult_gemm" with implementation "pure".


In [4]:
sdfg.apply_transformations(GPUTransformSDFG)
sdfg

Applied 2 StateFusion.
Applied 1 GPUTransformSDFG.


In [5]:
gemm, state = find_map_by_name(sdfg, "gemm_map")
xfutil.tile(state.parent, gemm, True, True, __i0 = 128, __i1 = 128, __i2 = 8)
sdfg

In [6]:
sdfg.apply_transformations(MapCollapse)
sdfg

Applied 1 MapCollapse.


In [7]:
gemm, state = find_map_by_param(state.parent, "__i0")
xfutil.tile(state.parent, gemm, True, True, __i0 = 64, __i1 = 32)
sdfg

In [8]:
warp_entry_outer, state = find_map_by_param(state.parent, "tile1___i0")
warp_entry_inner, state = find_map_by_param(state.parent, "tile1___i1")
MapCollapse.apply_to(state.parent, _outer_map_entry=warp_entry_outer, _inner_map_entry=warp_entry_inner)



(MapEntry (gemm_map[tile1___i0=0:128:64, tile1___i1=0:128:32]),
 MapExit (gemm_map[tile1___i0=0:128:64, tile1___i1=0:128:32]))

In [9]:
sdfg

In [10]:
smem_entry_outer, state = find_map_by_param(state, "tile___i2")
smem_entry_outer

MapEntry (gemm_map[tile___i2=0:K:8])

In [11]:
smem_entry_inner, state = find_map_by_param(state, "tile1___i1")
smem_entry_inner

MapEntry (gemm_map[tile1___i0=0:128:64, tile1___i1=0:128:32])

In [12]:
InLocalStorage.apply_to(state.parent, node_a=smem_entry_outer, node_b=smem_entry_inner)
sdfg

In [13]:
InLocalStorage.apply_to(state.parent, node_a=smem_entry_outer, node_b=smem_entry_inner)
sdfg

In [14]:
threadblock_tile, state = find_map_by_param(sdfg, 'tile1___i0')
threadblock_tile._map.schedule = dace.ScheduleType.GPU_ThreadBlock
sdfg

In [16]:
warp_tile, state = find_map_by_param(state.parent, "__i0")
xfutil.tile(state.parent, warp_tile, True, True, __i0 = 4, __i1 = 4)
sdfg

In [None]:
warp_entry_outer, state = find_map_by_param(state.parent, "tile1___i0")
warp_entry_inner, state = find_map_by_param(state.parent, "tile1___i1")
MapCollapse.apply_to(state.parent, _outer_map_entry=warp_entry_outer, _inner_map_entry=warp_entry_inner)