In [None]:
def weighted_sum(X, w):
    """
    Compute MatMul X@w.
    """
    return (X * w).sum(axis=-1)

In [None]:
import triton
import triton.language as tl

"""
In triton, we do not count by # of bytes (like in C).

We count by # of elements, ie: id1, id2, id3, ...
"""

@triton.jit
def weighted_sum_fwd(
    X_ptr, w_ptr,    # Mat, Vec ptr
    out_ptr,         # Output pointer
    X_row_stride,    # Number of element to get to the nxt row in X
    X_stride_dim,    # Dimension of X's stride
    w_stride_dim,    # Dimesnion of w's stride
    out_stride_row,  # Number of element to get to the nxt row in output
    X_ROW, X_D,     # Shape of the matrix X
    TILE_ROWS_SIZE, TILE_D_SIZE # Shape of each tile
):
    # Each instance will compute the weighted sum of a (tile of rows) of x.
    # `tl.program_id` gives us a way to check which `thread block` we're running in
    row_tile_idx = tl.program_id(0)

    # Block pointers give us a way to select from an ND (N-dimensional) region of memory
    """
    The block pointer must know:
        - ptr: ptr          The pointer to the first element of the tensor
        - shape:Tuple()     The overall shape of the tensor (R, D) to handle out-of-bounds access
        - strides:Tuple()   The strides of each dimension (stride_R, stride_D) to use the memory layout properly
        - offsets:Tuple()   The ND coordinates of the starting block, i.e., "offsets"
        - block_shape:Tuple() The block shape to use load/store at a time
        - order:Tuple()     The order of the dimensions in memory from major to minor
            axes (= np.argsort(strides)) for optimizations

    - order: Specify out how the matrix A is stored in the memory. (Contiguous in ROW or in COL?)
        
        Suppose we have an matrix A, 
               A = [[ 0,  1,  2,  3],
                    [ 4,  5,  6,  7],
                    [ 8,  9, 10, 11],
                    [12, 13, 14, 15]]

        - If order=(0,1) (row-major) → linear memory: 0,1,2,3,4,5,...
            Threads along row → coalesced.

        - If order=(1,0) (column-major) → linear memory: 0,4,8,12,1,5,9,13,...
            Threads along column → coalesced.
    """

    # Row major matrix X in R^(N, D)
    """
    X is loaded as X[TILE_ROWS_SIZE][TILE_D_SIZE]
    """
    x_block_ptr = tl.make_block_ptr(
        X_ptr, # a pointer to X[0, 0]
        shape = (X_ROW, X_D), # for boundary check
        strides = (X_row_stride, X_stride_dim), # move +1 row / column == +X_row_stride elements / +X_stride_dim elements
        offsets = (row_tile_idx * TILE_ROWS_SIZE, 0), # The starting corner (origin) of the current tile; (parallelizes over row tiles)
        block_shape = (TILE_ROWS_SIZE, TILE_D_SIZE),  # The size of the tile.
        order = (1,0) # dim1 is more contiguous than dim0; moving in column; row-major layout
    )

    # Row vector w in R^(D, 1); All Vector's Typing is the SAME
    """
    w is loaded as w[TILE_D_SIZE]
    """
    weight_block_ptr = tl.make_block_ptr(
        w_ptr,
        shape=(X_D,),
        strides=(w_stride_dim,),
        offsets=(0,), # The starting corner; (因为是一个Row vec同时我们parallelizes over row tiles，所以是0)
        block_shape=(TILE_D_SIZE,), # The size of the tile.
        order=(0,),
    )

    # Output col vector in R^(N, 1)
    """
    output is loaded as o[TILE_D_SIZE]
    """
    out_block_ptr = tl.make_block_ptr(
        out_ptr,
        shape=(X_ROW,),
        strides=(out_stride_row),
        offsets=(row_tile_idx * TILE_ROWS_SIZE,), # The starting corner (origin) of the current tile; (parallelizes over row tiles)
        block_shape=(TILE_ROWS_SIZE,), # The size of the tile.
        order=(0,),
    )

    # Initialize a buffer to recieve output
    output = tl.zeros((TILE_ROWS_SIZE,))

    # Sweep across columns
    for i in range(tl.cdiv(X_D, TILE_D_SIZE)):
        # load blocks ptrs
        # 因为有可能 Tile Row/D Size 不能整除 X Row/D Size，所以需要check boundary for Row & D.
        # padded with 0 如果 out of boundary
        X_ROW = tl.load(x_block_ptr, boundary_check=(0,1), padding_option="zero")  
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")

        # Compute Weighted Sum of the block
        output += tl.sum(X_ROW * weight[None, :], axis=1)

        # Move the pointers to the next tile
        # [[] -> ...
        #  [] -> ...    &&  [ [] -> ... ]
        #  [] -> ...]
        x_block_ptr = x_block_ptr.advance((0, TILE_D_SIZE)) # Curr += (0, TILE_D_SIZE)
        weight_block_ptr = weight_block_ptr((TILE_D_SIZE,)) # Curr += (TILE_D_SIZE,)

    # Write the output to the buffer
    tl.store(out_block_ptr, output, boundary_check=(0,))

