In [1]:
from IPython.display import clear_output
!pip install --upgrade triton
clear_output()

In [2]:
import triton
import triton.language as tl
import torch

In [6]:
@triton.jit
def __add__(x_ptr, y_ptr, out_ptr, size_x, size_y, BLOCK_SIZE: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    row_indices = (pid_0*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))[:, None]
    col_indices = (pid_1*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))[None, :]

    row_mask = row_indices < size_y
    col_mask = col_indices < size_x

    val_mask = row_mask & col_mask

    flat_index = row_indices * size_x + col_indices

    x = tl.load(x_ptr + flat_index, mask=val_mask, other=0.0)
    y = tl.load(y_ptr + flat_index, mask=val_mask, other=0.0)

    out = x + y

    tl.store(out_ptr + flat_index, out, mask=val_mask)


In [10]:
def test_addMatrix():
    sizeX = 8
    sizeY = 8
    BLOCK_SIZE = 2

    x = torch.randn(sizeY, sizeX, device='cuda', dtype=torch.float32)
    y = torch.randn(sizeY, sizeX, device='cuda', dtype=torch.float32)
    out = torch.zeros_like(x, device='cuda', dtype=torch.float32)

    x_flat = x.flatten()
    y_flat = y.flatten()
    out_flat = out.flatten()

    grid = (triton.cdiv(sizeX, BLOCK_SIZE), triton.cdiv(sizeY, BLOCK_SIZE))
    __add__[grid](x_flat, y_flat, out_flat, sizeX, sizeY, BLOCK_SIZE)

    out = out_flat.reshape(sizeY, sizeX)

    expected = x + y
    print("Matrix A:\n", x)
    print("Matrix B:\n", y)
    print("Matrix C (Triton):\n", out)
    print("Expected (PyTorch):\n", expected)
    assert torch.allclose(out, expected), "Triton result does not match PyTorch result!"

test_addMatrix()

Matrix A:
 tensor([[-0.8801, -0.2961,  1.4796, -0.5410,  0.2491, -0.4705,  1.9539, -0.9597],
        [-2.0555,  1.2980, -0.9776,  0.8391, -0.0563,  0.6079, -0.4813, -0.7989],
        [-1.3494, -0.4939, -0.6455,  0.4303, -0.5407, -0.2800, -0.3566,  0.9767],
        [ 0.8018,  0.1906,  0.9631, -0.0593,  0.0785, -1.8621, -0.8929, -1.3323],
        [ 0.8464, -0.1313,  0.2787, -1.1805, -0.7900, -1.1620,  0.2588,  0.4047],
        [-1.3724, -1.1831,  1.7337, -0.2776,  0.3643,  2.0978,  0.7107, -0.3782],
        [-0.0148, -1.2617,  1.1113, -0.6828, -0.2707,  0.3632,  0.4519, -1.1577],
        [ 0.8871,  1.0357, -0.8301, -0.7055, -1.1847, -0.3851, -0.2040, -0.3510]],
       device='cuda:0')
Matrix B:
 tensor([[-0.0107,  1.4539,  0.7185, -0.2708, -0.3393,  0.5407,  0.2321,  0.9470],
        [ 1.2506, -0.0601, -1.9298,  0.8016, -0.2933, -0.4981,  0.5054,  1.9880],
        [-1.4856, -0.2434,  0.4407,  0.1842,  1.8926,  2.1400, -0.0332, -0.9954],
        [ 0.7496, -0.3823, -0.1183, -0.3345, -0.651