In [None]:
def weighted_sum(X, w):
    """
    Compute MatMul X@w.
    """
    return (X * w).sum(axis=-1)

In [None]:
import triton
import triton.language as tl

"""
In triton, we do not count by # of bytes (like in C).

We count by # of elements, ie: id1, id2, id3, ...
"""

@triton.jit
def weighted_sum_kernel(
    X_ptr, w_ptr,    # Mat, Vec ptr
    out_ptr,         # Output pointer
    X_row_stride,    # Number of element to get to the nxt row in X
    X_stride_dim,    # Dimension of X's stride
    w_stride_dim,    # Dimesnion of w's stride
    out_stride_row,  # Number of element to get to the nxt row in output
    X_ROWS, X_D,     # Shape of the matrix X
    TILE_ROW, TILE_D # Shape of each tile
):
    # Each instance will compute the weighted sum of a (tile of rows) of x.
    # `tl.program_id` gives us a way to check which `thread block` we're running in
    row_tile_idx = tl.program_id(0)


    # Block pointers give us a way to select from an ND (N-dimensional) region of memory
    """
    The block pointer must know:
        - The pointer to the first element of the tensor
        - The overall shape of the tensor to handle out-of-bounds access
        - The strides of each dimension to use the memory layout properly
        - The ND coordinates of the starting block, i.e., "offsets"
        - The block shape to use load/store at a time
        - The order of the dimensions in memory from major to minor
    """
    # axes (= np.argsort(strides)) for optimizations, especially useful on H100
    