In [1]:
import torch
import triton
import triton.language as tl

@triton.jit
def test_kernel(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    tl.store(output_ptr + offsets, tl.load(x_ptr + offsets))

print("Triton kernel test passed.")


Triton kernel test passed.


In [2]:
def weighted_sum(x, weight):
    # Here, assume that x has n-dim shape [..., D], and weight has 1D shape [D]
    return (weight * x).sum(axis=-1)

In [3]:
@triton.jit
def weighted_sum_fwd(
    x_ptr, weight_ptr, # Input pointers
    output_ptr, # Output pointer
    x_stride_row, x_stride_dim, # Strides tell us how to move one element in each axis of a tensor
    weight_stride_dim, # Likely 1
    output_stride_row, # Likely 1
    NUM_ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr, # Tile shapes must be known at compile time
    ):
    # Each instance will compute the weighted sum of a tile of rows of x.
    # `tl.program_id` gives us a way to check which thread block we're running in
    row_tile_idx = tl.program_id(0)

    # Block pointers give us a way to select from an ND region of memory
    # and move our selection around.
    # The block pointer must know:
    # - The pointer to the first element of the tensor
    # - The overall shape of the tensor to handle out-of-bounds access
    # - The strides of each dimension to use the memory layout properly
    # - The ND coordinates of the starting block, i.e., "offsets"
    # - The block shape to use load/store at a time
    # - The order of the dimensions in memory from major to minor
    #   axes (= np.argsort(strides)) for optimizations, especially useful on H100

    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(NUM_ROWS, D,),
        strides=(x_stride_row, x_stride_dim),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
        )

    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(weight_stride_dim,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
        )

    output_block_ptr = tl.make_block_ptr(
        output_ptr,
        shape=(NUM_ROWS,),
        strides=(output_stride_row,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
        )

    # Initialize a buffer to write to
    output = tl.zeros((ROWS_TILE_SIZE,), dtype=tl.float32)

    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        # Load the current block pointer
        # Since ROWS_TILE_SIZE might not divide ROWS, and D_TILE_SIZE might not divide D,
        # we need boundary checks for both dimensions
        row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero") # (ROWS_TILE_SIZE, D_TILE_SIZE)
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero") # (D_TILE_SIZE,)

        # Compute the weighted sum of the row.
        output += tl.sum(row * weight[None, :], axis=1)

        # Move the pointers to the next tile.
        # These are (rows, columns) coordinate deltas
        x_block_ptr = x_block_ptr.advance((0, D_TILE_SIZE)) # Move by D_TILE_SIZE in the last dimension
        weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,)) # Move by D_TILE_SIZE

    # Write output to the output block pointer (a single scalar per row).
    # Since ROWS_TILE_SIZE might not divide ROWS, we need boundary checks
    tl.store(output_block_ptr, output, boundary_check=(0,))

