问题：默认流水线的阶段结构在 lowered IR 中只体现为 barrier/TMA/MMA 的相对顺序，不会还原为 `order/stage/group` 数组；且我们批量推断的摘要粒度过粗，难以区分不同配置。

方案：对 `tilelang.language.pipeline.Pipelined` 做“可恢复的 monkeypatch”记录器，采集每一次构图时传入的关键字参数（`num_stages/order/stage/group/sync`），连同当前配置（block_M/N、threads 等）一起记录。


In [75]:
# 安装/恢复 Pipelined 记录器（monkeypatch）
import threading, time, json, csv
import tilelang.language as TL

_PIPELINED_ORIG = getattr(TL, 'Pipelined')
_PIPELINE_LOG = []
_PIPELINE_CTX = threading.local()

# 提供上下文配置，便于在记录里带出当前 config
class PipelineLogScope:
    def __init__(self, tag: str = None, **cfg):
        self.tag = tag
        self.cfg = cfg
    def __enter__(self):
        setattr(_PIPELINE_CTX, 'cfg', self.cfg)
        setattr(_PIPELINE_CTX, 'tag', self.tag)
    def __exit__(self, exc_type, exc, tb):
        setattr(_PIPELINE_CTX, 'cfg', None)
        setattr(_PIPELINE_CTX, 'tag', None)


def _pipelined_logger(*args, **kwargs):
    # 记录关键参数
    rec = {
        'ts': time.time(),
        'tag': getattr(_PIPELINE_CTX, 'tag', None),
        'cfg': getattr(_PIPELINE_CTX, 'cfg', None),
        'num_stages': kwargs.get('num_stages'),
        'order': kwargs.get('order'),
        'stage': kwargs.get('stage'),
        'group': kwargs.get('group'),
        'sync': kwargs.get('sync'),
    }
    _PIPELINE_LOG.append(rec)
    # 调用原始 Pipelined
    return _PIPELINED_ORIG(*args, **kwargs)


def install_pipelined_logger():
    TL.Pipelined = _pipelined_logger
    return True


def restore_pipelined():
    TL.Pipelined = _PIPELINED_ORIG
    return True

print('Pipelined logger ready. Use install_pipelined_logger()/restore_pipelined() and PipelineLogScope.')



Pipelined logger ready. Use install_pipelined_logger()/restore_pipelined() and PipelineLogScope.


In [76]:
install_pipelined_logger()

from tilelang.jit.kernel import JITKernel
explicit_cfgs = [
    {'block_M':64,'block_N':64,'num_stages':2,'threads':128},
    {'block_M':128,'block_N':128,'num_stages':2,'threads':256},
]

for i,cfg in enumerate(explicit_cfgs):
    with PipelineLogScope(tag=f'explicit-{i}', **cfg):
        kern = B.flashattn(batch=1, heads=1, seq_len=256, dim=64, is_causal=False, **cfg)
        assert isinstance(kern, JITKernel)

default_cfgs = [
    {'block_M':64,'block_N':64,'num_stages':1,'threads':128},
    {'block_M':128,'block_N':128,'num_stages':2,'threads':128},
]

for i,cfg in enumerate(default_cfgs):
    with PipelineLogScope(tag=f'default-{i}', **cfg):
        kern = A.flashattn(batch=1, heads=1, seq_len=256, dim=64, is_causal=False, **cfg)
        assert isinstance(kern, JITKernel)

restore_pipelined()

# 导出日志
log_json = OUT_DIR / 'pipeline_tune_calls.json'
log_csv = OUT_DIR / 'pipeline_tune_calls.csv'

with open(log_json, 'w') as f:
    json.dump(_PIPELINE_LOG, f, indent=2)

if _PIPELINE_LOG:
    keys = sorted(set().union(*[set(x.keys()) for x in _PIPELINE_LOG]))
    with open(log_csv, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=keys)
        writer.writeheader()
        writer.writerows(_PIPELINE_LOG)

print('Logged calls:', len(_PIPELINE_LOG))
print('Sample (first 3):')
for r in _PIPELINE_LOG[:3]:
    print(r)



