Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ cython_debug/

# Cursor
.cursorignore
.vscode/

# intermediate files
**/temp/
Expand Down
5 changes: 3 additions & 2 deletions modeling/transformers/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ RUN python -m pip install --upgrade pip setuptools wheel
# Install PyTorch with CUDA 13.0 support
RUN pip install --no-cache-dir --pre "torch==${TORCH_VERSION}" --index-url https://download.pytorch.org/whl/cu130

# Install cuda-tile and accelerate
# Install cuda-tile and transformers models' dependencies.
RUN pip install --no-cache-dir cuda-tile && \
pip install --no-cache-dir --no-deps accelerate
pip install --no-cache-dir --no-deps accelerate && \
pip install --no-cache-dir sentencepiece protobuf

WORKDIR /workspace/tilegym

Expand Down
52 changes: 52 additions & 0 deletions modeling/transformers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ End-to-end inference examples for transformer language models accelerated with T
| LLaMA-3.1-8B | `meta-llama/Meta-Llama-3.1-8B` | RoPE, SwiGLU, RMSNorm, Attention*, Flash Decoding* |
| DeepSeek-V2-Lite-Chat | `deepseek-ai/DeepSeek-V2-Lite-Chat` | RoPE, SwiGLU, RMSNorm, MoE, MLADecoding*, Attention* |
| Qwen2-7B | `Qwen/Qwen2-7B` | RoPE, SwiGLU, RMSNorm, Attention* |
| Gemma-3-4B-IT | `google/gemma-3-4b-it` | RoPE, GEGLU, RMSNorm, Attention* |
| GPT-OSS | `openai/gpt-oss-20b` | RoPE, RMSNorm, Attention Sink* |
| Mistral-7B-Instruct-v0.3 | `mistralai/Mistral-7B-Instruct-v0.3` | RoPE, SwiGLU, RMSNorm, Attention* |

*Optional: Enable with `--use_attn`, we can use attention provided in TileGym

Expand Down Expand Up @@ -101,6 +104,15 @@ Run benchmark scripts for automated comparison:

# Qwen2-7B benchmark
./bench_qwen.sh

# Gemma-3-4B-IT benchmark
./bench_gemma3.sh

# GPT-OSS benchmark
./bench_gpt_oss.sh

# Mistral-7B benchmark
./bench_mistral.sh
```

### Manual Benchmark
Expand Down Expand Up @@ -169,6 +181,46 @@ python infer.py \
--output_length 100
```

#### Gemma-3-4B-IT Benchmark
```bash
# PyTorch baseline
python infer.py \
--model_id google/gemma-3-4b-it \
--profile \
--sentence_file sample_inputs/input_prompt_small.txt \
--output_length 100

# TileGym CUTILE backend
python infer.py \
--model_id google/gemma-3-4b-it \
--use_tilegym \
--use_cutile \
--use_attn \
--profile \
--sentence_file sample_inputs/input_prompt_small.txt \
--output_length 100
```

#### Mistral-7B Benchmark
```bash
# PyTorch baseline
python infer.py \
--model_id mistralai/Mistral-7B-Instruct-v0.3 \
--profile \
--sentence_file sample_inputs/input_prompt_32K.txt \
--output_length 100

# TileGym CUTILE backend
python infer.py \
--model_id mistralai/Mistral-7B-Instruct-v0.3 \
--use_tilegym \
--use_cutile \
--use_attn \
--profile \
--sentence_file sample_inputs/input_prompt_32K.txt \
--output_length 100
```

## Command Line Options

| Option | Description | Default |
Expand Down
57 changes: 57 additions & 0 deletions modeling/transformers/bench_gemma3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

# Benchmark script for Gemma3 model
# Compares PyTorch baseline vs TileGym CUTILE backend

set -e

MODEL_ID="google/gemma-3-4b-it"
INPUT_FILE="sample_inputs/input_prompt_small.txt"
OUTPUT_LENGTH=100
SUMMARY_FILE="gemma3_benchmark_summary.txt"

