In [1]:
import torch
import triton
import triton.language as tl
import time

In [2]:
def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

In [3]:
def is_hip_mi200():
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == 'hip' and target.arch == 'gfx90a'

In [None]:
"""
MatMul+Relu+Add Fused Optimization.
The kernel uses several optimization techniques:

  1. Shared memory tiling.
  2. Register tiling.
  3. Cooperative fetching.
  4. Operator Fusion
  5. Write cache / epilogue fusion.

"""

# -----------------------------------------------------------------------------
# Tiling parameters - You will need to change these to achieve better results.
# -----------------------------------------------------------------------------
BLOCK_M = 128  # Tile size in the M dimension.
BLOCK_N = 256 # Tile size in the N dimension.
BLOCK_K = 16 # Tile size in the K dimension.
GROUP_SIZE_M = 4 # ‘super-grouping’ blocks in groups of GROUP_SIZE_M rows for L2 Cache Optimization


# -----------------------------------------------------------------------------
# Triton Kernel: Matrix Multiplication + ReLU + Add
#
# The kernel uses:
#   Step 1: Tile assignment (each kernel computes a tile of C)
#   Step 2: Shared memory tiling + Cooperative Fetching: Load tiles of A and B.
#   Step 3: Register tiling: Use a register accumulator.
#   Step 4: Add and ReLU fusion
#   Step 5: Write cache/Epilogue: Write the final tile back to global memory.
# -----------------------------------------------------------------------------

@triton.jit
def matmul_add_relu_kernel_fp16(
    a_ptr, b_ptr, c_ptr, d_ptr,
    M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
    stride_am: tl.constexpr, stride_ak: tl.constexpr,
    stride_bk: tl.constexpr, stride_bn: tl.constexpr,
    stride_cm: tl.constexpr, stride_cn: tl.constexpr,
    stride_dm: tl.constexpr, stride_dn: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    # -------------------------------------------------------------------------
    # Step 1: Tile: Assignment
    #
    # Each kernel instance is mapped to a tile in the output matrix C.
    # Compute the starting indices (m_start, n_start) for this tile.
    # -------------------------------------------------------------------------
    pid = tl.program_id(axis=0)

    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)

    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group


    first_pid_m = group_id * GROUP_SIZE_M

    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)

    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # -------------------------------------------------------------------------
    # Step 2: Register Tiling
    # -------------------------------------------------------------------------
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16)

    # -------------------------------------------------------------------------
    # Step 3: Shared Memory Tiling & Cooperative Fetching.
    # Compute pointers to the sub-tiles of A and B that are needed to compute
    # the current C tile. The offsets here serve to load BLOCK_M x BLOCK_K
    # and BLOCK_K x BLOCK_N blocks from A and B respectively.
    # -------------------------------------------------------------------------
    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    for k in range(0, tl.cdiv(K, BLOCK_K)):

        # mask_a = (offs_am[:, None] < M) & ((k * BLOCK_K + offs_k[None, :]) < K)
        # mask_b = ((k * BLOCK_K + offs_k[:, None]) < K) & (offs_bn[None, :] < N)
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)

        accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float16)

        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    # -------------------------------------------------------------------------
    # Step 4: Apply ReLU and Add C to the accumulator
    # -------------------------------------------------------------------------
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

    c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
    c = tl.load(c_ptrs, mask=out_mask, other=0.0)

    accumulator += c

    accumulator = tl.maximum(accumulator, 0.0)

    # -------------------------------------------------------------------------
    # Step 5: Write Cache / Epilogue Fusion: Write the computed tile to D.
    # -------------------------------------------------------------------------
    d_ptrs = d_ptr + stride_dm * offs_m[:, None] + stride_dn * offs_n[None, :]
    tl.store(d_ptrs, accumulator, mask=out_mask)

In [5]:
def matmul_add_relu_fp16(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
    """
    Computes Output = ReLU(A @ B + C) using fp16 precision for maximum throughput.
    """
    M, K = a.shape
    K2, N = b.shape
    assert K == K2, "Incompatible dimensions"

    d = torch.empty((M, N), device=a.device, dtype=torch.float16)

    grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )

    matmul_add_relu_kernel_fp16[grid](
        a, b, c, d,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        d.stride(0), d.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        GROUP_SIZE_M=GROUP_SIZE_M
    )
    return d

In [6]:
# Reference implementation using PyTorch
def reference_matmul_add_relu(A, B, C):
    result = torch.matmul(A, B).add(C).relu_()
    return result