Logged calls: 0
Sample (first 3):


In [77]:
# 扩大量：多样化 shape 以强制重新构图，多次记录显式/默认的实际调用
install_pipelined_logger()

from tilelang.jit.kernel import JITKernel

# 通过改变 seq_len / dim 来避开 JIT 的内存缓存键
explicit_runs = [
    {'block_M':64,'block_N':64,'num_stages':2,'threads':128, 'batch':1,'heads':1,'seq_len':192,'dim':64},
    {'block_M':64,'block_N':64,'num_stages':2,'threads':128, 'batch':1,'heads':1,'seq_len':320,'dim':64},
    {'block_M':128,'block_N':128,'num_stages':2,'threads':256,'batch':1,'heads':1,'seq_len':256,'dim':96},
]

for i, cfg in enumerate(explicit_runs):
    tag = f'explicit-shape-{i}'
    call_cfg = {k:cfg[k] for k in ['block_M','block_N','num_stages','threads']}
    with PipelineLogScope(tag=tag, **call_cfg):
        kern = B.flashattn(**cfg, is_causal=False)
        assert isinstance(kern, JITKernel)

# 默认版同理
default_runs = [
    {'block_M':64,'block_N':64,'num_stages':1,'threads':128, 'batch':1,'heads':1,'seq_len':192,'dim':64},
    {'block_M':128,'block_N':128,'num_stages':2,'threads':128,'batch':1,'heads':1,'seq_len':320,'dim':64},
]

for i, cfg in enumerate(default_runs):
    tag = f'default-shape-{i}'
    call_cfg = {k:cfg[k] for k in ['block_M','block_N','num_stages','threads']}
    with PipelineLogScope(tag=tag, **call_cfg):
        kern = A.flashattn(**cfg, is_causal=False)
        assert isinstance(kern, JITKernel)

restore_pipelined()

print('Total logged calls (cumulative):', len(_PIPELINE_LOG))
print('Last 5 samples:')
for r in _PIPELINE_LOG[-5:]:
    print(r)



