In [22]:
import torch
import triton
import triton.language as tl
import time

In [23]:
def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

In [24]:
def is_hip_mi200():
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == 'hip' and target.arch == 'gfx90a'

In [26]:
"""
PA2 Part 2: MatMul+Relu+Add Fused Optimization.
The kernel uses several optimization techniques:

  1. Shared memory tiling.
  2. Register tiling.
  3. Cooperative fetching.
  4. Operator Fusion
  5. Write cache / epilogue fusion.

Fill in the missing parts (marked with TODO).
"""

# -----------------------------------------------------------------------------
# Tiling parameters - You will need to change these to achieve better results.
# -----------------------------------------------------------------------------
BLOCK_M = 128  # Tile size in the M dimension.
BLOCK_N = 256 # Tile size in the N dimension.
BLOCK_K = 32 # Tile size in the K dimension.


# -----------------------------------------------------------------------------
# Triton Kernel: Matrix Multiplication + ReLU + Add
#
# The kernel uses:
#   Step 1: Tile assignment (each kernel computes a tile of C)
#   Step 2: Shared memory tiling + Cooperative Fetching: Load tiles of A and B.
#   Step 3: Register tiling: Use a register accumulator.
#   Step 4: Add and ReLU fusion
#   Step 5: Write cache/Epilogue: Write the final tile back to global memory.
# -----------------------------------------------------------------------------
# @triton.autotune(
#     configs=get_cuda_autotune_config(),
#     key=['M', 'N', 'K'],
# )
@triton.jit
def matmul_add_relu_kernel_fp16(
    a_ptr, b_ptr, c_ptr, d_ptr,
    M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
    stride_am: tl.constexpr, stride_ak: tl.constexpr,
    stride_bk: tl.constexpr, stride_bn: tl.constexpr,
    stride_cm: tl.constexpr, stride_cn: tl.constexpr,
    stride_dm: tl.constexpr, stride_dn: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # -------------------------------------------------------------------------
    # Step 1: Tile: Assignment
    #
    # Each kernel instance is mapped to a tile in the output matrix C.
    # Compute the starting indices (m_start, n_start) for this tile.
    # -------------------------------------------------------------------------
    # TODO: Compute the tile indices using program_id(0) for M and program_id(1) for N.
    pid1 = tl.program_id(0)
    pid2 = tl.program_id(1)

    # -------------------------------------------------------------------------
    # Step 2: Register Tiling
    # -------------------------------------------------------------------------
    # TODO: Initialize the accumulator "acc" with zeros (dtype: float16).
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16)

    # -------------------------------------------------------------------------
    # Step 3: Shared Memory Tiling & Cooperative Fetching.
    # Compute pointers to the sub-tiles of A and B that are needed to compute
    # the current C tile. The offsets here serve to load BLOCK_M x BLOCK_K
    # and BLOCK_K x BLOCK_N blocks from A and B respectively.
    # -------------------------------------------------------------------------
    # TODO: Finish code below
    offs_m = pid1 * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid2 * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    a_ptrs = a_ptr + (offs_m[:, None] * stride_am)
    b_ptrs = b_ptr + (offs_n[None, :] * stride_bn)


    mask_a = (offs_m[:, None] < M)
    mask_b = (offs_n[None, :] < N)

    for k in range(0, K // BLOCK_K):
        a = tl.load(a_ptrs + offs_k[None, :] * stride_ak, mask=mask_a & (offs_k[None, :] <  K - k * BLOCK_K), other=0.0)
        b = tl.load(b_ptrs + offs_k[:, None] * stride_bk, mask=mask_b & (offs_k[:, None] <  K - k * BLOCK_K), other=0.0)
        acc += tl.dot(a, b, out_dtype=tl.float16)

        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    # -------------------------------------------------------------------------
    # Step 4: Apply ReLU and Add C to the accumulator
    # -------------------------------------------------------------------------
    # TODO: Finish code below
    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    res = tl.maximum(0, acc + tl.load(c_ptrs, mask=mask, other=0.0))

    # -------------------------------------------------------------------------
    # Step 5: Write Cache / Epilogue Fusion: Write the computed tile to D.
    # -------------------------------------------------------------------------
    # TODO: Finish code below
    tl.store(d_ptr + (offs_m[:, None] * stride_dm + offs_n[None, :] * stride_dn), res, mask=mask)

In [27]:
def matmul_add_relu_fp16(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
    """
    Computes Output = ReLU(A @ B + C) using fp16 precision for maximum throughput.
    """
    M, K = a.shape
    K2, N = b.shape
    assert K == K2, "Incompatible dimensions"

    d = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # Create launch grid
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_add_relu_kernel_fp16[grid](
        a, b, c, d,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        d.stride(0), d.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
    )
    return d

In [28]:
# Reference implementation using PyTorch
def reference_matmul_add_relu(A, B, C):
    result = torch.matmul(A, B).add(C).relu_()
    return result

In [29]:
# -----------------------------------------------------------------------------
# Accuracy Tests
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    torch.manual_seed(0)
    sz = 512
    a = torch.randn((sz, sz), device=torch.device("cuda"), dtype=torch.float16)
    b = torch.randn((sz, sz), device=torch.device("cuda"), dtype=torch.float16)
    c = torch.randn((sz, sz), device=torch.device("cuda"), dtype=torch.float16)
    triton_output = matmul_add_relu_fp16(a, b, c)
    torch_output = reference_matmul_add_relu(a, b, c)
    print(f"triton_output_with_fp16_inputs={triton_output}")
    print(f"torch_output_with_fp16_inputs={torch_output}")
    rtol = 1e-2 if is_hip_mi200() else 0.032
    if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
        print("✅ Triton and Torch match")
    else:
        diff = triton_output - torch_output
        abs_diff = torch.abs(diff)
        max_abs_diff = torch.max(abs_diff)
        print(f"❌ Triton and Torch differ: {max_abs_diff=}")

triton_output_with_fp16_inputs=tensor([[ 0.0000,  6.1250,  0.0000,  ..., 10.0625,  0.0000,  0.0000],
        [ 7.9102, 15.6328, 26.6094,  ..., 11.4609,  5.3750, 18.6250],
        [ 2.7246,  0.0000,  0.0000,  ...,  0.0000, 26.0781,  0.0000],
        ...,
        [ 0.4448, 75.1875,  0.0000,  ..., 26.2812,  0.0000,  0.0000],
        [ 6.9492,  1.1230,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [27.6094, 26.9531, 22.9219,  ..., 13.5391,  6.0508, 21.6250]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[ 0.0000,  6.1289,  0.0000,  ..., 10.0391,  0.0000,  0.0000],
        [ 7.9102, 15.6328, 26.6250,  ..., 11.4531,  5.3945, 18.6562],
        [ 2.7266,  0.0000,  0.0000,  ...,  0.0000, 26.1250,  0.0000],
        ...,
        [ 0.4316, 75.2500,  0.0000,  ..., 26.2812,  0.0000,  0.0000],
        [ 6.9570,  1.1260,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [27.6406, 26.9531, 22.9375,  ..., 13.5625,  6.0391, 21.6406]],
       device='cuda:0', dt

In [30]:
# -----------------------------------------------------------------------------
# Performance Benchmark
# IMPORTANT: DO NOT CHANGE THIS CODE.
# THIS IS THE EXACT CODE THAT WILL BE USED TO GRADE YOUR IMPLEMENTATION.
# ANY CHANGES TO THIS CODE (INCLUDING DIMENSIONS, REPEATS, etc.)
# WILL CAUSE YOU TO HAVE DIFFERENT SPEEDUP RESULTS.
# -----------------------------------------------------------------------------
M = 2048
K = 2048
N = 2048

# KEEP THESE MATRICES IN FP16. FP32 WILL NOT PROVIDE ACCURATE RESULTS
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C = torch.randn((M, N), device="cuda", dtype=torch.float16)

# warmup
_ = matmul_add_relu_fp16(A, B, C)
_ = reference_matmul_add_relu(A, B, C)

REPEATS = 5000

# time your implementation
print("Triton implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = matmul_add_relu_fp16(A, B, C)
torch.cuda.synchronize()
triton_time = (time.perf_counter() - start) / REPEATS

# time pytorch
print("PyTorch implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = reference_matmul_add_relu(A, B, C)
torch.cuda.synchronize()
torch_time = (time.perf_counter() - start) / REPEATS

print(f"Performance comparison for matrix multiplication ({M}x{K} @ {K}x{N}):")
print(f"Triton implementation: {triton_time*1000:.2f} ms")
print(f"PyTorch implementation: {torch_time*1000:.2f} ms")

print(f"\nSpeedup of Triton vs PyTorch: {torch_time/triton_time:.2f}x")

Triton implementation
PyTorch implementation
Performance comparison for matrix multiplication (2048x2048 @ 2048x2048):
Triton implementation: 0.77 ms
PyTorch implementation: 1.02 ms

Speedup of Triton vs PyTorch: 1.34x


In [34]:
import random

# Initialize counters
passed_tests = 0
failed_tests = 0
failed_shapes = []

# Function to generate random dimensions
def generate_random_dimensions():
    M = random.randint(1, 2048)  # Rows in a
    K = random.randint(1, 2048)  # Columns in a and rows in b
    N = random.randint(1, 2048)  # Columns in b
    M, K, N = 512, 512, 512
    return M, K, N

# Run 100 tests
num = 100
for i in range(num):
    print("Test", i)
    # Generate random dimensions
    M, K, N = generate_random_dimensions()

    # Set a manual seed for reproducibility
    torch.manual_seed(i)

    # Generate random tensors with the generated dimensions
    s = 0.1
    a = torch.randn((M, K), device=torch.device("cuda"), dtype=torch.float16) * s
    b = torch.randn((K, N), device=torch.device("cuda"), dtype=torch.float16) * s
    c = torch.randn((M, N), device=torch.device("cuda"), dtype=torch.float16) * s

    # Compute outputs using Triton and reference implementations
    triton_output = matmul_add_relu_fp16(a, b, c)
    torch_output = reference_matmul_add_relu(a, b, c)

    # Define relative tolerance
    rtol = 1e-2 if is_hip_mi200() else 0.032

    # Check if outputs are close within the specified tolerance
    if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
        passed_tests += 1  # Increment counter if the test passes
    else:
        failed_tests += 1  # Increment counter if the test fails
        failed_shapes.append((a.shape, b.shape, c.shape))  # Log the shapes

# After the loop, print the number of passed and failed tests
print(f"{passed_tests} out of {num} tests passed.")
print(f"{failed_tests} out of {num} tests failed.")

# Print the shapes of matrices that failed
if failed_tests > 0:
    print("Shapes of matrices that failed:")
    for idx, shapes in enumerate(failed_shapes):
        a_shape, b_shape, c_shape = shapes
        print(f"Test {idx + 1}:")
        print(f"  a.shape: {a_shape}")
        print(f"  b.shape: {b_shape}")
        print(f"  c.shape: {c_shape}")

Test 0
Test 1
Test 2
Test 3
Test 4
Test 5
Test 6
Test 7
Test 8
Test 9
Test 10
Test 11
Test 12
Test 13
Test 14
Test 15
Test 16
Test 17
Test 18
Test 19
Test 20
Test 21
Test 22
Test 23
Test 24
Test 25
Test 26
Test 27
Test 28
Test 29
Test 30
Test 31
Test 32
Test 33
Test 34
Test 35
Test 36
Test 37
Test 38
Test 39
Test 40
Test 41
Test 42
Test 43
Test 44
Test 45
Test 46
Test 47
Test 48
Test 49
Test 50
Test 51
Test 52
Test 53
Test 54
Test 55
Test 56
Test 57
Test 58
Test 59
Test 60
Test 61
Test 62
Test 63
Test 64
Test 65
Test 66
Test 67
Test 68
Test 69
Test 70
Test 71
Test 72
Test 73
Test 74
Test 75
Test 76
Test 77
Test 78
Test 79
Test 80
Test 81
Test 82
Test 83
Test 84
Test 85
Test 86
Test 87
Test 88
Test 89
Test 90
Test 91
Test 92
Test 93
Test 94
Test 95
Test 96
Test 97
Test 98
Test 99
100 out of 100 tests passed.
0 out of 100 tests failed.


In [32]:
# -----------------------------------------------------------------------------
# Performance Benchmark
# IMPORTANT: DO NOT CHANGE THIS CODE.
# THIS IS THE EXACT CODE THAT WILL BE USED TO GRADE YOUR IMPLEMENTATION.
# ANY CHANGES TO THIS CODE (INCLUDING DIMENSIONS, REPEATS, etc.)
# WILL CAUSE YOU TO HAVE DIFFERENT SPEEDUP RESULTS.
# -----------------------------------------------------------------------------
M = 2048
K = 2048
N = 2048

# KEEP THESE MATRICES IN FP16. FP32 WILL NOT PROVIDE ACCURATE RESULTS
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C = torch.randn((M, N), device="cuda", dtype=torch.float16)

# warmup
_ = matmul_add_relu_fp16(A, B, C)
_ = reference_matmul_add_relu(A, B, C)

REPEATS = 5000

# -----------------------------------------------------------------------------
# Grid Search for Best Parameters
# -----------------------------------------------------------------------------
best_time = float('inf')
best_block_M = 0
best_block_N = 0
best_block_K = 0

for block_M in [32, 64, 128, 512]:
    for block_N in [32, 64, 128, 512]:
        for block_K in [32, 64, 128, 512]:
            if block_M*block_N*block_K > 128*128*32:
                continue
            # Print current parameters being tested
            print(f"Testing: BLOCK_M={block_M}, BLOCK_N={block_N}, BLOCK_K={block_K}")

            # Time your Triton implementation
            print("Triton implementation")
            torch.cuda.synchronize()
            start = time.perf_counter()
            for _ in range(REPEATS):
                _ = matmul_add_relu_fp16(A, B, C)
            torch.cuda.synchronize()
            triton_time = (time.perf_counter() - start) / REPEATS

            # Time PyTorch implementation
            print("PyTorch implementation")
            torch.cuda.synchronize()
            start = time.perf_counter()
            for _ in range(REPEATS):
                _ = reference_matmul_add_relu(A, B, C)
            torch.cuda.synchronize()
            torch_time = (time.perf_counter() - start) / REPEATS

            # Compare performance and store best time
            speedup = torch_time / triton_time
            print(f"Speedup of Triton vs PyTorch: {speedup:.2f}x")

            if triton_time < best_time:
                best_time = triton_time
                best_block_M = block_M
                best_block_N = block_N
                best_block_K = block_K

# Print the best performing block sizes
print(f"\nBest parameters:")
print(f"BLOCK_M={best_block_M}, BLOCK_N={best_block_N}, BLOCK_K={best_block_K}")
print(f"Best Triton time: {best_time*1000:.2f} ms")

Testing: BLOCK_M=32, BLOCK_N=32, BLOCK_K=32
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.35x
Testing: BLOCK_M=32, BLOCK_N=32, BLOCK_K=64
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.34x
Testing: BLOCK_M=32, BLOCK_N=32, BLOCK_K=128
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.34x
Testing: BLOCK_M=32, BLOCK_N=32, BLOCK_K=512
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.32x
Testing: BLOCK_M=32, BLOCK_N=64, BLOCK_K=32
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.30x
Testing: BLOCK_M=32, BLOCK_N=64, BLOCK_K=64
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.31x
Testing: BLOCK_M=32, BLOCK_N=64, BLOCK_K=128
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.32x
Testing: BLOCK_M=32, BLOCK_N=128, BLOCK_K=32
Triton implementation
PyTorch implementation
Speedup of Triton vs PyTorch: 1.