In [7]:
# -----------------------------------------------------------------------------
# Accuracy Tests
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    torch.manual_seed(0)
    a = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    b = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    c = torch.randn((512, 512), device=torch.device("cuda"), dtype=torch.float16)
    triton_output = matmul_add_relu_fp16(a, b, c)
    torch_output = reference_matmul_add_relu(a, b, c)
    print(f"triton_output_with_fp16_inputs={triton_output}")
    print(f"torch_output_with_fp16_inputs={torch_output}")
    rtol = 1e-2 if is_hip_mi200() else 0.032
    if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
        print("✅ Triton and Torch match")
    else:
        diff = triton_output - torch_output
        abs_diff = torch.abs(diff)
        max_abs_diff = torch.max(abs_diff)
        print(f"❌ Triton and Torch differ: {max_abs_diff=}")


triton_output_with_fp16_inputs=tensor([[ 0.0000,  6.1250,  0.0000,  ..., 10.0625,  0.0000,  0.0000],
        [ 7.9102, 15.6328, 26.6094,  ..., 11.4609,  5.3750, 18.6250],
        [ 2.7246,  0.0000,  0.0000,  ...,  0.0000, 26.0781,  0.0000],
        ...,
        [ 0.4448, 75.1875,  0.0000,  ..., 26.2812,  0.0000,  0.0000],
        [ 6.9492,  1.1230,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [27.6094, 26.9531, 22.9219,  ..., 13.5391,  6.0508, 21.6250]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[ 0.0000,  6.1289,  0.0000,  ..., 10.0391,  0.0000,  0.0000],
        [ 7.9102, 15.6328, 26.6250,  ..., 11.4531,  5.3945, 18.6562],
        [ 2.7266,  0.0000,  0.0000,  ...,  0.0000, 26.1250,  0.0000],
        ...,
        [ 0.4316, 75.2500,  0.0000,  ..., 26.2812,  0.0000,  0.0000],
        [ 6.9570,  1.1260,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [27.6406, 26.9531, 22.9375,  ..., 13.5625,  6.0391, 21.6406]],
       device='cuda:0', dt

In [13]:
# -----------------------------------------------------------------------------
# Performance Benchmark
# IMPORTANT: DO NOT CHANGE THIS CODE.
# THIS IS THE EXACT CODE THAT WILL BE USED TO GRADE YOUR IMPLEMENTATION.
# ANY CHANGES TO THIS CODE (INCLUDING DIMENSIONS, REPEATS, etc.)
# WILL CAUSE YOU TO HAVE DIFFERENT SPEEDUP RESULTS.
# -----------------------------------------------------------------------------
M = 2048
K = 2048
N = 2048

# KEEP THESE MATRICES IN FP16. FP32 WILL NOT PROVIDE ACCURATE RESULTS
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C = torch.randn((M, N), device="cuda", dtype=torch.float16)

# warmup
_ = matmul_add_relu_fp16(A, B, C)
_ = reference_matmul_add_relu(A, B, C)

REPEATS = 5000

# time your implementation
print("Triton implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = matmul_add_relu_fp16(A, B, C)
torch.cuda.synchronize()
triton_time = (time.perf_counter() - start) / REPEATS

# time pytorch
print("PyTorch implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = reference_matmul_add_relu(A, B, C)
torch.cuda.synchronize()
torch_time = (time.perf_counter() - start) / REPEATS

print(f"Performance comparison for matrix multiplication ({M}x{K} @ {K}x{N}):")
print(f"Triton implementation: {triton_time*1000:.2f} ms")
print(f"PyTorch implementation: {torch_time*1000:.2f} ms")

print(f"\nSpeedup of Triton vs PyTorch: {torch_time/triton_time:.2f}x")

Triton implementation
PyTorch implementation
Performance comparison for matrix multiplication (2048x2048 @ 2048x2048):
Triton implementation: 0.73 ms
PyTorch implementation: 1.02 ms

Speedup of Triton vs PyTorch: 1.41x


In [None]:
def grid_search():
    # Define parameter ranges that are likely to work well on a T4 GPU with FP16
    # T4 has tensor cores that work best with multiples of 16 for FP16 operations
    block_m_values = [64, 128, 256]
    block_n_values = [64, 128, 256]
    block_k_values = [16, 32, 64]
    group_size_m = [4, 8, 16, 32]

    best_speedup = 0
    best_params = None
    results = []


    M = 2048
    K = 2048
    N = 2048


    A = torch.randn((M, K), device="cuda", dtype=torch.float16)
    B = torch.randn((K, N), device="cuda", dtype=torch.float16)
    C = torch.randn((M, N), device="cuda", dtype=torch.float16)


    _ = reference_matmul_add_relu(A, B, C)

    torch.cuda.synchronize()
    start = time.perf_counter()
    repeats = 500
    for _ in range(repeats):  # Use fewer iterations than the full benchmark
        _ = reference_matmul_add_relu(A, B, C)
    torch.cuda.synchronize()
    torch_time = (time.perf_counter() - start) / repeats

    print(f"Reference PyTorch implementation: {torch_time*1000:.2f} ms")

    # Start grid search
    print("\nStarting grid search...")
    for block_m in block_m_values:
        for block_n in block_n_values:
            for block_k in block_k_values:
                for group_size in group_size_m:
                    # Skip configurations that might be problematic
                    # Very large tiles can cause issues with shared memory or register usage
                    if block_m * block_n * block_k > 1048576:  # Skip very large combinations
                        continue

                    # Update global parameters
                    global BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M
                    BLOCK_M = block_m
                    BLOCK_N = block_n
                    BLOCK_K = block_k
                    GROUP_SIZE_M = group_size

                    try:
                        # Warmup
                        _ = matmul_add_relu_fp16(A, B, C)

                        # Benchmark
                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        for _ in range(repeats):  # Use fewer iterations for grid search
                            _ = matmul_add_relu_fp16(A, B, C)
                        torch.cuda.synchronize()
                        triton_time = (time.perf_counter() - start) / repeats

                        # Calculate speedup
                        speedup = torch_time / triton_time

                        # Print result
                        print(f"BLOCK_M={block_m}, BLOCK_N={block_n}, BLOCK_K={block_k}, GROUP_SIZE_M={group_size}, "
                              f"Triton: {triton_time*1000:.2f} ms, Speedup: {speedup:.2f}x")

                        # Store result
                        results.append((block_m, block_n, block_k, group_size, triton_time, speedup))

                        # Update best parameters
                        if speedup > best_speedup:
                            best_speedup = speedup
                            best_params = (block_m, block_n, block_k, group_size)
                    except Exception as e:
                        print(f"Error with BLOCK_M={block_m}, BLOCK_N={block_n}, BLOCK_K={block_k}, GROUP_SIZE_M={group_size}: {e}")

    # Sort results by speedup
    results.sort(key=lambda x: x[5], reverse=True)

    print("\nTop 10 Configurations:")
    for i, (block_m, block_n, block_k, group_size, triton_time, speedup) in enumerate(results[:10]):
        print(f"{i+1}. BLOCK_M={block_m}, BLOCK_N={block_n}, BLOCK_K={block_k}, GROUP_SIZE_M={group_size}, "
              f"Triton: {triton_time*1000:.2f} ms, Speedup: {speedup:.2f}x")

    print(f"\nBest Configuration: BLOCK_M={best_params[0]}, BLOCK_N={best_params[1]}, "
          f"BLOCK_K={best_params[2]}, GROUP_SIZE_M={best_params[3]}, Speedup={best_speedup:.2f}x")



    return best_params

best_param = grid_search()
BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M = 128, 256, 16, 4

Reference PyTorch implementation: 0.99 ms

Starting grid search...
BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, GROUP_SIZE_M=4, Triton: 1.40 ms, Speedup: 0.71x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, GROUP_SIZE_M=8, Triton: 1.33 ms, Speedup: 0.75x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, GROUP_SIZE_M=16, Triton: 1.45 ms, Speedup: 0.68x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, GROUP_SIZE_M=32, Triton: 1.83 ms, Speedup: 0.54x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, GROUP_SIZE_M=4, Triton: 1.28 ms, Speedup: 0.77x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, GROUP_SIZE_M=8, Triton: 1.20 ms, Speedup: 0.83x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, GROUP_SIZE_M=16, Triton: 1.36 ms, Speedup: 0.73x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, GROUP_SIZE_M=32, Triton: 1.72 ms, Speedup: 0.58x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, GROUP_SIZE_M=4, Triton: 1.17 ms, Speedup: 0.85x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, GROUP_SIZE_M=8, Triton: 1.11 ms, Speedup: 0.89x
BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, GROUP_SIZE_M=16, Triton: 1.18 ms, Speedup: 0.84x
BLOC