In [1]:
import dace
import numpy as np
from dace.transformation.interstate import GPUTransformSDFG, StateFusion
from dace.transformation.dataflow import MapTiling, InLocalStorage, MapExpansion, MapCollapse
from dace.transformation.optimizer import Optimizer
from dace.transformation import helpers as xfutil

def find_map_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry:
    """ Finds the first map entry node by the given parameter name. """
    return next((n, state) for n, state in sdfg.all_nodes_recursive()
                if isinstance(n, dace.nodes.MapEntry) and pname in n.params)

def find_map_by_name(sdfg: dace.SDFG, name: str) -> dace.nodes.MapEntry:
    """ Finds the first map entry node by the given parameter name. """
    return next((n, state) for n, state in sdfg.all_nodes_recursive()
                if isinstance(n, dace.nodes.MapEntry) and name == n.label)

M = dace.symbol('M')
N = dace.symbol('N')
K = dace.symbol('K')

@dace.program
def matmul(A: dace.float64[M, K], B: dace.float64[K, N]):
    return A @ B

In [2]:
sdfg = matmul.to_sdfg()
sdfg

Applied 1 StateFusion.


In [3]:
sdfg.expand_library_nodes()
sdfg

Automatically expanded library node "_MatMult_" with implementation "specialize".
Automatically expanded library node "_MatMult_gemm" with implementation "pure".


In [4]:
sdfg.apply_transformations(GPUTransformSDFG)
sdfg

Applied 2 StateFusion.
Applied 1 GPUTransformSDFG.


In [5]:
gemm, state = 
1

​

(sdfg, "gemm_map")
xfutil.tile(state.parent, gemm, True, True, __i0 =128, __i1 = 128, __i2 = 8)
sdfg

In [6]:
sdfg.apply_transformations(MapCollapse)
sdfg

Applied 1 MapCollapse.
