In [1]:
# Installing needed package (triton)
! pip install triton



In [2]:
# Importing needed libraries
import torch
import torch.nn.functional as F
import triton
import triton.language as tl

In [3]:
# Golbal Variables
DEVICE = "cuda"
C_OUT = 64
C_IN = 3
H = 1024
W = 1024
FH = 3
FW = 3
P = 1
BLOCK_H, BLOCK_W = 16, 16

In [4]:
# Making torch tensors
tensor_I = torch.rand(1, C_IN, H, W, device=DEVICE) # Input, assuming that batch_size is one
tensor_F = torch.rand(C_OUT, C_IN, FH, FW , device=DEVICE) # Weights

In [5]:
# This is the result from Convolutional Layer provided by Torch
# Use this for correctness check
golden_out = F.conv2d(tensor_I, tensor_F, padding=1)
print(golden_out.shape) # (1, C_OUT, OUT_H, OUT_W)

torch.Size([1, 64, 1024, 1024])


In [6]:
@triton.jit
def my_triton_kernel(
    input_ptr, kernel_ptr, output_ptr,
    H, W, C_IN, C_OUT, FH, FW, OUT_H, OUT_W,
    stride_h, stride_w, padding_h, padding_w,
    BLOCK_H: tl.constexpr, BLOCK_W: tl.constexpr
):
    # TODO: Complete the triton kernel that does convolution
    # Pointers to the output tile
    oh = tl.program_id(0)
    ow = tl.program_id(1)
    co = tl.program_id(2)

    h_offsets = oh * BLOCK_H + tl.arange(0, BLOCK_H)
    w_offsets = ow * BLOCK_W + tl.arange(0, BLOCK_W)

    # Initialize accumulator
    acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)

    for ci in range(0, C_IN):
        for fh in range(FH):
            for fw in range(FW):
                h_in = h_offsets + fh - padding_h
                w_in = w_offsets + fw - padding_w

                mask_h = (h_in >= 0) & (h_in < H)
                mask_w = (w_in >= 0) & (w_in < W)
                mask = mask_h[:, None] & mask_w[None, :]

                input_idx = (
                    ci * H * W
                    + h_in[:, None] * W
                    + w_in[None, :]
                )
                input_val = tl.load(input_ptr + input_idx, mask=mask, other=0.0)

                kernel_idx = (
                    co * C_IN * FH * FW
                    + ci * FH * FW
                    + fh * FW
                    + fw
                )
                kernel_val = tl.load(kernel_ptr + kernel_idx)

                acc += input_val * kernel_val

    # Write back
    output_idx = (
        co * OUT_H * OUT_W
        + h_offsets[:, None] * OUT_W
        + w_offsets[None, :]
    )
    mask_h_out = h_offsets < OUT_H
    mask_w_out = w_offsets < OUT_W
    mask_out = mask_h_out[:, None] & mask_w_out[None, :]

    tl.store(output_ptr + output_idx, acc, mask=mask_out)

def my_conv2d(input, kernel):
    # TODO: Initializing some variables
    B, C_IN, H, W = input.shape
    C_OUT, _, FH, FW = kernel.shape
    OUT_H = H
    OUT_W = W

    # TODO: Calculate output dimension & Allocate output tensor
    output = torch.empty((B, C_OUT, OUT_H, OUT_W), device=input.device, dtype=input.dtype)

    # Flatten input and kernel for Triton
    input_flat = input.view(-1)
    kernel_flat = kernel.view(-1)
    output_flat = output.view(-1)

    # TODO: Define grid
    grid = (
        triton.cdiv(OUT_H, BLOCK_H),
        triton.cdiv(OUT_W, BLOCK_W),
        C_OUT
    )
    # TODO: Call the triton kernel (my_triton_kernel) and measure execution time
    # Warm up
    my_triton_kernel[grid](
        input_flat, kernel_flat, output_flat,
        H, W, C_IN, C_OUT, FH, FW, OUT_H, OUT_W,
        1, 1, 1, 1,
        BLOCK_H=BLOCK_H, BLOCK_W=BLOCK_W
    )
    torch.cuda.synchronize()

    # Measure execution time
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    my_triton_kernel[grid](
        input_flat, kernel_flat, output_flat,
        H, W, C_IN, C_OUT, FH, FW, OUT_H, OUT_W,
        1, 1, 1, 1,
        BLOCK_H=BLOCK_H, BLOCK_W=BLOCK_W
    )
    # Wait for the events to be recorded!
    torch.cuda.synchronize()

    end_event.record()

    # Measure execution time in milliseconds
    execution_time = start_event.elapsed_time(end_event)

    # TODO: Return output (output should include execution time)
    return output, execution_time

In [8]:
# Testing
# Comparing the result from my_conv2d and Conv from torch
my_output, execution_time = my_conv2d(tensor_I, tensor_F)
torch.testing.assert_close(golden_out, my_output) # Assert statement should be passed
# Printing the execution time
print(f"Execution Time for triton kernel: {execution_time: .3f} ms")

Execution Time for triton kernel:  4.057 ms
