Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions modeling/transformers/bench_qwen3_5.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

# Benchmark script for Qwen3.5-0.8B model
# Compares PyTorch baseline vs TileGym CUTILE backend

set -e

MODEL_ID="Qwen/Qwen3.5-0.8B"
INPUT_FILE="sample_inputs/input_prompt_small.txt"
OUTPUT_LENGTH=50
SUMMARY_FILE="qwen3_5_benchmark_summary.txt"
BATCH_SIZE=1
LOG_DIR="${LOG_DIR:-${TMPDIR:-/tmp}/tilegym_bench}"

echo "========================================"
echo " Qwen3.5-0.8B Performance Benchmark"
echo "========================================"
echo ""
echo "Model: ${MODEL_ID}"
echo "Input: ${INPUT_FILE}"
echo "Output length: ${OUTPUT_LENGTH} tokens"
echo "Batch size: ${BATCH_SIZE}"
echo ""

# Clean previous results
rm -f ${SUMMARY_FILE}

echo "Running PyTorch baseline..."
python infer.py \
--model_id ${MODEL_ID} \
--profile \
--log_dir ${LOG_DIR} \
--sentence_file ${INPUT_FILE} \
--batch_size ${BATCH_SIZE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "Running TileGym CUTILE backend..."
python infer.py \
--model_id ${MODEL_ID} \
--use_tilegym \
--use_cutile \
--use_attn \
--profile \
--log_dir ${LOG_DIR} \
--sentence_file ${INPUT_FILE} \
--batch_size ${BATCH_SIZE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "========================================"
echo " Benchmark Results"
echo "========================================"
if [ -f ${SUMMARY_FILE} ]; then
cat ${SUMMARY_FILE}
rm -f ${SUMMARY_FILE}
else
echo "Summary file not found."
fi
echo "========================================"

echo ""
echo "========================================"
echo " TileGym Kernel Coverage"
echo "========================================"
python infer.py \
--model_id ${MODEL_ID} \
--use_tilegym \
--use_cutile \
--use_attn \
--report_kernel_coverage \
--log_dir ${LOG_DIR} \
--sentence_file ${INPUT_FILE} \
--batch_size ${BATCH_SIZE} \
--output_length ${OUTPUT_LENGTH}
echo "========================================"
28 changes: 24 additions & 4 deletions modeling/transformers/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tilegym.transformers import apply_tilegym_kernel_to_mistral
from tilegym.transformers import apply_tilegym_kernel_to_phi3
from tilegym.transformers import apply_tilegym_kernel_to_qwen2
from tilegym.transformers import apply_tilegym_kernel_to_qwen3


def check_and_setup_model_cache(model_id):
Expand Down Expand Up @@ -148,9 +149,9 @@ def _fix_tokenizer_decoder_if_needed(tokenizer, model_id):
print(f"Fixed tokenizer decoder: replaced ByteFallback with ByteLevel (from tokenizer.json)")


def load_tokenizer_with_cache(model_id):
def load_tokenizer_with_cache(model_id, **kwargs):
"""Load tokenizer with cache checking."""
tokenizer = _load_with_fallback(model_id, AutoTokenizer, "tokenizer")
tokenizer = _load_with_fallback(model_id, AutoTokenizer, "tokenizer", **kwargs)
_fix_tokenizer_decoder_if_needed(tokenizer, model_id)
return tokenizer

Expand Down Expand Up @@ -267,6 +268,10 @@ def apply_tilegym_patch(model_id, use_attn=False, use_cutile=False):
apply_tilegym_kernel_to_gpt_oss(rope=True, rms_norm=True, swiglu=False, attn=use_attn, use_cutile=use_cutile)
elif "mistral" in model_name:
apply_tilegym_kernel_to_mistral(rope=True, rms_norm=True, swiglu=True, attn=use_attn, use_cutile=use_cutile)
elif "qwen3.5" in model_name or "qwen3_5" in model_name:
apply_tilegym_kernel_to_qwen3(
rope=True, rms_norm=True, swiglu=True, attn=use_attn, gated_delta_rule=True, use_cutile=use_cutile
)
elif "qwen" in model_name:
apply_tilegym_kernel_to_qwen2(rope=True, rms_norm=True, swiglu=True, attn=use_attn, use_cutile=use_cutile)
elif "gemma" in model_name:
Expand Down Expand Up @@ -304,6 +309,13 @@ def __init__(self):
# MLA kernels
"naive_absorb_mla",
"_mla_decoding",
# Gated delta rule kernels (Qwen3.5 linear attention)
"recurrent_gated_delta_rule",
"chunk_gated_delta_rule",
"_ct_chunk_inter_recurrence_kernel",
"_ct_intra_chunk_prepare_kernel",
# Partial RoPE kernel (Qwen3.5)
"rope_partial_kernel",
# Reduce kernels
"splitk_reduce_kernel",
# GEMM kernels
Expand Down Expand Up @@ -344,7 +356,11 @@ def run(self):

print(f"Running nsys profile command:\n {shlex.join(nsys_cmd)}\n")

proc = subprocess.Popen(nsys_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
env = os.environ.copy()
# nsys writes scratch files to TMPDIR (defaults to /tmp/nvidia/nsight_systems).
# Use log_dir as TMPDIR so nsys works in environments where /tmp is read-only.
env["TMPDIR"] = self.log_dir
proc = subprocess.Popen(nsys_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env)
for line in proc.stdout:
print(line, end="")
proc.wait()
Expand Down Expand Up @@ -578,7 +594,11 @@ def main():

# Load tokenizer and model with cache support
print(f"Loading model {args.model_id}...")
tokenizer = load_tokenizer_with_cache(args.model_id)
tokenizer_kwargs = {}
if "qwen3.5" in args.model_id.lower() or "qwen3_5" in args.model_id.lower():
# Qwen3.5 slow tokenizer produces empty outputs; use fast tokenizer
tokenizer_kwargs["use_fast"] = True
tokenizer = load_tokenizer_with_cache(args.model_id, **tokenizer_kwargs)
backend = "base"
if args.use_tilegym:
if args.use_cutile:
Expand Down
14 changes: 10 additions & 4 deletions src/tilegym/ops/cutile/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,16 @@ def fmha_kernel_impl(
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)

# Load query tile for this batch, head, and M-chunk
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)).reshape(
(TILE_M, TILE_D)
) # [TILE_M, TILE_D]
# Load query tile for this batch, head, and M-chunk.
# PaddingMode.ZERO ensures OOB rows (when TILE_M > q_len) read as
# zeros rather than stale memory, preventing NaN from softmax on
# recycled allocations.
q = ct.load(
Q,
index=(batch_idx, head_idx, bid_x, 0),
shape=(1, 1, TILE_M, TILE_D),
padding_mode=ct.PaddingMode.ZERO,
).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]

# Loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
Expand Down
Loading
Loading