echo "========================================"
echo " Gemma3 Performance Benchmark"
echo "========================================"
echo ""
echo "Model: ${MODEL_ID}"
echo "Input: ${INPUT_FILE}"
echo "Output length: ${OUTPUT_LENGTH} tokens"
echo ""

# Clean previous results
rm -f ${SUMMARY_FILE}

echo "Running PyTorch baseline..."
python infer.py \
--model_id ${MODEL_ID} \
--profile \
--sentence_file ${INPUT_FILE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "Running TileGym CUTILE backend..."
python infer.py \
--model_id ${MODEL_ID} \
--use_tilegym \
--use_cutile \
--use_attn \
--profile \
--sentence_file ${INPUT_FILE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "========================================"
echo " Benchmark Results"
echo "========================================"
if [ -f ${SUMMARY_FILE} ]; then
cat ${SUMMARY_FILE}
rm -f ${SUMMARY_FILE}
else
echo "Summary file not found."
fi
echo "========================================"
57 changes: 57 additions & 0 deletions modeling/transformers/bench_gpt_oss.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

# Benchmark script for GPT-OSS model
# Compares PyTorch baseline vs TileGym CUTILE backend

set -e

MODEL_ID="openai/gpt-oss-20b"
INPUT_FILE="sample_inputs/input_prompt_small.txt"
OUTPUT_LENGTH=100
SUMMARY_FILE="gpt_oss_benchmark_summary.txt"

echo "========================================"
echo " GPT-OSS Performance Benchmark"
echo "========================================"
echo ""
echo "Model: ${MODEL_ID}"
echo "Input: ${INPUT_FILE}"
echo "Output length: ${OUTPUT_LENGTH} tokens"
echo ""

# Clean previous results
rm -f ${SUMMARY_FILE}

echo "Running PyTorch baseline..."
python infer.py \
--model_id ${MODEL_ID} \
--profile \
--sentence_file ${INPUT_FILE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "Running TileGym CUTILE backend..."
python infer.py \
--model_id ${MODEL_ID} \
--use_tilegym \
--use_cutile \
--use_attn \
--profile \
--sentence_file ${INPUT_FILE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "========================================"
echo " Benchmark Results"
echo "========================================"
if [ -f ${SUMMARY_FILE} ]; then
cat ${SUMMARY_FILE}
rm -f ${SUMMARY_FILE}
else
echo "Summary file not found."
fi
echo "========================================"
57 changes: 57 additions & 0 deletions modeling/transformers/bench_mistral.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

# Benchmark script for Mistral-7B-Instruct-v0.3 model
# Compares PyTorch baseline vs TileGym CUTILE backend

set -e

MODEL_ID="mistralai/Mistral-7B-Instruct-v0.3"
INPUT_FILE="sample_inputs/input_prompt_32K.txt"
OUTPUT_LENGTH=50
SUMMARY_FILE="mistral_benchmark_summary.txt"

echo "========================================"
echo " Mistral-7B Performance Benchmark"
echo "========================================"
echo ""
echo "Model: ${MODEL_ID}"
echo "Input: ${INPUT_FILE}"
echo "Output length: ${OUTPUT_LENGTH} tokens"
echo ""

# Clean previous results
rm -f ${SUMMARY_FILE}

echo "Running PyTorch baseline..."
python infer.py \
--model_id ${MODEL_ID} \
--profile \
--sentence_file ${INPUT_FILE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "Running TileGym cuTile backend..."
python infer.py \
--model_id ${MODEL_ID} \
--use_tilegym \
--use_cutile \
--use_attn \
--profile \
--sentence_file ${INPUT_FILE} \
--output_length ${OUTPUT_LENGTH} \
--summary_file ${SUMMARY_FILE}

echo ""
echo "========================================"
echo " Benchmark Results"
echo "========================================"
if [ -f ${SUMMARY_FILE} ]; then
cat ${SUMMARY_FILE}
rm -f ${SUMMARY_FILE}
else
echo "Summary file not found."
fi
echo "========================================"
10 changes: 9 additions & 1 deletion modeling/transformers/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from transformers import AutoTokenizer

from tilegym.transformers import apply_tilegym_kernel_to_deepseek_v2
from tilegym.transformers import apply_tilegym_kernel_to_gemma3
from tilegym.transformers import apply_tilegym_kernel_to_gpt_oss
from tilegym.transformers import apply_tilegym_kernel_to_llama
from tilegym.transformers import apply_tilegym_kernel_to_mistral
from tilegym.transformers import apply_tilegym_kernel_to_qwen2


Expand Down Expand Up @@ -210,9 +213,14 @@ def apply_tilegym_patch(model_id, use_attn=False, use_cutile=False):
apply_tilegym_kernel_to_deepseek_v2(
rope=True, rms_norm=True, swiglu=True, attn=use_attn, moe=True, use_cutile=use_cutile
)

elif "gpt-oss" in model_name:
apply_tilegym_kernel_to_gpt_oss(rope=True, rms_norm=True, swiglu=False, attn=use_attn, use_cutile=use_cutile)
elif "mistral" in model_name:
apply_tilegym_kernel_to_mistral(rope=True, rms_norm=True, swiglu=True, attn=use_attn, use_cutile=use_cutile)
elif "qwen" in model_name:
apply_tilegym_kernel_to_qwen2(rope=True, rms_norm=True, swiglu=True, attn=use_attn, use_cutile=use_cutile)
elif "gemma" in model_name:
apply_tilegym_kernel_to_gemma3(rope=True, rms_norm=True, mlp=True, attn=use_attn, use_cutile=use_cutile)
else:
print(f"Warning: Model {model_id} is not supported in tilegym patch. No optimizations will be applied.")

Expand Down
17 changes: 10 additions & 7 deletions src/tilegym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@
#
# SPDX-License-Identifier: MIT

# Apply experimental kernel tracking patch
from .experimental import _apply_patch as _apply_experimental_patch

# Import logging utilities
from .logger import get_logger
from .logger import set_env_log_level
from .logger import set_log_level
from .logger import warn_once

_apply_experimental_patch()

logger = get_logger()

# Initialize backend selector first to avoid import order issues
# Import other modules
from . import ops # Unified ops module
from .backend import get_available_backends
from .backend import get_available_backends_for_op
from .backend import get_current_backend
Expand All @@ -26,6 +19,16 @@
from .backend import print_registry_info
from .backend import set_backend

# Setup cutile integration
if is_backend_available("cutile"):
# Apply experimental kernel tracking patch
from .experimental import _apply_patch as _apply_experimental_patch

_apply_experimental_patch()

# Import other modules
from . import ops # Unified ops module

try:
import transformers
except ImportError:
Expand Down
6 changes: 6 additions & 0 deletions src/tilegym/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from . import moe_interface

# Re-export key interfaces
from .attn_interface import attention_sink_interface
from .attn_interface import fmha_interface
from .attn_interface import get_attention_sink_interface
from .attn_interface import get_fmha_gemma3_interface
from .attn_interface import get_fmha_interface
from .attn_interface import mla_decoding_interface
from .attn_interface import mla_interface
Expand All @@ -43,8 +46,11 @@
"moe_interface",
# Re-exported submodules
# Key interfaces
"attention_sink_interface",
"fmha_interface",
"get_attention_sink_interface",
"get_fmha_interface",
"get_fmha_gemma3_interface",
"mla_interface",
"mla_decoding_interface",
"fused_moe_kernel_interface",
Expand Down
22 changes: 22 additions & 0 deletions src/tilegym/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@ def relu(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(f"relu is not implemented for {get_current_backend()}")


@dispatch(
"geglu",
)
def geglu(input: torch.Tensor, dim: int = -1, approximate: str = "none") -> torch.Tensor:
"""
Applies the Gated GELU (GEGLU) activation function.

GEGLU(x) = a ⊗ GELU(b)
where a is the first half of the input and b is the second half,
split along the specified dimension.

Args:
input: Input tensor
dim: Dimension along which to split the input (default: -1)
approximate: The approximation type for GELU. Can be 'none' or 'tanh'

Returns:
Output tensor with GEGLU applied
"""
raise NotImplementedError(f"geglu is not implemented for {get_current_backend()}")


@dispatch(
"gelu",
)
Expand Down
Loading