<a href="https://colab.research.google.com/github/Sudhansh6/cse234-pa2/blob/main/matmul_triton_sara.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import torch
import triton
import triton.language as tl
import time

In [15]:
def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

In [16]:
def is_hip_mi200():
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == 'hip' and target.arch == 'gfx90a'

In [17]:
"""
PA2 Part 2: MatMul+Relu+Add Fused Optimization.
The kernel uses several optimization techniques:

  1. Shared memory tiling.
  2. Register tiling.
  3. Cooperative fetching.
  4. Operator Fusion
  5. Write cache / epilogue fusion.

Fill in the missing parts (marked with TODO).
"""

# -----------------------------------------------------------------------------
# Tiling parameters - You will need to change these to achieve better results.
# -----------------------------------------------------------------------------
BLOCK_M = 128  # Tile size in the M dimension.
BLOCK_N = 256 # Tile size in the N dimension.
BLOCK_K = 32 # Tile size in the K dimension.


# -----------------------------------------------------------------------------
# Triton Kernel: Matrix Multiplication + ReLU + Add
#
# The kernel uses:
#   Step 1: Tile assignment (each kernel computes a tile of C)
#   Step 2: Shared memory tiling + Cooperative Fetching: Load tiles of A and B.
#   Step 3: Register tiling: Use a register accumulator.
#   Step 4: Add and ReLU fusion
#   Step 5: Write cache/Epilogue: Write the final tile back to global memory.
# -----------------------------------------------------------------------------
@triton.jit
def matmul_add_relu_kernel_fp16(
    a_ptr, b_ptr, c_ptr, d_ptr,
    M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
    stride_am: tl.constexpr, stride_ak: tl.constexpr,
    stride_bk: tl.constexpr, stride_bn: tl.constexpr,
    stride_cm: tl.constexpr, stride_cn: tl.constexpr,
    stride_dm: tl.constexpr, stride_dn: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # -------------------------------------------------------------------------
    # Step 1: Tile: Assignment
    #
    # Each kernel instance is mapped to a tile in the output matrix C.
    # Compute the starting indices (m_start, n_start) for this tile.
    # -------------------------------------------------------------------------
    # TODO: Compute the tile indices using program_id(0) for M and program_id(1) for N.
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    m_start = pid_m * BLOCK_M
    n_start = pid_n * BLOCK_N

    # -------------------------------------------------------------------------
    # Step 2: Register Tiling
    # -------------------------------------------------------------------------
    # TODO: Initialize the accumulator "acc" with zeros (dtype: float16).
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16)

    # -------------------------------------------------------------------------
    # Step 3: Shared Memory Tiling & Cooperative Fetching.
    # Compute pointers to the sub-tiles of A and B that are needed to compute
    # the current C tile. The offsets here serve to load BLOCK_SIZE_M x BLOCK_SIZE_K
    # and BLOCK_SIZE_K x BLOCK_SIZE_N blocks from A and B respectively.
    # -------------------------------------------------------------------------
    # TODO: Finish code below
    offs_am = m_start + tl.arange(0, BLOCK_M)[:, None]
    offs_bn = n_start + tl.arange(0, BLOCK_N)[None, :]
    for k in range(0, K, BLOCK_K):

        offs_k = k + tl.arange(0, BLOCK_K)

        a_ptrs = a_ptr + (offs_am * stride_am + offs_k * stride_ak)
        b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn * stride_bn)

        a_tile = tl.load(a_ptrs, mask=(offs_am < M) & (offs_k < K), other=0.0)
        b_tile = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_bn < N), other=0.0)

        acc += tl.dot(a_tile, b_tile, out_dtype=tl.float16)
    # -------------------------------------------------------------------------
    # Step 4: Apply ReLU and Add C to the accumulator
    # -------------------------------------------------------------------------
    # TODO: Finish code below
    offs_cm = m_start + tl.arange(0, BLOCK_M)[:, None]
    offs_cn = n_start + tl.arange(0, BLOCK_N)[None, :]
    c_ptrs = c_ptr + stride_cm * offs_cm + stride_cn * offs_cn

    acc = tl.maximum(acc + tl.load(c_ptrs, mask=(offs_cm < M) & (offs_cn < N), other=0.0), 0)

    # -------------------------------------------------------------------------
    # Step 5: Write Cache / Epilogue Fusion: Write the computed tile to D.
    # -------------------------------------------------------------------------
    # TODO: Finish code below
    d_ptrs = d_ptr + stride_dm * offs_cm + stride_dn * offs_cn
    tl.store(d_ptrs, acc, mask=(offs_cm < M) & (offs_cn < N))