In [4]:
@triton.jit
def weighted_sum_backward(
    x_ptr, weight_ptr,  # Input
    grad_output_ptr,    # Grad input
    grad_x_ptr, partial_grad_weight_ptr,  # Grad outputs
    stride_xr, stride_xd,
    stride_wd,
    stride_gr,
    stride_gxr, stride_gxd,
    stride_gwb, stride_gwd,
    NUM_ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    row_offsets = pid * ROWS_TILE_SIZE + tl.arange(0, ROWS_TILE_SIZE)
    dim_offsets = tl.arange(0, D_TILE_SIZE)

    grad_weight_local = tl.zeros((D_TILE_SIZE,), dtype=tl.float32)

    for col in range(0, tl.cdiv(D, D_TILE_SIZE)):
        # Load x, weight, and grad_output
        x_tile = tl.load(
            x_ptr + row_offsets[:, None] * stride_xr + (col * D_TILE_SIZE + dim_offsets[None, :]) * stride_xd,
            mask=(row_offsets[:, None] < NUM_ROWS) & (col * D_TILE_SIZE + dim_offsets[None, :] < D),
            other=0.0
        )
        weight_tile = tl.load(
            weight_ptr + (col * D_TILE_SIZE + dim_offsets) * stride_wd,
            mask=(col * D_TILE_SIZE + dim_offsets < D),
            other=0.0
        )
        grad_output_tile = tl.load(
            grad_output_ptr + row_offsets * stride_gr,
            mask=(row_offsets < NUM_ROWS),
            other=0.0
        )[:, None]

        # Compute gradients
        grad_x_tile = grad_output_tile * weight_tile[None, :]
        grad_weight_local += tl.sum(x_tile * grad_output_tile, axis=0)

        # Store grad_x
        tl.store(
            grad_x_ptr + row_offsets[:, None] * stride_gxr + (col * D_TILE_SIZE + dim_offsets[None, :]) * stride_gxd,
            grad_x_tile,
            mask=(row_offsets[:, None] < NUM_ROWS) & (col * D_TILE_SIZE + dim_offsets[None, :] < D)
        )

    # Store partial gradients
    tl.store(
        partial_grad_weight_ptr + pid * stride_gwb + dim_offsets * stride_gwd,
        grad_weight_local,
        mask=(dim_offsets < D_TILE_SIZE)
    )

In [5]:
@triton.jit
def weighted_sum_backward_optimized(
    x_ptr, weight_ptr, grad_out_ptr,
    grad_x_ptr, partial_grad_weight_ptr,
    x_stride_row, x_stride_dim,
    weight_stride_dim,
    grad_out_stride_row,
    grad_x_stride_row, grad_x_stride_dim,
    partial_grad_weight_stride_row, partial_grad_weight_stride_dim,
    NUM_ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    row_start = pid * ROWS_TILE_SIZE
    row_offsets = row_start + tl.arange(0, ROWS_TILE_SIZE)
    dim_offsets = tl.arange(0, D_TILE_SIZE)

    # Local gradient accumulator (very fast)
    grad_weight_local = tl.zeros((D_TILE_SIZE,), dtype=tl.float32)

    for col in range(0, tl.cdiv(D, D_TILE_SIZE)):
        # Explicitly load tiles (fast shared loads)
        x_tile = tl.load(
            x_ptr + row_offsets[:, None] * x_stride_row + (col * D_TILE_SIZE + dim_offsets[None, :]) * x_stride_dim,
            mask=(row_offsets[:, None] < NUM_ROWS) & (col * D_TILE_SIZE + dim_offsets[None, :] < D),
            other=0.0
        )

        weight_tile = tl.load(
            weight_ptr + col * D_TILE_SIZE + dim_offsets,
            mask=(col * D_TILE_SIZE + dim_offsets < D),
            other=0.0
        )

        grad_out_tile = tl.load(
            grad_out_ptr + row_offsets * grad_out_stride_row,
            mask=(row_offsets < NUM_ROWS),
            other=0.0
        )[:, None]

        # Compute and store grad_x (once per tile explicitly)
        grad_x_tile = grad_out_tile * weight_tile[None, :]
        tl.store(
            grad_x_ptr + row_offsets[:, None] * grad_x_stride_row + (col * D_TILE_SIZE + dim_offsets[None, :]) * grad_x_stride_dim,
            grad_x_tile,
            mask=(row_offsets[:, None] < NUM_ROWS) & (col * D_TILE_SIZE + dim_offsets[None, :] < D)
        )

        grad_weight_local += tl.sum(x_tile * grad_out_tile, axis=0)

    # Store only one small partial gradient per block
    tl.store(
        partial_grad_weight_ptr + pid * partial_grad_weight_stride_row + tl.arange(0, D_TILE_SIZE),
        grad_weight_local,
        mask=(tl.arange(0, D_TILE_SIZE) < D)
    )


In [6]:
@triton.jit
def weighted_sum_backward_fully_optimized(
    x_ptr, weight_ptr, grad_out_ptr,
    grad_x_ptr, grad_weight_ptr,
    x_stride_row, x_stride_dim,
    weight_stride_dim,
    grad_out_stride_row,
    grad_x_stride_row, grad_x_stride_dim,
    NUM_ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    row_offsets = pid * ROWS_TILE_SIZE + tl.arange(0, ROWS_TILE_SIZE)

    # Loop over column tiles explicitly (unavoidable explicitly due to Triton constraints)
    for col in range(0, tl.cdiv(D, D_TILE_SIZE)):
        dim_offsets = col * D_TILE_SIZE + tl.arange(0, D_TILE_SIZE)

        mask = (row_offsets[:, None] < NUM_ROWS) & (dim_offsets[None, :] < D)

        # Load explicitly once per tile
        x_tile = tl.load(
            x_ptr + row_offsets[:, None] * x_stride_row + dim_offsets[None, :] * x_stride_dim,
            mask=mask, other=0.0
        )

        grad_out_tile = tl.load(
            grad_out_ptr + row_offsets * grad_out_stride_row,
            mask=(row_offsets < NUM_ROWS), other=0.0
        )[:, None]

        weight_tile = tl.load(
            weight_ptr + dim_offsets,
            mask=(dim_offsets < D), other=0.0
        )

        # local buffer explicitly constrained by Triton (D_TILE_SIZE explicitly)
        grad_weight_local = tl.sum(x_tile * grad_out_tile, axis=0)

        # Explicit atomic-add per tile explicitly unavoidable given Triton constraints explicitly!
        tl.atomic_add(
            grad_weight_ptr + dim_offsets,
            grad_weight_local,
            mask=(dim_offsets < D)
        )

        # Efficient grad_x computation explicitly
        grad_x_tile = weight_tile[None, :] * grad_out_tile

        tl.store(
            grad_x_ptr + row_offsets[:, None] * grad_x_stride_row + dim_offsets[None, :] * grad_x_stride_dim,
            grad_x_tile,
            mask=mask
        )

In [7]:
from einops import rearrange


class WeightedSumFunc(torch.autograd.Function):

    ROWS_TILE_SIZE = 16
    
    @staticmethod
    def forward(ctx, x, weight):
        D, output_dims = x.shape[-1], x.shape[:-1]
        input_shape = x.shape
        x = rearrange(x, "... d -> (...) d")

        ctx.save_for_backward(x, weight)
        assert x.is_cuda and weight.is_cuda, "CUDA tensors required"
        assert x.is_contiguous(), "Tensor x must be contiguous"

        n_rows = x.shape[0]

        # Important fix: explicitly setting attributes on ctx
        ctx.ROWS_TILE_SIZE = WeightedSumFunc.ROWS_TILE_SIZE
        ctx.D_TILE_SIZE = triton.next_power_of_2(D) // ctx.ROWS_TILE_SIZE
        #ctx.D_TILE_SIZE = min(D, max(16, triton.next_power_of_2(D) // ctx.ROWS_TILE_SIZE))

        ctx.input_shape = input_shape

        y = torch.empty((n_rows,), device=x.device)

        # Define the grid explicitly
        grid = (triton.cdiv(n_rows, ctx.ROWS_TILE_SIZE),)

        weighted_sum_fwd[grid](
            x, weight, y,
            x.stride(0), x.stride(1),
            weight.stride(0),
            y.stride(0),
            NUM_ROWS=n_rows, D=D,
            ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE, D_TILE_SIZE=ctx.D_TILE_SIZE,
        )

        return y.view(output_dims)

    @staticmethod
    def backward(ctx, grad_out):
        x, weight = ctx.saved_tensors
        n_rows, D = x.shape

        # Prepare explicitly smaller partial gradient buffer
        grid = (triton.cdiv(n_rows, ctx.ROWS_TILE_SIZE),)
        grad_weight = torch.zeros(D, device=x.device)
        grad_x = torch.empty_like(x)
        
        use_fully_optimized = False
        if use_fully_optimized:
            # Call optimized Triton backward explicitly
            weighted_sum_backward_fully_optimized[grid](
                x, weight, grad_out,
                grad_x, grad_weight,
                x.stride(0), x.stride(1),
                weight.stride(0),
                grad_out.stride(0),
                grad_x.stride(0), grad_x.stride(1),
                NUM_ROWS=n_rows, D=D,
                ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE, D_TILE_SIZE=ctx.D_TILE_SIZE,
            )
        else:
            partial_grad_weight = torch.zeros((grid[0], D), device=x.device, dtype=x.dtype)

            weighted_sum_backward[grid](
                x, weight, grad_out,
                grad_x, partial_grad_weight,
                x.stride(0), x.stride(1),
                weight.stride(0),
                grad_out.stride(0),
                grad_x.stride(0), grad_x.stride(1),
                partial_grad_weight.stride(0), partial_grad_weight.stride(1),
                n_rows, D,
                ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE,
                D_TILE_SIZE=D  # explicitly simple, whole dimension tile
            )

            grad_weight = partial_grad_weight.sum(dim=0)

        return grad_x.view(ctx.input_shape), grad_weight

In [8]:
import os

torch.manual_seed(42)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Set dimensions (medium-scale explicitly recommended first)
B, T, D = 256, 1024, 512

# Initialize input explicitly
x = torch.randn(B, T, D, device='cuda', requires_grad=True)
weight = torch.randn(D, device='cuda', requires_grad=True)

# Triton implementation (forward and backward explicitly)
y_triton = WeightedSumFunc.apply(x, weight)
loss_triton = y_triton.sum()
loss_triton.backward()

# Save gradients explicitly
grad_x_triton = x.grad.clone()
grad_weight_triton = weight.grad.clone()

# Reset gradients explicitly for PyTorch verification
x.grad.zero_()
weight.grad.zero_()

# PyTorch reference explicitly
y_torch = (x * weight).sum(dim=-1)
loss_torch = y_torch.sum()
loss_torch.backward()

# Explicitly verify correctness using torch.allclose
grad_x_correct = torch.allclose(grad_x_triton, x.grad, atol=1e-4, rtol=1e-4)
grad_weight_correct = torch.allclose(grad_weight_triton, weight.grad, atol=1e-4, rtol=1e-4)

print(f"Gradient x correct? {grad_x_correct}")
print(f"Gradient weight correct? {grad_weight_correct}")

if not grad_x_correct:
    print("Max difference in grad_x:", (grad_x_triton - x.grad).abs().max())

if not grad_weight_correct:
    print("Max difference in grad_weight:", (grad_weight_triton - weight.grad).abs().max())


Gradient x correct? True
Gradient weight correct? True


In [9]:
import timeit
import os
import torch
torch.manual_seed(42)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Inputs explicitly
# input_sizes = [[128, 512, 256], [256, 1024, 512], [256, 1024, 1024], [512, 2048, 512]]
input_sizes = [[128, 512, 256], [256, 1024, 512]]
tile_sizes = [16, 32, 64, 128, 256]

for [B, T, D] in input_sizes:

    # Initialize inputs explicitly
    x = torch.randn(B, T, D, device='cuda', requires_grad=True)
    weight = torch.randn(D, device='cuda', requires_grad=True)

    results = []

    for tile_size in tile_sizes:
        WeightedSumFunc.ROWS_TILE_SIZE = tile_size

        # Warm-up explicitly
        WeightedSumFunc.apply(x, weight).sum().backward()
        torch.cuda.synchronize()

        # Benchmark timing explicitly
        triton_time = timeit.timeit(
            lambda: (WeightedSumFunc.apply(x, weight).sum().backward(), torch.cuda.synchronize()),
            number=100
        )

        torch_time = timeit.timeit(
            lambda: ((x * weight).sum(dim=-1).sum().backward(), torch.cuda.synchronize()),
            number=100
        )

        avg_triton = triton_time / 100 * 1000
        avg_torch = torch_time / 100 * 1000
        ratio = avg_torch / avg_triton

        results.append((tile_size, avg_triton, avg_torch, ratio))

    # Explicitly show input dimensions at top of each table
    print(f"\n{'='*80}")
    print(f"Benchmark results for Input Size: B={B}, T={T}, D={D}")
    print(f"{'='*80}")
    print(f"{'Tile Size':<10} | {'Triton Time (ms)':<18} | {'PyTorch Time (ms)':<18} | {'Speedup (Torch/Triton)':<20}")
    print("-" * 80)
    for ts, triton_t, torch_t, r in results:
        print(f"{ts:<10} | {triton_t:<18.3f} | {torch_t:<18.3f} | {r:<20.2f}")



Benchmark results for Input Size: B=128, T=512, D=256
Tile Size  | Triton Time (ms)   | PyTorch Time (ms)  | Speedup (Torch/Triton)
--------------------------------------------------------------------------------
16         | 1.226              | 1.621              | 1.32                
32         | 1.325              | 1.586              | 1.20                
64         | 1.275              | 1.513              | 1.19                
128        | 2.309              | 1.393              | 0.60                
256        | 1.896              | 1.660              | 0.88                

Benchmark results for Input Size: B=256, T=1024, D=512
Tile Size  | Triton Time (ms)   | PyTorch Time (ms)  | Speedup (Torch/Triton)
--------------------------------------------------------------------------------
16         | 7.173              | 10.780             | 1.50                
32         | 7.076              | 10.815             | 1.53                
64         | 15.555             | 10.57

# ROWS_TILE_SIZE 16
Triton implementation avg time: 0.344 ms
PyTorch native avg time: 0.189 ms

# ROWS_TILE_SIZE 32
Triton implementation avg time: 0.178 ms
PyTorch native avg time: 0.063 ms

# ROWS_TILE_SIZE 64
Triton implementation avg time: 0.167 ms
PyTorch native avg time: 0.078 ms

In [10]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce RTX 4090 Laptop GPU
