1) 环境初始化与路径/版本打印

In [1]:
# Inspect TileLang Pipelined IR: default vs explicit order/stage/group
import os, sys, importlib.util, textwrap, re
from pathlib import Path
import tilelang
from tilelang.jit.kernel import JITKernel

ROOT = Path('/home/chenxi')
A_PATH = ROOT / 'tilelang/examples/flash_attention/example_mha_fwd_bshd.py'
B_PATH = ROOT / 'tilelang/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined-2.py'
OUT_DIR = ROOT / 'tilelang/notebooks/artifacts'
OUT_DIR.mkdir(parents=True, exist_ok=True)

print('TileLang version:', tilelang.__version__)
print('Using tilelang from:', tilelang.__file__)
print('A:', A_PATH)
print('B:', B_PATH)

TileLang version: 0.1.5
Using tilelang from: /home/chenxi/miniconda3/envs/tilelang/lib/python3.12/site-packages/tilelang/__init__.py
A: /home/chenxi/tilelang/examples/flash_attention/example_mha_fwd_bshd.py
B: /home/chenxi/tilelang/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined-2.py


2) 用 `importlib` 的 `spec_from_file_location` 安全加载两个示例模块，并打印加载的模块名

In [2]:
# Helpers to safely import example modules without triggering autotune side-effects

def load_module(path: str, mod_name: str):
    spec = importlib.util.spec_from_file_location(mod_name, path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod

A = load_module(str(A_PATH), 'example_mha_fwd_bshd')
B = load_module(str(B_PATH), 'example_mha_fwd_bshd_wgmma_pipelined_2')

print('Loaded modules:', A.__name__, B.__name__)



Joint配置数量: 6
Tile配置数量: 4 (固定stage=1)
Loaded modules: example_mha_fwd_bshd example_mha_fwd_bshd_wgmma_pipelined_2


3) 分别调用两份脚本的 `flashattn(...)` 得到 `JITKernel`，从中取 `prim_func`，即编译前的 TIR ，保存，用于直观看到 `T.Pipelined(...)` 的源级调用位置与上下文

In [3]:
# Build kernels via JIT (avoid manual lower), then extract PrimFuncs
from tilelang.jit.kernel import JITKernel

# A: use flashattn (decorated), returns JITKernel; autotune will be bypassed if all tunables provided
_tmp_A = A.flashattn(batch=1, heads=1, seq_len=256, dim=64, is_causal=False,
                     block_M=64, block_N=64, num_stages=2, threads=128)
# B: explicit pipelined version
_tmp_B = B.flashattn(batch=1, heads=1, seq_len=256, dim=64, is_causal=False,
                     block_M=64, block_N=64, num_stages=2, threads=128)

assert isinstance(_tmp_A, JITKernel), f"A should be JITKernel, got {_tmp_A}"
assert isinstance(_tmp_B, JITKernel), f"B should be JITKernel, got {_tmp_B}"

kern_A: JITKernel = _tmp_A
kern_B: JITKernel = _tmp_B

prim_A = kern_A.prim_func
prim_B = kern_B.prim_func

print('A kernel ok:', isinstance(kern_A, JITKernel))
print('B kernel ok:', isinstance(kern_B, JITKernel))
print('A prim type:', type(prim_A))
print('B prim type:', type(prim_B))

pre_A = prim_A.script()
pre_B = prim_B.script()
(OUT_DIR / 'pre_default.py').write_text(pre_A)
(OUT_DIR / 'pre_explicit.py').write_text(pre_B)
print('Saved pre-lower scripts to', OUT_DIR)



2025-09-09 16:36:24  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel_impl` with `out_idx=[3]`
2025-09-09 16:36:35  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `kernel_impl`
2025-09-09 16:36:35  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[3]`
2025-09-09 16:36:45  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
A kernel ok: True
B kernel ok: True
A prim type: <class 'tvm.tir.function.PrimFunc'>
B prim type: <class 'tvm.tir.function.PrimFunc'>
Saved pre-lower scripts to /home/chenxi/tilelang/notebooks/artifacts


4) 通过 `kernel.artifact.device_mod.script()` 拿到 lowered 设备端 IR，并保存到 `artifacts/lower_*.py`，用于观察实际生成的 barrier/TMA/WGMMA 等指令结构

In [4]:
assert isinstance(kern_A, JITKernel) and isinstance(kern_B, JITKernel), "请先构建 kern_A/kern_B"

def get_device_mod(kernel: JITKernel):
    mod = getattr(kernel.artifact, 'device_mod', None)
    if mod is None:
        mod = getattr(kernel.adapter, 'device_mod', None)
    assert mod is not None, "device_mod 获取失败：请重跑上方构建单元或改用不同 shape 触发重新编译"
    return mod

low_A_mod = get_device_mod(kern_A)
low_B_mod = get_device_mod(kern_B)

low_A = low_A_mod.script()
low_B = low_B_mod.script()
(OUT_DIR / 'lower_default.py').write_text(low_A)
(OUT_DIR / 'lower_explicit.py').write_text(low_B)
print('Saved lowered device IR to', OUT_DIR)
print(low_A[:800]); print('---'); print(low_B[:800])