In [None]:
import torch
import einops
from einops import rearrange
from triton import cdiv

class WeightedSumFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, weight):
        X_D, out_dim = X.shape[-1], X.shape[:-1] # `last dim`, `all dim except for the last dim`
        input_shape = X.shape
        
        # Reshape X into 2D
        X = rearrange(X, "... d -> (...) d")

        # ctx is 
        ctx.save_for_backward(X, weight)

        # Prior checks 
        assert len(weight.shape) == 1 and weight.shape[0] == X_D, "ValueError: Matrix Vector Dimension Mismatch"
        assert X.is_cuda and weight.is_cuda, "TypeError: Expect CUDA Tensors, got other"
        assert X.is_contiguous(), "TypeError: Expect a Contiguous Tensor X"

        # Define triton kernel constants
        ctx.TILE_D_SIZE = triton.next_power_of_2(X_D) # Fixed the Tile size to be power of 2, ie: 1024, 2048, ....
        ctx.TILE_ROW_SIZE = 16 # Fixed the # of rows in Tile
        ctx.input_shape = input_shape 

        # Need to initialize empty result tensor. Note that these elements are not necessarily Output.
        y = torch.empty(out_dim, device=X.device)

        # Launch kernel
        n_rows = y.numel()
        weighted_sum_fwd[(cdiv(n_rows, ctx.TILE_ROW_SIZE),)]( # define launch grid
            X, weight,
            y,
            X.stride(0), X.stride(1),
            weight.stride(0),
            y.stride(0),
            X_ROW=n_rows, X_D=X_D,
            TILE_ROWS_SIZE=ctx.TILE_ROWS_SIZE, TILE_D_SIZE=ctx.TILE_D_SIZE
        )

        return y.view(input_shape[:-1])

In [None]:
@triton.jit
def weighted_sum_backward(
    X_ptr, w_ptr,
    grad_y_ptr,
    grad_X_ptr, tile_grad_w_ptr,
    stride_X_ROW, stride_X_D,
    stride_w_D,
    stride_grad_y_ROW,
    stride_grad_X_ROW, stride_grad_X_D,
    stride_tile_grad_w_ROW, stride_tile_grad_w_D,
    X_ROW, X_D,
    TILE_ROW, TILE_D
):
    """
    Grad X can parrallized 
    Grad w can't fully parrallized computed at once.

    Note:
        The gradient of weight w is only partially computed in this stage, since 
        the gradient of w_j requires to sum across all the rows n to compute.
    """
    # Get the current program's tile id
    rowtile_idx = tl.program_id(0)
    n_rowtile = tl.num_programs(0)

    # 1D vector 
    grad_y_block_ptr = tl.make_block_ptr(
        grad_y_ptr,
        shape=(X_ROW,),
        strides=(stride_grad_y_ROW,),
        offsets=(rowtile_idx*TILE_ROW, 0),
        block_shape=(TILE_ROW),
        order=(0,)
    )

    # 2D matrix
    X_block_ptr = tl.make_block_ptr(
        X_ptr,
        shape=(X_ROW, X_D),
        strides=(stride_X_ROW, stride_X_D),
        offsets=(rowtile_idx*X_ROW, 0),
        block_shape=(X_ROW, X_D),
        order=(1, 0)
    )

    # 1D vector
    w_block_ptr = tl.make_block_ptr(
        w_ptr,
        shape=(X_D,),
        strides=(stride_w_D,),
        offsets=(0,)
        block_shape=(TILE_D,)
        order=(0,)
    )

    # 2D matrix
    grad_X_block_ptr = tl.make_block_ptr(
        grad_X_ptr,
        shape=(X_ROW, X_D),
        strides=(stride_grad_X_ROW, stride_grad_X_D),
        offsets=(rowtile_idx*TILE_ROW, 0),
        block_shape=(TILE_ROW, TILE_D),
        order=(1,0)
    )

    # A partially computed gradient, not sum() reduced yet
    tile_grad_w_block_ptr = tl.make_block_ptr(
        tile_grad_w_ptr,
        shape=(n_rowtile, X_D,),
        strides=(stride_tile_grad_w_ROW, stride_tile_grad_w_D),
        offsets=(rowtile_idx, 0),
        block_shape=(1, TILE_D),
        order=(1,0)
    )

    # Sweeping across COLUMNS
    for i in range(tl.cdiv(X_D, TILE_D)):
        grad_y = tl.load(grad_y_block_ptr, boundary_check=(0,), padding_option="zero") # (X_D,)
        w = tl.load(w_ptr, boundary_check=(0,), padding_option="zero")  # (TILE_D)

        # Compute dL/dX = outer_prod(grad_y, w) = [nx1][1xD]
        grad_X = grad_y[:,None] * w[None, :]
        tl.store(grad_X_block_ptr, grad_X)

NameError: name 'triton' is not defined