2025-09-09 16:25:43  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[3]`
2025-09-09 16:25:53  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
2025-09-09 16:25:53  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[3]`
2025-09-09 16:26:03  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
2025-09-09 16:26:03  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[3]`
2025-09-09 16:26:15  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
2025-09-09 16:26:15  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel_impl` with `out_idx=[3]`
2025-09-09 16:26:25  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `kernel_impl`
2025-09-09 16:26:25  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel_impl` 

1) 环境初始化与路径/版本打印

In [78]:
# Inspect TileLang Pipelined IR: default vs explicit order/stage/group
import os, sys, importlib.util, textwrap, re
from pathlib import Path
import tilelang
from tvm import tir

ROOT = Path('/home/chenxi')
A_PATH = ROOT / 'tilelang/examples/flash_attention/example_mha_fwd_bshd.py'
B_PATH = ROOT / 'tilelang/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined-2.py'
OUT_DIR = ROOT / 'tilelang/notebooks/artifacts'
OUT_DIR.mkdir(parents=True, exist_ok=True)

print('TileLang version:', tilelang.__version__)
print('Using tilelang from:', tilelang.__file__)
print('A:', A_PATH)
print('B:', B_PATH)

TileLang version: 0.1.5
Using tilelang from: /home/chenxi/miniconda3/envs/tilelang/lib/python3.12/site-packages/tilelang/__init__.py
A: /home/chenxi/tilelang/examples/flash_attention/example_mha_fwd_bshd.py
B: /home/chenxi/tilelang/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined-2.py


2) 用 `importlib` 的 `spec_from_file_location` 安全加载两个示例模块，并打印加载的模块名

In [79]:
# Helpers to safely import example modules without triggering autotune side-effects

def load_module(path: str, mod_name: str):
    spec = importlib.util.spec_from_file_location(mod_name, path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod

A = load_module(str(A_PATH), 'example_mha_fwd_bshd')
B = load_module(str(B_PATH), 'example_mha_fwd_bshd_wgmma_pipelined_2')

print('Loaded modules:', A.__name__, B.__name__)



Joint配置数量: 6
Tile配置数量: 4 (固定stage=1)
Loaded modules: example_mha_fwd_bshd example_mha_fwd_bshd_wgmma_pipelined_2


3) 分别调用两份脚本的 `flashattn(...)` 得到 `JITKernel`，从中取 `prim_func`，即编译前的 TIR ，保存，用于直观看到 `T.Pipelined(...)` 的源级调用位置与上下文

In [80]:
# Build kernels via JIT (avoid manual lower), then extract PrimFuncs
from tilelang.jit.kernel import JITKernel

# A: use flashattn (decorated), returns JITKernel; autotune will be bypassed if all tunables provided
_tmp_A = A.flashattn(batch=1, heads=1, seq_len=256, dim=64, is_causal=False,
                     block_M=64, block_N=64, num_stages=2, threads=128)
# B: explicit pipelined version
_tmp_B = B.flashattn(batch=1, heads=1, seq_len=256, dim=64, is_causal=False,
                     block_M=64, block_N=64, num_stages=2, threads=128)

assert isinstance(_tmp_A, JITKernel), f"A should be JITKernel, got {_tmp_A}"
assert isinstance(_tmp_B, JITKernel), f"B should be JITKernel, got {_tmp_B}"

kern_A: JITKernel = _tmp_A
kern_B: JITKernel = _tmp_B

prim_A = kern_A.prim_func
prim_B = kern_B.prim_func

print('A kernel ok:', isinstance(kern_A, JITKernel))
print('B kernel ok:', isinstance(kern_B, JITKernel))
print('A prim type:', type(prim_A))
print('B prim type:', type(prim_B))

pre_A = prim_A.script()
pre_B = prim_B.script()
(OUT_DIR / 'pre_default.py').write_text(pre_A)
(OUT_DIR / 'pre_explicit.py').write_text(pre_B)
print('Saved pre-lower scripts to', OUT_DIR)



A kernel ok: True
B kernel ok: True
A prim type: <class 'tvm.tir.function.PrimFunc'>
B prim type: <class 'tvm.tir.function.PrimFunc'>
Saved pre-lower scripts to /home/chenxi/tilelang/notebooks/artifacts


4) 通过 `kernel.artifact.device_mod.script()` 拿到 lowered 设备端 IR，并保存到 `artifacts/lower_*.py`，用于观察实际生成的 barrier/TMA/WGMMA 等指令结构

In [81]:
# Lower to device IR from kernels (use compiled artifacts)
low_A = kern_A.artifact.device_mod.script()
low_B = kern_B.artifact.device_mod.script()
(OUT_DIR / 'lower_default.py').write_text(low_A)
(OUT_DIR / 'lower_explicit.py').write_text(low_B)
print('Saved lowered device IR to', OUT_DIR)
print(low_A[:800])
print('---')
print(low_B[:800])

Saved lowered device IR to /home/chenxi/tilelang/notebooks/artifacts
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def kernel_impl_kernel(K_desc: T.handle("uint8x128", "grid_constant"), Output_desc: T.handle("uint8x128", "grid_constant"), Q_desc: T.handle("uint8x128", "grid_constant"), V_desc: T.handle("uint8x128", "grid_constant")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 40960, "target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 4, "blockIdx.y": 1, "blockIdx.z": 1, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "blockIdx.z", "threadIdx.x", "th
---
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main_k

5) 基于关键词（mbarrier/tma/wgmma 等）从 lowered IR 提取上下文片段，快速定位流水线/同步相关代码

In [82]:
# Extract and highlight pipeline/barrier-related lines

def grep_keywords(text: str, kws):
    lines = text.splitlines()
    hits = []
    for i,l in enumerate(lines):
        if any(k in l for k in kws):
            start = max(0, i-3)
            end = min(len(lines), i+4)
            hits.append('\n'.join(lines[start:end]))
    return '\n\n'.join(hits)

KWs = [
    'mbarrier', 'barrier', 'cp.async', 'tma', 'descriptor', 'async', 'wait_group', 'commit_group',
    'tlx.async_descriptor_load', 'ldmatrix', 'WGMMA', 'wgmma', 'wmma', 'warpgroup'
]

print('=== Default (lowered) key sections ===')
print(grep_keywords(low_A, KWs)[:2000])
print('\n=== Explicit (lowered) key sections ===')
print(grep_keywords(low_B, KWs)[:2000])

=== Default (lowered) key sections ===
        by = T.launch_thread("blockIdx.y", 1)
        bz = T.launch_thread("blockIdx.z", 1)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(9)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Q_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", K_desc)

        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(9)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Q_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", K_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", V_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Output_desc)

        T.create_barriers(9)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", Q_desc)
            T.call_extern("handle", "tl::prefetch_

In [83]:
# 针对默认流水线：基于 lowered IR 的启发式推断阶段化（统计每个循环步的TMA/Barrier/WGMMA顺序）
import re

def analyze_pipeline_order(lowered_text: str, max_lines: int = 10000):
    lines = lowered_text.splitlines()[:max_lines]
    # 以 mbarrier 初始化/arrive/wait 或者 prefetch_tma_descriptor 作为“段落”切分点
    markers = ['prefetch_tma_descriptor', 'ptx_init_barrier_thread_count', 'mbarrier', 'wgmma', 'wmma']
    events = []
    for i,l in enumerate(lines):
        tag = None
        if 'prefetch_tma_descriptor' in l or 'async_descriptor_load' in l or 'tma' in l:
            tag = 'TMA'
        elif 'ptx_init_barrier_thread_count' in l or 'create_barriers' in l or 'mbarrier' in l:
            tag = 'BARRIER'
        elif 'wgmma' in l.lower() or 'wmma' in l.lower():
            tag = 'MMA'
        if tag:
            events.append((i, tag, l.strip()))
    # 把邻近的事件聚成“阶段”，并统计顺序
    phases = []
    cur = []
    last_idx = -999
    for idx, tag, txt in events:
        if idx - last_idx > 20 and cur:
            phases.append(cur)
            cur = []
        cur.append((idx, tag, txt))
        last_idx = idx
    if cur:
        phases.append(cur)
    # 阶段摘要
    summary = []
    for p in phases:
        tags = [t for _, t, _ in p]
        summary.append('->'.join(sorted(set(tags), key=tags.index)))
    return phases, summary

phases_A, summary_A = analyze_pipeline_order(low_A)
phases_B, summary_B = analyze_pipeline_order(low_B)

print('Default inferred phase order (coarse):', summary_A[:12])
print('Explicit inferred phase order (coarse):', summary_B[:12])

# 可选：打印前几个阶段的上下文
for name, phases in [('Default', phases_A), ('Explicit', phases_B)]:
    print(f"\n{name} first 2 phases (context):")
    for pi, phase in enumerate(phases[:2]):
        print(f"  Phase {pi}:")
        for idx, tag, txt in phase[:8]:
            print(f"    [{idx:05d}] {tag}: {txt}")



Default inferred phase order (coarse): ['BARRIER->TMA', 'BARRIER->TMA']
Explicit inferred phase order (coarse): ['BARRIER->TMA', 'BARRIER', 'BARRIER->TMA']

Default first 2 phases (context):
  Phase 0:
    [00037] BARRIER: T.create_barriers(9)
    [00039] TMA: T.call_extern("handle", "tl::prefetch_tma_descriptor", Q_desc)
    [00040] TMA: T.call_extern("handle", "tl::prefetch_tma_descriptor", K_desc)
    [00041] TMA: T.call_extern("handle", "tl::prefetch_tma_descriptor", V_desc)
    [00042] TMA: T.call_extern("handle", "tl::prefetch_tma_descriptor", Output_desc)
    [00043] BARRIER: T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 1)
    [00044] BARRIER: T.ptx_init_barrier_thread_count(T.get_mbarrier(1), 1)
    [00045] BARRIER: T.ptx_init_barrier_thread_count(T.get_mbarrier(2), 1)
  Phase 1:
    [00111] BARRIER: T.mbarrier_wait_parity(T.get_mbarrier(k % 2 + 2), k // 2)
    [00113] BARRIER: T.ptx_arrive_barrier(T.get_mbarrier(k % 2 + 6), 0, tx == 0)
    [00118] TMA: T.ptx_stmatrix(0, 

7) 默认流水线批量推断与导出

    - 遍历一组 `block_M/block_N/num_stages` 组合，构建 JITKernel 并获取 lowered IR。
    - 用前面的 `analyze_pipeline_order` 提取每个配置的阶段摘要（coarse order）。
    - 导出为 `artifacts/pipeline_inferred_default.json` 与 `artifacts/pipeline_inferred_default.csv`，便于与显式版对比。


In [84]:
# 批量推断默认流水线阶段摘要并导出
import itertools, json, csv
from tilelang.jit.kernel import JITKernel

# 可根据需要调整搜索空间（保持小范围，避免编译时间过长）
BLOCK_M = [64, 128]
BLOCK_N = [64, 128]
NUM_STAGES = [1, 2]
THREADS = [128]

combos = list(itertools.product(BLOCK_M, BLOCK_N, NUM_STAGES, THREADS))

results = []
for bm, bn, ns, th in combos:
    try:
        # 构建默认流水线内核（不显式 order/stage/group）
        kern = A.flashattn(batch=1, heads=1, seq_len=256, dim=64, is_causal=False,
                           block_M=bm, block_N=bn, num_stages=ns, threads=th)
        assert isinstance(kern, JITKernel)
        low = kern.artifact.device_mod.script()
        # 提取阶段摘要
        _, summary = analyze_pipeline_order(low)
        results.append({
            'block_M': bm,
            'block_N': bn,
            'num_stages': ns,
            'threads': th,
            'summary': ' | '.join(summary[:8])  # 截断显示，避免太长
        })
    except Exception as e:
        results.append({
            'block_M': bm,
            'block_N': bn,
            'num_stages': ns,
            'threads': th,
            'summary': f'ERROR: {type(e).__name__}: {e}'
        })

json_out = OUT_DIR / 'pipeline_inferred_default.json'
csv_out = OUT_DIR / 'pipeline_inferred_default.csv'

with open(json_out, 'w') as f:
    json.dump(results, f, indent=2)

with open(csv_out, 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
    writer.writeheader()
    writer.writerows(results)

print(f'Exported {len(results)} inferred entries:')
print(' -', json_out)
print(' -', csv_out)
print('Preview:')
for r in results[:5]:
    print(r)



Exported 8 inferred entries:
 - /home/chenxi/tilelang/notebooks/artifacts/pipeline_inferred_default.json
 - /home/chenxi/tilelang/notebooks/artifacts/pipeline_inferred_default.csv
Preview:
{'block_M': 64, 'block_N': 64, 'num_stages': 1, 'threads': 128, 'summary': 'BARRIER->TMA | BARRIER->TMA'}
{'block_M': 64, 'block_N': 64, 'num_stages': 2, 'threads': 128, 'summary': 'BARRIER->TMA | BARRIER->TMA'}
{'block_M': 64, 'block_N': 128, 'num_stages': 1, 'threads': 128, 'summary': 'BARRIER->TMA | BARRIER->TMA'}
{'block_M': 64, 'block_N': 128, 'num_stages': 2, 'threads': 128, 'summary': 'BARRIER->TMA | BARRIER->TMA'}
{'block_M': 128, 'block_N': 64, 'num_stages': 1, 'threads': 128, 'summary': 'BARRIER->TMA | BARRIER->TMA'}


In [85]:
# 更健壮的 T.Pipelined(...) 参数提取实现
import re, ast

def extract_pipelined_args(text: str):
    # 找到 T.Pipelined( 的起点，然后做括号配对拿到完整参数字符串
    start = text.find('T.Pipelined(')
    if start == -1:
        return None
    i = start + len('T.Pipelined(')
    depth = 1
    args = []
    while i < len(text) and depth > 0:
        ch = text[i]
        args.append(ch)
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
        i += 1
    arg_str = ''.join(args[:-1])  # 去掉最后一个闭括号
    def grab(key):
        m = re.search(rf"{key}\s*=\s*(\[[^\]]*\])", arg_str)
        if not m:
            return None
        try:
            return ast.literal_eval(m.group(1))
        except Exception:
            return m.group(1)
    return {
        'order': grab('order'),
        'stage': grab('stage'),
        'group': grab('group'),
        'sync': grab('sync'),
        'arg_str': arg_str
    }

res = extract_pipelined_args(prim_B.script())
if res is None or (res['order'] is None and res['stage'] is None and res['group'] is None):
    print('PrimFunc script中未匹配到参数，回退到源码解析...')
    src = B_PATH.read_text()
    res = extract_pipelined_args(src)

print('Parsed from:', 'PrimFunc' if 'PrimFunc' in locals() else 'Source')
print('order =', res.get('order'))
print('stage =', res.get('stage'))
print('group =', res.get('group'))
print('sync  =', res.get('sync'))
if res and res.get('arg_str'):
    print('...context snippet...')
    print(res['arg_str'][:400])



PrimFunc script中未匹配到参数，回退到源码解析...
Parsed from: Source
order = [-1, 0, 3, 1, -1, 2]
stage = [-1, 0, 0, 1, -1, 1]
group = [[0]
sync  = None
...context snippet...

                    loop_range,
                    num_stages=num_stages,
                    order=[-1, 0, 3, 1, -1, 2],
                    stage=[-1, 0, 0, 1, -1, 1],
                    group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]


### 方法A：遍历 get_configs()，抓取每个配置的 order/stage/group/num_stages 并导出

- 对 `example_mha_fwd_bshd_wgmma_pipelined-2.py`：`order/stage/group` 是在源码里显式指定的，因此不同 config 会得到相同的三元组，但我们依旧记录每个 config 的形状（block_M/N、threads、num_stages）方便关联。
- 不进行 benchmark，只做构图与脚本解析，速度快。
- 导出为 `artifacts/pipeline_orders.json` 与 `artifacts/pipeline_orders.csv`。


In [86]:
# 遍历 get_configs() 并导出（方法A）
import json, csv

# 读取 -2.py 中的 configs 生成器（若无该函数则回退到简单集）
get_cfg = getattr(B, 'get_configs', None)
if get_cfg is None:
    cfgs = [{'block_M': 128, 'block_N': 128, 'num_stages': 2, 'threads': 256}]
else:
    cfgs = get_cfg()

records = []

# 复用前面已实现的 AST 解析器（优先脚本，回退源码）
from typing import Any, Dict

def get_pipelined_kwargs_from_script_or_source() -> Dict[str, Any]:
    cands = extract_all_pipelined_kwargs(prim_B.script())
    if not cands:
        cands = extract_all_pipelined_kwargs(B_PATH.read_text())
    # 取包含最多关键字段的那个
    best = None
    score = -1
    for seg, kw in cands:
        s = sum(1 for k in ('order','stage','group','sync','num_stages') if k in kw)
        if s > score:
            best, score = (seg, kw), s
    return best[1] if best else {}

kw_fixed = get_pipelined_kwargs_from_script_or_source()

for cfg in cfgs:
    entry = {
        'block_M': cfg.get('block_M'),
        'block_N': cfg.get('block_N'),
        'threads': cfg.get('threads'),
        'num_stages_cfg': cfg.get('num_stages'),
        # 来自脚本/源码的显式参数：
        'order': kw_fixed.get('order'),
        'stage': kw_fixed.get('stage'),
        'group': kw_fixed.get('group'),
        'sync': kw_fixed.get('sync'),
        'num_stages_kw': kw_fixed.get('num_stages'),
    }
    records.append(entry)

json_path = OUT_DIR / 'pipeline_orders.json'
csv_path = OUT_DIR / 'pipeline_orders.csv'

with open(json_path, 'w') as f:
    json.dump(records, f, indent=2)

with open(csv_path, 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=list(records[0].keys()))
    writer.writeheader()
    writer.writerows(records)

print(f'Exported {len(records)} entries to:')
print(' -', json_path)
print(' -', csv_path)
print('Preview (first 3):')
for r in records[:3]:
    print(r)



Exported 36 entries to:
 - /home/chenxi/tilelang/notebooks/artifacts/pipeline_orders.json
 - /home/chenxi/tilelang/notebooks/artifacts/pipeline_orders.csv
Preview (first 3):
{'block_M': 64, 'block_N': 64, 'threads': 128, 'num_stages_cfg': 1, 'order': [-1, 0, 3, 1, -1, 2], 'stage': [-1, 0, 0, 1, -1, 1], 'group': [[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]], 'sync': None, 'num_stages_kw': 'num_stages'}
{'block_M': 64, 'block_N': 64, 'threads': 128, 'num_stages_cfg': 2, 'order': [-1, 0, 3, 1, -1, 2], 'stage': [-1, 0, 0, 1, -1, 1], 'group': [[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]], 'sync': None, 'num_stages_kw': 'num_stages'}
{'block_M': 64, 'block_N': 64, 'threads': 128, 'num_stages_cfg': 3, 'order': [-1, 0, 3, 1, -1, 2], 'stage': [-1, 0, 0, 1, -1, 1], 'group': [[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]], 'sync': None, 'num_stages_kw': 'num_stages'}


In [87]:
# 使用 AST 解析 T.Pipelined(...)，支持嵌套列表（group）与多处匹配
import ast

def find_all_pipelined_segments(text: str):
    segs = []
    i = 0
    while True:
        start = text.find('T.Pipelined(', i)
        if start == -1:
            break
        j = start + len('T.Pipelined(')
        depth = 1
        args = []
        while j < len(text) and depth > 0:
            ch = text[j]
            args.append(ch)
            if ch == '(':
                depth += 1
            elif ch == ')':
                depth -= 1
            j += 1
        seg = ''.join(args[:-1])
        segs.append(seg)
        i = j
    return segs


def parse_kwargs_by_ast(arg_str: str):
    # 构造一个可被 ast 解析的表达式：F(<arg_str>)
    expr = ast.parse('F(' + arg_str + ')', mode='eval')
    call = expr.body
    assert isinstance(call, ast.Call)
    out = {}
    for kw in call.keywords:
        key = kw.arg
        try:
            val = ast.literal_eval(kw.value)
        except Exception:
            # 退而求其次，用源码片段
            val = ast.get_source_segment('F(' + arg_str + ')', kw.value)
        out[key] = val
    return out


def extract_all_pipelined_kwargs(text: str):
    results = []
    for seg in find_all_pipelined_segments(text):
        try:
            kw = parse_kwargs_by_ast(seg)
            results.append((seg, kw))
        except Exception:
            continue
    return results

# 先从 PrimFunc script 中解析，找包含 order/stage/group 的那一个；否则回退源码
candidates = extract_all_pipelined_kwargs(prim_B.script())
if not candidates:
    candidates = extract_all_pipelined_kwargs(B_PATH.read_text())

best = None
score = -1
for seg, kw in candidates:
    s = 0
    for k in ('order','stage','group','sync','num_stages'):
        if k in kw:
            s += 1
    if s > score:
        best, score = (seg, kw), s

if best is None:
    print('未能解析到任何 T.Pipelined(...) 关键字参数。')
else:
    seg, kw = best
    print('Matched kwargs keys:', sorted(list(kw.keys())))
    print('order =', kw.get('order'))
    print('stage =', kw.get('stage'))
    print('group =', kw.get('group'))
    print('sync  =', kw.get('sync'))
    print('num_stages =', kw.get('num_stages'))



Matched kwargs keys: ['group', 'num_stages', 'order', 'stage']
order = [-1, 0, 3, 1, -1, 2]
stage = [-1, 0, 0, 1, -1, 1]
group = [[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]
sync  = None
num_stages = num_stages