Saved lowered device IR to /home/chenxi/tilelang/notebooks/artifacts
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def kernel_impl_kernel(K_desc: T.handle("uint8x128", "grid_constant"), Output_desc: T.handle("uint8x128", "grid_constant"), Q_desc: T.handle("uint8x128", "grid_constant"), V_desc: T.handle("uint8x128", "grid_constant")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 40960, "target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 4, "blockIdx.y": 1, "blockIdx.z": 1, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "blockIdx.z", "threadIdx.x", "th
---
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main_k

5) 基于关键词（mbarrier/tma/wgmma 等）从 lowered IR 提取上下文片段，快速定位流水线/同步相关代码

In [8]:
# Extract and highlight pipeline/barrier-related lines

def grep_keywords(text: str, kws):
    lines = text.splitlines()
    hits = []
    for i,l in enumerate(lines):
        if any(k in l for k in kws):
            start = max(0, i-3)
            end = min(len(lines), i+4)
            hits.append('\n'.join(lines[start:end]))
    return '\n\n'.join(hits)

KWs = [
    'mbarrier', 'barrier', 'cp.async', 'tma', 'descriptor', 'async', 'wait_group', 'commit_group',
    'tlx.async_descriptor_load', 'ldmatrix', 'WGMMA', 'wgmma', 'wmma', 'warpgroup'
]

print('=== Default (lowered) key sections ===')
print(grep_keywords(low_A, KWs)[:2000])
print('\n=== Explicit (lowered) key sections ===')
print(grep_keywords(low_B, KWs)[:2000])

=== Default (lowered) key sections ===
        by = T.launch_thread("blockIdx.y", 1)
        bz = T.launch_thread("blockIdx.z", 1)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(9)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Q_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", K_desc)

        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(9)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Q_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", K_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", V_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Output_desc)

        T.create_barriers(9)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Q_desc)
            T.call_extern("handle", "tl::prefetch_

In [9]:
# 安装/恢复 Pipelined 记录器（monkeypatch）
import threading, time, json, csv
import tilelang.language as TL

_PIPELINED_ORIG = getattr(TL, 'Pipelined')
_PIPELINE_LOG = []
_PIPELINE_CTX = threading.local()

# 提供上下文配置，便于在记录里带出当前 config
class PipelineLogScope:
    def __init__(self, tag: str = None, **cfg):
        self.tag = tag
        self.cfg = cfg
    def __enter__(self):
        setattr(_PIPELINE_CTX, 'cfg', self.cfg)
        setattr(_PIPELINE_CTX, 'tag', self.tag)
    def __exit__(self, exc_type, exc, tb):
        setattr(_PIPELINE_CTX, 'cfg', None)
        setattr(_PIPELINE_CTX, 'tag', None)


def _pipelined_logger(*args, **kwargs):
    # 记录关键参数
    rec = {
        'ts': time.time(),
        'tag': getattr(_PIPELINE_CTX, 'tag', None),
        'cfg': getattr(_PIPELINE_CTX, 'cfg', None),
        'num_stages': kwargs.get('num_stages'),
        'order': kwargs.get('order'),
        'stage': kwargs.get('stage'),
        'group': kwargs.get('group'),
        'sync': kwargs.get('sync'),
    }
    _PIPELINE_LOG.append(rec)
    # 调用原始 Pipelined
    return _PIPELINED_ORIG(*args, **kwargs)


def install_pipelined_logger():
    TL.Pipelined = _pipelined_logger
    return True


def restore_pipelined():
    TL.Pipelined = _PIPELINED_ORIG
    return True

print('Pipelined logger ready. Use install_pipelined_logger()/restore_pipelined() and PipelineLogScope.')

Pipelined logger ready. Use install_pipelined_logger()/restore_pipelined() and PipelineLogScope.


In [None]:
install_pipelined_logger()

from tilelang.jit.kernel import JITKernel

# 通过改变 seq_len / dim 来避开 JIT 的内存缓存键
explicit_runs = [
    {'block_M':64,'block_N':64,'num_stages':2,'threads':128, 'batch':1,'heads':1,'seq_len':192,'dim':64},
    {'block_M':64,'block_N':64,'num_stages':2,'threads':128, 'batch':1,'heads':1,'seq_len':320,'dim':64},
    {'block_M':128,'block_N':128,'num_stages':2,'threads':256,'batch':1,'heads':1,'seq_len':256,'dim':96},
]

for i, cfg in enumerate(explicit_runs):
    tag = f'explicit-shape-{i}'
    call_cfg = {k:cfg[k] for k in ['block_M','block_N','num_stages','threads']}
    with PipelineLogScope(tag=tag, **call_cfg):
        kern = B.flashattn(**cfg, is_causal=False)
        assert isinstance(kern, JITKernel)

restore_pipelined()

print('Total logged calls (cumulative):', len(_PIPELINE_LOG))
print('Last 5 samples:')
for r in _PIPELINE_LOG[-5:]:
    print(r)

2025-09-09 16:36:45  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[3]`
2025-09-09 16:36:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
2025-09-09 16:36:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[3]`
2025-09-09 16:37:05  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
2025-09-09 16:37:05  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[3]`
2025-09-09 16:37:17  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
2025-09-09 16:37:17  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel_impl` with `out_idx=[3]`
2025-09-09 16:37:26  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `kernel_impl`
2025-09-09 16:37:27  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel_impl` 