In [18]:
def matmul_add_relu_fp16(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
    """
    Computes Output = ReLU(A @ B + C) using fp16 precision for maximum throughput.
    """
    M, K = a.shape
    K2, N = b.shape
    assert K == K2, "Incompatible dimensions"

    d = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # Create launch grid
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_add_relu_kernel_fp16[grid](
        a, b, c, d,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        d.stride(0), d.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
    )
    return d

In [19]:
# Reference implementation using PyTorch
def reference_matmul_add_relu(A, B, C):
    result = torch.matmul(A, B).add(C).relu_()
    return result

In [20]:
# # -----------------------------------------------------------------------------
# # Accuracy Tests
# # -----------------------------------------------------------------------------
# if __name__ == "__main__":
#     torch.manual_seed(0)
#     a = torch.randn((1139, 3213), device=torch.device("cuda"), dtype=torch.float16)
#     b = torch.randn((3213, 4215), device=torch.device("cuda"), dtype=torch.float16)
#     c = torch.randn((1139, 4215), device=torch.device("cuda"), dtype=torch.float16)
#     triton_output = matmul_add_relu_fp16(a, b, c)
#     torch_output = reference_matmul_add_relu(a, b, c)
#     print(f"triton_output_with_fp16_inputs={triton_output}")
#     print(f"torch_output_with_fp16_inputs={torch_output}")
#     rtol = 1e-2 if is_hip_mi200() else 0.032
#     if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
#         print("✅ Triton and Torch match")
#     else:
#         diff = triton_output - torch_output
#         abs_diff = torch.abs(diff)
#         max_abs_diff = torch.max(abs_diff)
#         print(f"❌ Triton and Torch differ: {max_abs_diff=}")

In [21]:
if __name__ == "__main__":
    torch.manual_seed(0)

    # Normalize input values to prevent fp16 overflows
    a = (torch.randn((1139, 3213), device="cuda", dtype=torch.float16) * 0.1)
    b = (torch.randn((3213, 4215), device="cuda", dtype=torch.float16) * 0.1)
    c = (torch.randn((1139, 4215), device="cuda", dtype=torch.float16) * 0.1)

    # Compute outputs
    triton_output = matmul_add_relu_fp16(a, b, c)
    torch_output = reference_matmul_add_relu(a, b, c)

    # Print outputs for debugging
    print(f"Triton Output:\n {triton_output}")
    print(f"Torch Output:\n {torch_output}")

    rtol = 1e-2 if is_hip_mi200() else 0.032
    if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
        print("✅ Triton and Torch match")
    else:
        diff = triton_output - torch_output
        abs_diff = torch.abs(diff)
        max_abs_diff = torch.max(abs_diff)
        print(f"❌ Triton and Torch differ: {max_abs_diff=}")

Triton Output:
 tensor([[0.2659, 0.0644, 1.0361,  ..., 0.6006, 0.0045, 0.0000],
        [0.6382, 0.0000, 1.1904,  ..., 0.2001, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.1382, 0.0000, 0.3589],
        ...,
        [0.0000, 0.1774, 0.0000,  ..., 0.2593, 0.0639, 0.3738],
        [0.2800, 0.0000, 0.5938,  ..., 0.0000, 0.5625, 0.6572],
        [0.5381, 0.0000, 0.4954,  ..., 0.0000, 0.8140, 0.0000]],
       device='cuda:0', dtype=torch.float16)
Torch Output:
 tensor([[0.2668, 0.0653, 1.0322,  ..., 0.6006, 0.0046, 0.0000],
        [0.6387, 0.0000, 1.1836,  ..., 0.2008, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.1393, 0.0000, 0.3564],
        ...,
        [0.0000, 0.1783, 0.0000,  ..., 0.2583, 0.0638, 0.3735],
        [0.2795, 0.0000, 0.5938,  ..., 0.0000, 0.5625, 0.6558],
        [0.5386, 0.0000, 0.4944,  ..., 0.0000, 0.8115, 0.0000]],
       device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match


In [22]:
# -----------------------------------------------------------------------------
# Performance Benchmark
# IMPORTANT: DO NOT CHANGE THIS CODE.
# THIS IS THE EXACT CODE THAT WILL BE USED TO GRADE YOUR IMPLEMENTATION.
# ANY CHANGES TO THIS CODE (INCLUDING DIMENSIONS, REPEATS, etc.)
# WILL CAUSE YOU TO HAVE DIFFERENT SPEEDUP RESULTS.
# -----------------------------------------------------------------------------
M = 2048
K = 2048
N = 2048

# KEEP THESE MATRICES IN FP16. FP32 WILL NOT PROVIDE ACCURATE RESULTS
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C = torch.randn((M, N), device="cuda", dtype=torch.float16)

# warmup
_ = matmul_add_relu_fp16(A, B, C)
_ = reference_matmul_add_relu(A, B, C)

REPEATS = 5000

# time your implementation
print("Triton implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = matmul_add_relu_fp16(A, B, C)
torch.cuda.synchronize()
triton_time = (time.perf_counter() - start) / REPEATS

# time pytorch
print("PyTorch implementation")
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(REPEATS):
    _ = reference_matmul_add_relu(A, B, C)
torch.cuda.synchronize()
torch_time = (time.perf_counter() - start) / REPEATS

print(f"Performance comparison for matrix multiplication ({M}x{K} @ {K}x{N}):")
print(f"Triton implementation: {triton_time*1000:.2f} ms")
print(f"PyTorch implementation: {torch_time*1000:.2f} ms")

print(f"\nSpeedup of Triton vs PyTorch: {torch_time/triton_time:.2f}x")

Triton implementation
PyTorch implementation
Performance comparison for matrix multiplication (2048x2048 @ 2048x2048):
Triton implementation: 0.75 ms
PyTorch implementation: 1.04 ms

Speedup of Triton vs PyTorch: 1.39x


In [23]:
import random

# Initialize counters
passed_tests = 0
failed_tests = 0
failed_shapes = []

# Function to generate random dimensions
def generate_random_dimensions():
    M = random.randint(1, 2048)  # Rows in a
    K = random.randint(1, 2048)  # Columns in a and rows in b
    N = random.randint(1, 2048)  # Columns in b
    return M, K, N

# Run 100 tests
for i in range(50):
    print("Test", i)
    # Generate random dimensions
    M, K, N = generate_random_dimensions()

    # Set a manual seed for reproducibility
    torch.manual_seed(i)

    # Generate random tensors with the generated dimensions
    s = 0.1
    a = torch.randn((M, K), device=torch.device("cuda"), dtype=torch.float16) * s
    b = torch.randn((K, N), device=torch.device("cuda"), dtype=torch.float16) * s
    c = torch.randn((M, N), device=torch.device("cuda"), dtype=torch.float16) * s

    # Compute outputs using Triton and reference implementations
    triton_output = matmul_add_relu_fp16(a, b, c)
    torch_output = reference_matmul_add_relu(a, b, c)

    # Define relative tolerance
    rtol = 1e-2 if is_hip_mi200() else 0.032

    # Check if outputs are close within the specified tolerance
    if torch.allclose(triton_output, torch_output, atol=0.15, rtol=rtol):
        passed_tests += 1  # Increment counter if the test passes
    else:
        failed_tests += 1  # Increment counter if the test fails
        failed_shapes.append((a.shape, b.shape, c.shape))  # Log the shapes

# After the loop, print the number of passed and failed tests
print(f"{passed_tests} out of 50 tests passed.")
print(f"{failed_tests} out of 50 tests failed.")

# Print the shapes of matrices that failed
if failed_tests > 0:
    print("Shapes of matrices that failed:")
    for idx, shapes in enumerate(failed_shapes):
        a_shape, b_shape, c_shape = shapes
        print(f"Test {idx + 1}:")
        print(f"  a.shape: {a_shape}")
        print(f"  b.shape: {b_shape}")
        print(f"  c.shape: {c_shape}")

Test 0
Test 1
Test 2
Test 3
Test 4
Test 5
Test 6
Test 7
Test 8
Test 9
Test 10
Test 11
Test 12
Test 13
Test 14
Test 15
Test 16
Test 17
Test 18
Test 19
Test 20
Test 21
Test 22
Test 23
Test 24
Test 25
Test 26
Test 27
Test 28
Test 29
Test 30
Test 31
Test 32
Test 33
Test 34
Test 35
Test 36
Test 37
Test 38
Test 39
Test 40
Test 41
Test 42
Test 43
Test 44
Test 45
Test 46
Test 47
Test 48
Test 49
50 out of 50 tests passed.
0 out of 50 tests failed.


In [24]:
#GRID SEARCH


BLOCK_M_VALUES = [64, 128]
BLOCK_N_VALUES = [128, 256]
BLOCK_K_VALUES = [32, 64]

def is_safe_config(BLOCK_M, BLOCK_N, BLOCK_K):
    shared_mem_size = (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N) * 2
    return shared_mem_size <= (48 * 1024)

def benchmark_triton(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, repeats=5000):
    A = torch.randn((M, K), device="cuda", dtype=torch.float16)
    B = torch.randn((K, N), device="cuda", dtype=torch.float16)
    C = torch.randn((M, N), device="cuda", dtype=torch.float16)
    D = torch.empty((M, N), device="cuda", dtype=torch.float16)

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    for _ in range(50):
        matmul_add_relu_kernel_fp16[grid](
            A, B, C, D,
            M, N, K,
            A.stride(0), A.stride(1),
            B.stride(0), B.stride(1),
            C.stride(0), C.stride(1),
            D.stride(0), D.stride(1),
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
        )

    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(repeats):
        matmul_add_relu_kernel_fp16[grid](
            A, B, C, D,
            M, N, K,
            A.stride(0), A.stride(1),
            B.stride(0), B.stride(1),
            C.stride(0), C.stride(1),
            D.stride(0), D.stride(1),
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
        )
    torch.cuda.synchronize()
    elapsed_time = (time.perf_counter() - start) / repeats
    return elapsed_time

M, N, K = 2048, 2048, 2048
best_time = float("inf")
best_config = None

for BLOCK_M in BLOCK_M_VALUES:
    for BLOCK_N in BLOCK_N_VALUES:
        for BLOCK_K in BLOCK_K_VALUES:
            if is_safe_config(BLOCK_M, BLOCK_N, BLOCK_K):
                exec_time = benchmark_triton(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K)
                print(f"BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, BLOCK_K={BLOCK_K} -> Time: {exec_time:.4f} ms")
                if exec_time < best_time:
                    best_time = exec_time
                    best_config = (BLOCK_M, BLOCK_N, BLOCK_K)

print(f"Best Config: BLOCK_M={best_config[0]}, BLOCK_N={best_config[1]}, BLOCK_K={best_config[2]} with time {best_time:.4f} ms")


BLOCK_M=64, BLOCK_N=128, BLOCK_K=32 -> Time: 0.0012 ms
BLOCK_M=64, BLOCK_N=128, BLOCK_K=64 -> Time: 0.0012 ms
BLOCK_M=64, BLOCK_N=256, BLOCK_K=32 -> Time: 0.0011 ms
BLOCK_M=64, BLOCK_N=256, BLOCK_K=64 -> Time: 0.0013 ms
BLOCK_M=128, BLOCK_N=128, BLOCK_K=32 -> Time: 0.0011 ms
BLOCK_M=128, BLOCK_N=128, BLOCK_K=64 -> Time: 0.0009 ms
BLOCK_M=128, BLOCK_N=256, BLOCK_K=32 -> Time: 0.0008 ms
BLOCK_M=128, BLOCK_N=256, BLOCK_K=64 -> Time: 0.0009 ms
Best Config: BLOCK_M=128, BLOCK_N=256, BLOCK_K=32 with time 0.0008 ms
