# TileLang-enabled Flash Attention on AMD Instinct MI300X GPUs

When training or deploying large language models (LLMs), such as DeepSeek-V3 or gpt-oss, on AMD Instinct™ MI300X GPUs, the Flash Attention kernel can become a critical performance point. Traditional implementations either suffer from high latency or require tedious low-level [HIP](https://rocm.docs.amd.com/projects/HIP/en/latest/index.html) coding to tap the hardware's potential. [TileLang](https://tilelang.com/), a high-level kernel domain-specific language, solves this pain point. It lets you run optimized Flash Attention on an Instinct MI300X GPU with concise code, providing some advantages over Triton.

This tutorial guides you through the entire process: setting up the ROCm environment, running the TileLang-based Flash Attention kernel, and finally verifying its correctness and performance. By the end, you’ll master leveraging TileLang to accelerate core LLM kernels on the AMD Instinct MI300X.

## Prerequisites

This tutorial was developed and tested using the following setup.

### Operating system

* **Ubuntu 22.04**: Ensure your system is running Ubuntu 22.04.

### Hardware

* **AMD Instinct MI300X GPU**: This tutorial was tested on an AMD Instinct MI300X GPU. Ensure you are using an AMD Instinct GPU with ROCm support and that your system meets the [official requirements](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html).

### Software

* **ROCm 7.0**: Install and verify ROCm by following the [ROCm install guide](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html).

  After installation, confirm your setup using:

  ```bash
  amd-smi
  ```

  This command lists your AMD GPUs with relevant details.

* **Docker**: Ensure Docker is installed and configured correctly. Follow the Docker installation guide for your operating system.

  **Note**: Ensure the Docker permissions are correctly configured. To configure permissions to allow non-root access, run the following commands:

  ```bash
  sudo usermod -aG docker $USER
  newgrp docker
  ```

  Verify Docker is working correctly with:

  ```bash
  docker run hello-world
  ```

## Set up the TileLang development environment

Follow these steps to set up the environment, launch Jupyter Notebooks, and verify the TileLang installation.

### Step 1: Launch the Docker container

Launch the Docker container. From your host machine, run this command:

```bash
docker run \
  -it \
  --device=/dev/kfd \
  --device=/dev/dri \
  --security-opt seccomp=unconfined \
  --network=host \
  --cap-add=SYS_PTRACE \
  --group-add video \
  --shm-size 32g \
  --ipc=host \
  --name tilelang-FA-notebook \
  danielamd/tilelang-amd-mi300:v0.1.7
```

**Note**: This command launches a Docker container where you can perform all the work in this tutorial. You can download this notebook from the [AI Developer Hub GitHub repository](https://github.com/ROCm/gpuaidev).

### Step 2: Launch Jupyter Notebooks in the container

Inside the Docker container, install Jupyter using the following command:

```bash
pip install jupyter
```

Start the Jupyter server:

```bash
jupyter-lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root
```

**Note**: Ensure port `8888` is not already in use on your system before running the above command. If it is, you can specify a different port by replacing `--port=8888` with another port number, for example, `--port=8890`. The rest of this tutorial can run as interactive blocks in your Jupyter notebook after you upload this tutorial to your server.

### Step 3: Verify the TileLang installation

Verify that TileLang is installed correctly by importing the library and checking the version:

In [None]:
%%bash
python -c "import tilelang; print(tilelang.__version__)"

## Run Flash Attention on the Instinct MI300X

This section demonstrates the full workflow using a typical LLM scenario (with parameters ``batch=1``, ``heads=8``, ``seq_len=4096``, and ``dim=128``). The workflow prepares the data, runs the kernel, verifies the results, and tests performance.

### Environment initialization

Import the required libraries and set the computing device:

In [None]:
import os
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
import itertools
from functools import partial
from itertools import product

# Add TileLang path (if not set in terminal)
# os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":/root/TileLang"
# sys.path.append("/root/TileLang")

# Set device to AMD GPU (ROCm uses "cuda" alias)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda", "This tutorial requires AMD GPU with ROCm"
print(f"Running on device: {torch.cuda.get_device_name(device)}")

The expected output is: `Running on device: AMD Instinct MI300X` (or your MI300 series GPU).

### Define the PyTorch reference implementation

The following code shows a standard PyTorch Flash Attention implementation:

In [None]:
def ref_program(Q, K, V, is_causal, groups=1):
    assert Q.size(
        2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
    assert Q.size(
        2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
    dim = Q.size(-1)
    K = K.repeat_interleave(groups, dim=2)
    V = V.repeat_interleave(groups, dim=2)
    scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
    if is_causal:
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
    return output

### Define auxiliary functions

The following code implements configuration generation for autotuning. The autotuner can optimize the block tile size, number of threads per block, number of pipeline stages, and rasterization settings to improve L2 cache reuse.

In [None]:
def get_configs():
    """Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
    block_M = [32, 64, 128, 256]
    block_N = [32, 64, 128, 256]
    threads = [128, 256, 512]
    num_split_q = [64, 128, 256]
    num_stages = [0, 1]
    enable_rasterization = [True]
    k_pack = [2]
    panel_size = [7, 8]
    qk_coalesced_width = [8]
    v_coalesced_width = [4]

    valid_configs = []

    for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
                                                                  threads, num_stages,
                                                                  enable_rasterization, k_pack,
                                                                  panel_size, qk_coalesced_width,
                                                                  v_coalesced_width):
        valid_configs.append({
            "block_M": m,
            "block_N": n,
            "num_split_q": s,
            "threads": t,
            "num_stages": stages,
            "enable_rasterization": r,
            "k_pack": k,
            "panel_size": p,
            "qk_coalesced_width": qkw,
            "v_coalesced_width": vw,
        })
    return valid_configs

# Custom supply function to ensure tensors are created on GPU
def supply_tensors_gpu(params):
    """Supply function that creates tensors on GPU for ROCm/HIP."""
    tensors = []
    for param in params:
        if hasattr(param, 'shape') and hasattr(param, 'dtype'):
            # Force creation on GPU device
            shape = [int(s) for s in param.shape]
            tensor = torch.randn(shape, dtype=param.dtype, device='cuda')
            tensors.append(tensor)
        else:
            tensors.append(param)
    return tensors

### Implement the TileLang Flash Attention kernel

The following code implements the core `fast_flashattn` function with tiled computation logic, optimized for the Instinct MI300X:

In [None]:
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
    batch,
    heads,
    seq_len,
    dim,
    is_causal,
    groups,
    block_M: int,
    block_N: int,
    num_split_q: int,
    threads: int,
    num_stages: int,
    enable_rasterization: bool,
    k_pack: int,
    panel_size: int,
    qk_coalesced_width: int,
    v_coalesced_width: int,
):
    scale = (1.0 / dim)**0.5
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim]
    kv_shape = [batch, seq_len, head_kv, dim]
    dtype = "float16"
    accum_dtype = "float"

    vec_size = qk_coalesced_width
    v_vec_size = v_coalesced_width

    @T.prim_func
    def main(
            Q: T.Tensor(q_shape, dtype),
            K: T.Tensor(kv_shape, dtype),
            V: T.Tensor(kv_shape, dtype),
            Output: T.Tensor(q_shape, dtype),
    ):
        with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
            T.use_swizzle(panel_size, enable=enable_rasterization)

            bz = byz_combined // heads
            by = byz_combined % heads

            num_q_blocks = T.ceildiv(seq_len, block_M)

            bx = T.alloc_var("int32")
            bx = b_split

            with T.While(bx < num_q_blocks):
                acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
                m_i = T.alloc_fragment([block_M], accum_dtype)
                l_i = T.alloc_fragment([block_M], accum_dtype)
                T.fill(acc_o, 0)
                T.fill(m_i, -T.infinity(accum_dtype))
                T.fill(l_i, 0)

                current_bx = bx
                q_block_offset = current_bx * block_M

                Q_shared = T.alloc_shared([block_M, dim], dtype)
                K_shared = T.alloc_shared([block_N, dim], dtype)
                V_shared = T.alloc_shared([block_N, dim], dtype)
                # Use register fragment for P instead of shared memory to reduce LDS usage
                acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)

                acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
                m_prev = T.alloc_fragment([block_M], accum_dtype)
                scale_factor = T.alloc_fragment([block_M], accum_dtype)

                T.copy(
                    Q[bz, q_block_offset:q_block_offset + block_M, by, :],
                    Q_shared,
                    coalesced_width=vec_size)

                loop_end_k = T.ceildiv(q_block_offset + block_M,
                                       block_N) if is_causal else T.ceildiv(seq_len, block_N)

                row_sum = T.alloc_fragment([block_M], accum_dtype)

                for k in T.Pipelined(loop_end_k, num_stages=num_stages):
                    kv_idx = k * block_N

                    T.copy(
                        K[bz, kv_idx:kv_idx + block_N, by // groups, :],
                        K_shared,
                        coalesced_width=vec_size)
                    T.copy(
                        V[bz, kv_idx:kv_idx + block_N, by // groups, :],
                        V_shared,
                        coalesced_width=v_vec_size)

                    if is_causal:
                        for i, j in T.Parallel(block_M, block_N):
                            acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
                                                         -T.infinity(acc_s.dtype))
                    else:
                        T.clear(acc_s)
                    T.gemm(
                        Q_shared,
                        K_shared,
                        acc_s,
                        transpose_B=True,
                        k_pack=k_pack,
                        policy=GemmWarpPolicy.FullRow,
                    )

                    T.copy(m_i, m_prev)
                    T.reduce_max(acc_s, m_i, dim=1, clear=False)
                    for i in T.Parallel(block_M):
                        m_i[i] = T.max(m_i[i], m_prev[i])

                    for i in T.Parallel(block_M):
                        sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
                        l_i[i] *= sf
                        scale_factor[i] = sf

                    for i, j in T.Parallel(block_M, dim):
                        acc_o[i, j] *= scale_factor[i]

                    for i, j in T.Parallel(block_M, block_N):
                        acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale)

                    T.reduce_sum(acc_s, row_sum, dim=1)
                    for i in T.Parallel(block_M):
                        l_i[i] += row_sum[i]

                    # Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V
                    T.copy(acc_s, acc_s_cast)

                    T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)

                l_inv = T.alloc_fragment([block_M], accum_dtype)
                for i in T.Parallel(block_M):
                    safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
                    l_inv[i] = 1.0 / safe_l

                for i, j in T.Parallel(block_M, dim):
                    Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]

                bx = current_bx + num_split_q

    return main

The function uses two key decorators:

* **@tilelang.autotune**: Enables autotuning by specifying the candidate configuration set, tensor caching strategy, and GPU tensor supply function.
* **@tilelang.jit**: Enables just-in-time compilation to convert TileLang code into GPU-executable HIP kernels. `out_idx=[3]` specifies that the fourth parameter (`Output`) is the output tensor.

#### Understanding the Flash Attention kernel implementation

The kernel implementation consists of the following key stages:

**Initialization**

* Calculates the normalization factor (`scale`) and the number of KV heads (`head_kv`).
* Defines the tensor shapes and data types (`float16` for computation and `float` for accumulation to ensure precision).

**Kernel definition**

* Uses `@T.prim_func` to define the core computation function.
* `T.Kernel` specifies parallel dimensions: `num_split_q` (parallel splitting of `Q`) and `batch*heads` (combined parallelism of `batch` and `heads`).
* The `threads` parameter specifies the number of threads per thread block.

**Hardware optimization**

* `T.use_swizzle` enables memory reordering optimization to improve memory access efficiency on the Instinct MI300X.
* Decomposes `byz_combined` into batch index (`bz`) and attention head index (`by`) for parallel processing.

**Q tile processing and cache initialization**

* Calculates the number of Q tiles (`num_q_blocks`) and traverses each ``Q`` tile using a while loop.
* Creates `acc_o` (output accumulator), `m_i` (row-wise maximum), and `l_i` (row-wise sum) for numerical stability in the softmax computation.
* Allocates shared memory (`Q_shared`, `K_shared`, and `V_shared`) to cache tile data and allocates registers (`acc_s_cast`) to cache intermediate results.
* Loads the current ``Q`` tile from HBM to shared memory with `coalesced_width` to improve memory bandwidth utilization.

**K/V tile traversal and core computation**

* Determines the traversal end of K/V tiles based on causality. In causal mode, only K/V tiles before the current ``Q`` position are processed.
* Loads the K/V tiles to shared memory using the same memory coalescing strategy as ``Q``.
* Applies causal masking by setting non-compliant positions in `acc_s` to `-∞`.
* Calls TileLang's built-in GEMM primitive to compute `QK^T` using `transpose_B=True`, `k_pack` for data packing optimization, and `GemmWarpPolicy.FullRow` to adapt to the Instinct MI300X warp scheduling strategy.
* Computes softmax incrementally (finding maximum value, scale, and compute exp) to avoid numerical overflow.
* Performs the GEMM computation between normalized attention weights (`acc_s`) and V, accumulating the results into `acc_o`.

**Output generation**

* Calculates the reciprocal of the row sum (`l_inv`) and performs final normalization on the accumulator `acc_o`.
* Updates `bx` to process the next `Q` tile.

### Prepare the input parameters

Configure the input parameters for the Flash Attention kernel. The following example uses a typical LLM scenario:

In [None]:
# Scenario parameters (typical LLM decoding/inference)
batch = 1
heads = 8
seq_len = 4096
dim = 128
is_causal = True  # Enable causal masking for generation
groups = 1        # No grouped attention
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
    total_flops *= 0.5

### Run TileLang Flash Attention

Trigger autotuning to search for the optimal configuration:

In [None]:
# Run with autotuning (first run takes ~1s for configuration search)
print("Starting autotuning for FlashAttention-V2...")
kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups)
print(f"Autotuning finished. Best Configuration: {kernel.config}")

### Verify correctness and test performance

After successfully invoking the kernel and finding an optimal configuration, verify its correctness by comparing your results against the native PyTorch implementation, then benchmark the performance.

#### Correctness verification

TileLang includes a built-in profiler to verify correctness using `torch.allclose` to check if results are consistent within a small precision tolerance:

In [None]:
# Get PyTorch native result (reference)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
# best results and all tune results are stored in profiler
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
print("Verifying correctness...")
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")

The expected output is: `Verifying correctness...All checks pass.`

#### Performance testing

Benchmark the kernel performance using the profiler's built-in benchmark tool, which measures kernel latency by synchronizing before and after kernel execution:

In [None]:
latency = profiler.do_bench(ref_program_processed, warmup=100)
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")

latency = profiler.do_bench(warmup=100)
print(
        f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
)

The expected output is:

    Reference (PyTorch): 0.89 ms | 77.18 TFlops
    Fast Flash Attention V2 (TileLang): 0.37 ms | 187.52 TFlops

The TileLang implementation achieves more than twice the speedup compared to the native PyTorch implementation.

## Summary

Congratulations! By following this TileLang Flash Attention tutorial, you learned how to implement and optimize Flash Attention on an AMD Instinct MI300X GPU using TileLang.

**Key takeaways:**

- **Environment setup**: ROCm 7.0 and the TileLang Docker environment are essential for Instinct MI300X compatibility.
- **Performance optimization**: Tiled computation with optimized block sizes and Instinct MI300X-specific optimizations (such as memory swizzling) greatly reduces latency compared to PyTorch.
- **Practical application**: The kernel can be integrated into LLM frameworks, like vLLM or SGLang, for end-to-end acceleration.

## Next steps

1. **Experiment with different scenarios**: Modify `batch`, `seq_len`, or `heads` to match the LLM requirements (for example, `batch=32` for batch inference).
2. **Advanced autotuning**: Extend `get_configs()` with additional tile sizes to explore further performance improvements.
3. **Framework integration**: Replace the attention operator in models such as DeepSeek or Llama, using `fast_flashattn` for full-model acceleration.

## Additional resources

- [TileLang GitHub](https://github.com/tile-ai/TileLang)
- [ROCm documentation](https://rocm.docs.amd.com/en/latest/index.html)
- [AMD AI Developer Hub](https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/)