# Triton Tutorial: Weighted Sum Kernel

This notebook walks through implementing a custom weighted sum operation in Triton, based on the CS336 Assignment 2 example.

## What we'll learn:
1. How to write a Triton kernel with the `@triton.jit` decorator
2. How to use block pointers for memory access
3. How to implement forward and backward passes
4. How to integrate Triton kernels with PyTorch's autograd

## The Operation

Given:
- Input matrix `X` with shape `[..., D]` (can be batched)
- Weight vector `w` with shape `[D]`

Compute: `(w * X).sum(axis=-1)` - a weighted sum along the last dimension


In [1]:
import torch
import triton
import triton.language as tl
import time

# Helper function for ceiling division
def cdiv(a, b):
    return (a + b - 1) // b

print(f"PyTorch version: {torch.__version__}")
print(f"Triton version: {triton.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


PyTorch version: 2.6.0+cu124
Triton version: 3.2.0
CUDA available: True


## Step 1: PyTorch Reference Implementation

First, let's see what we're trying to implement in pure PyTorch:


In [2]:
def weighted_sum_pytorch(x, weight):
    """Reference implementation in PyTorch"""
    # x has shape [..., D], weight has shape [D]
    return (weight * x).sum(axis=-1)

# Test it
D = 8
n_rows = 4
x = torch.randn(n_rows, D, device='cuda')
weight = torch.randn(D, device='cuda')

result = weighted_sum_pytorch(x, weight)
print(f"Input shape: {x.shape}")
print(f"Weight shape: {weight.shape}")
print(f"Output shape: {result.shape}")
print(f"Output: {result}")


Input shape: torch.Size([4, 8])
Weight shape: torch.Size([8])
Output shape: torch.Size([4])
Output: tensor([-3.2948, -0.0807,  0.2607, -3.1335], device='cuda:0')


In [3]:
@triton.jit
def weighted_sum_fwd(
    x_ptr, weight_ptr,  # Input pointers
    output_ptr,  # Output pointer
    x_stride_row, x_stride_dim,  # Strides for x
    weight_stride_dim,  # Stride for weight
    output_stride_row,  # Stride for output
    ROWS, D,  # Dimensions
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr,  # Tile sizes (compile-time constants)
):
    """
    Triton kernel for weighted sum forward pass.
    
    Each thread block processes a tile of rows:
    - row_tile_idx determines which tile of rows this block handles
    - We loop over D in tiles, accumulating the weighted sum
    """
    # Each thread block processes a tile of rows
    row_tile_idx = tl.program_id(0)
    
    # Create block pointer for x (2D: rows × D)
    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(ROWS, D),
        strides=(x_stride_row, x_stride_dim),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),  # Start at our tile
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),  # Column-major within block
    )
    
    # Create block pointer for weight (1D: D)
    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(weight_stride_dim,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
    )
    
    # Create block pointer for output (1D: rows)
    output_block_ptr = tl.make_block_ptr(
        output_ptr,
        shape=(ROWS,),
        strides=(output_stride_row,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )
    
    # Load first tile to determine dtype, then initialize accumulator
    x_tile_first = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")
    weight_tile_first = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")
    
    # Compute first partial sum
    weighted_first = x_tile_first * weight_tile_first[None, :]
    acc = tl.sum(weighted_first, axis=1)
    
    # Advance pointers for next iteration
    x_block_ptr = tl.advance(x_block_ptr, (0, D_TILE_SIZE))
    weight_block_ptr = tl.advance(weight_block_ptr, (D_TILE_SIZE,))
    
    # Loop over remaining tiles in the D dimension
    for i in range(1, tl.cdiv(D, D_TILE_SIZE)):
        # Load tiles
        x_tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")
        weight_tile = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")
        
        # Compute weighted sum for this tile
        # x_tile: (ROWS_TILE_SIZE, D_TILE_SIZE)
        # weight_tile: (D_TILE_SIZE,)
        weighted = x_tile * weight_tile[None, :]  # Broadcast weight
        acc += tl.sum(weighted, axis=1)  # Sum along D dimension
        
        # Advance pointers to next tile
        x_block_ptr = tl.advance(x_block_ptr, (0, D_TILE_SIZE))
        weight_block_ptr = tl.advance(weight_block_ptr, (D_TILE_SIZE,))
    
    # Store result
    tl.store(output_block_ptr, acc, boundary_check=(0,))

print("✓ Forward kernel defined!")


✓ Forward kernel defined!


In [4]:
@triton.jit
def weighted_sum_backward(
    x_ptr, weight_ptr,  # Inputs from forward pass
    grad_output_ptr,  # Gradient w.r.t. output
    grad_x_ptr, partial_grad_weight_ptr,  # Gradient outputs
    stride_xr, stride_xd,
    stride_wd,
    stride_gr,
    stride_gxr, stride_gxd,
    stride_gwb, stride_gwd,
    NUM_ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr,
):
    """
    Triton kernel for weighted sum backward pass.
    
    Computes:
    - grad_x[i,j] = weight[j] * grad_output[i] (outer product)
    - grad_weight[j] = sum_i(x[i,j] * grad_output[i]) (reduction)
    
    For grad_weight, we compute partial sums per tile and reduce later.
    """
    row_tile_idx = tl.program_id(0)
    n_row_tiles = tl.num_programs(0)
    
    # Block pointer for grad_output (1D)
    grad_output_block_ptr = tl.make_block_ptr(
        grad_output_ptr,
        shape=(NUM_ROWS,),
        strides=(stride_gr,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )
    
    # Block pointer for x (2D)
    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(NUM_ROWS, D),
        strides=(stride_xr, stride_xd),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )
    
    # Block pointer for weight (1D)
    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(stride_wd,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
    )
    
    # Block pointer for grad_x (2D)
    grad_x_block_ptr = tl.make_block_ptr(
        grad_x_ptr,
        shape=(NUM_ROWS, D),
        strides=(stride_gxr, stride_gxd),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )
    
    # Block pointer for partial grad_weight (2D: n_tiles × D)
    partial_grad_weight_block_ptr = tl.make_block_ptr(
        partial_grad_weight_ptr,
        shape=(n_row_tiles, D),
        strides=(stride_gwb, stride_gwd),
        offsets=(row_tile_idx, 0),
        block_shape=(1, D_TILE_SIZE),
        order=(1, 0),
    )
    
    # Load grad_output once (same for all D tiles)
    grad_output = tl.load(grad_output_block_ptr, boundary_check=(0,), padding_option="zero")
    
    # Loop over D dimension
    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        # Compute grad_x: outer product of grad_output and weight
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")
        grad_x_tile = grad_output[:, None] * weight[None, :]  # (ROWS_TILE_SIZE, D_TILE_SIZE)
        tl.store(grad_x_block_ptr, grad_x_tile, boundary_check=(0, 1))
        
        # Compute partial grad_weight: reduce over rows in this tile
        x_tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")
        grad_weight_tile = tl.sum(x_tile * grad_output[:, None], axis=0, keep_dims=True)
        tl.store(partial_grad_weight_block_ptr, grad_weight_tile, boundary_check=(1,))
        
        # Advance pointers
        x_block_ptr = tl.advance(x_block_ptr, (0, D_TILE_SIZE))
        weight_block_ptr = tl.advance(weight_block_ptr, (D_TILE_SIZE,))
        grad_x_block_ptr = tl.advance(grad_x_block_ptr, (0, D_TILE_SIZE))
        partial_grad_weight_block_ptr = tl.advance(partial_grad_weight_block_ptr, (0, D_TILE_SIZE))

print("✓ Backward kernel defined!")


✓ Backward kernel defined!


In [5]:
class WeightedSumFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight):
        # Save dimensions
        D = x.shape[-1]
        output_dims = x.shape[:-1]
        
        # Reshape to 2D for kernel
        input_shape = x.shape
        x = x.reshape(-1, D)
        
        # Validation
        assert len(weight.shape) == 1 and weight.shape[0] == D, "Dimension mismatch"
        assert x.is_cuda and weight.is_cuda, "Expected CUDA tensors"
        assert x.is_contiguous(), "x must be contiguous"
        
        # Save for backward
        ctx.save_for_backward(x, weight)
        
        # Choose tile sizes
        ctx.D_TILE_SIZE = min(triton.next_power_of_2(D) // 16, 128)
        ctx.D_TILE_SIZE = max(ctx.D_TILE_SIZE, 1)
        ctx.ROWS_TILE_SIZE = 16
        ctx.input_shape = input_shape
        
        # Allocate output
        n_rows = x.shape[0]
        y = torch.empty(n_rows, device=x.device, dtype=x.dtype)
        
        # Launch kernel
        grid = (cdiv(n_rows, ctx.ROWS_TILE_SIZE),)
        weighted_sum_fwd[grid](
            x, weight,
            y,
            x.stride(0), x.stride(1),
            weight.stride(0),
            y.stride(0),
            ROWS=n_rows, D=D,
            ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE,
            D_TILE_SIZE=ctx.D_TILE_SIZE,
        )
        
        return y.view(input_shape[:-1])
    
    @staticmethod
    def backward(ctx, grad_output):
        x, weight = ctx.saved_tensors
        ROWS_TILE_SIZE, D_TILE_SIZE = ctx.ROWS_TILE_SIZE, ctx.D_TILE_SIZE
        n_rows, D = x.shape
        
        # Flatten grad_output
        grad_output = grad_output.reshape(-1).contiguous()
        
        # Allocate outputs
        grad_x = torch.empty_like(x)
        n_tiles = cdiv(n_rows, ROWS_TILE_SIZE)
        partial_grad_weight = torch.empty((n_tiles, D), device=x.device, dtype=x.dtype)
        
        # Launch kernel
        grid = (n_tiles,)
        weighted_sum_backward[grid](
            x, weight,
            grad_output,
            grad_x, partial_grad_weight,
            x.stride(0), x.stride(1),
            weight.stride(0),
            grad_output.stride(0),
            grad_x.stride(0), grad_x.stride(1),
            partial_grad_weight.stride(0), partial_grad_weight.stride(1),
            NUM_ROWS=n_rows, D=D,
            ROWS_TILE_SIZE=ROWS_TILE_SIZE,
            D_TILE_SIZE=D_TILE_SIZE,
        )
        
        # Reduce partial gradients
        grad_weight = partial_grad_weight.sum(axis=0)
        
        # Reshape grad_x back to original shape
        grad_x = grad_x.view(ctx.input_shape)
        
        return grad_x, grad_weight

# Create the function
weighted_sum_triton = WeightedSumFunc.apply

print("✓ Autograd function created!")


✓ Autograd function created!


In [6]:

torch.manual_seed(42)
n_rows, D = 32, 64
x = torch.randn(n_rows, D, device='cuda', requires_grad=True)
weight = torch.randn(D, device='cuda', requires_grad=True)

# PyTorch reference
output_pytorch = weighted_sum_pytorch(x, weight)

# Triton implementation
output_triton = weighted_sum_triton(x, weight)

# Compare
print(f"PyTorch output shape: {output_pytorch.shape}")
print(f"Triton output shape: {output_triton.shape}")
print(f"\nMax absolute difference: {(output_pytorch - output_triton).abs().max().item():.2e}")
print(f"Mean absolute difference: {(output_pytorch - output_triton).abs().mean().item():.2e}")
print(f"\n✓ Outputs match: {torch.allclose(output_pytorch, output_triton, rtol=1e-4, atol=1e-4)}")


PyTorch output shape: torch.Size([32])
Triton output shape: torch.Size([32])

Max absolute difference: 9.54e-07
Mean absolute difference: 3.02e-07

✓ Outputs match: True


In [7]:
torch.manual_seed(42)
n_rows, D = 32, 64
x_pt = torch.randn(n_rows, D, device='cuda', requires_grad=True, dtype=torch.float64)
weight_pt = torch.randn(D, device='cuda', requires_grad=True, dtype=torch.float64)

x_tr = x_pt.clone().detach().requires_grad_(True)
weight_tr = weight_pt.clone().detach().requires_grad_(True)

# Forward pass
output_pt = weighted_sum_pytorch(x_pt, weight_pt)
output_tr = weighted_sum_triton(x_tr, weight_tr)

# Backward pass
grad_output = torch.randn_like(output_pt)
output_pt.backward(grad_output)
output_tr.backward(grad_output)

# Compare gradients
print("Gradient w.r.t. x:")
print(f"  Max absolute difference: {(x_pt.grad - x_tr.grad).abs().max().item():.2e}")
print(f"  ✓ Gradients match: {torch.allclose(x_pt.grad, x_tr.grad, rtol=1e-4, atol=1e-4)}")

print("\nGradient w.r.t. weight:")
print(f"  Max absolute difference: {(weight_pt.grad - weight_tr.grad).abs().max().item():.2e}")
print(f"  ✓ Gradients match: {torch.allclose(weight_pt.grad, weight_tr.grad, rtol=1e-4, atol=1e-4)}")


Gradient w.r.t. x:
  Max absolute difference: 0.00e+00
  ✓ Gradients match: True

Gradient w.r.t. weight:
  Max absolute difference: 3.55e-15
  ✓ Gradients match: True


In [8]:
test_shapes = [
    (16, 32),      # Small
    (128, 256),    # Medium
    (1024, 512),   # Large
    (8, 16, 64),   # Batched (batch_size=8, seq_len=16, D=64)
]

for shape in test_shapes:
    D = shape[-1]
    x = torch.randn(*shape, device='cuda')
    weight = torch.randn(D, device='cuda')
    
    output_pt = weighted_sum_pytorch(x, weight)
    output_tr = weighted_sum_triton(x, weight)
    
    matches = torch.allclose(output_pt, output_tr, rtol=1e-4, atol=1e-4)
    max_diff = (output_pt - output_tr).abs().max().item()
    
    status = "✓" if matches else "✗"
    print(f"{status} Shape {shape}: max diff = {max_diff:.2e}")


✓ Shape (16, 32): max diff = 4.77e-07
✓ Shape (128, 256): max diff = 5.72e-06
✓ Shape (1024, 512): max diff = 1.14e-05
✓ Shape (8, 16, 64): max diff = 2.86e-06


## Step 7: PyTorch-Only Custom Autograd Function

Now let's implement the same weighted sum operation using **only PyTorch operations** with a custom `autograd.Function`. This demonstrates:

1. How to define custom forward and backward passes without Triton
2. The mathematical derivations from the assignment (Equation 2)
3. How PyTorch's autograd system works under the hood

### Mathematical Background

Given operation `f(x, w) = (w * x).sum(axis=-1)` where:
- `x` has shape `[n, D]` (or `[..., D]` for batched)
- `w` has shape `[D]`
- Output has shape `[n]` (or `[...]` for batched)

**Backward pass gradients** (from Equation 2 in the assignment):

1. **Gradient w.r.t. x**: `(∇_x L)_ij = w_j · (∇_f L)_i`
   - This is an outer product: `grad_x = grad_output[:, None] * weight[None, :]`

2. **Gradient w.r.t. w**: `(∇_w L)_j = Σ_i x_ij · (∇_f L)_i`
   - This is a reduction: `grad_weight = (x * grad_output[:, None]).sum(dim=0)`


In [9]:
class WeightedSumPyTorchFunc(torch.autograd.Function):
    """
    Custom autograd function using pure PyTorch operations.
    
    This implementation manually defines forward and backward passes,
    showing how autograd works without custom kernels.
    """
    
    @staticmethod
    def forward(ctx, x, weight):
        """
        Forward pass: compute (weight * x).sum(axis=-1)
        
        Args:
            x: Input tensor of shape [..., D]
            weight: Weight vector of shape [D]
            
        Returns:
            output: Tensor of shape [...] (last dimension summed out)
        """
        # Save tensors for backward pass
        ctx.save_for_backward(x, weight)
        
        # Compute weighted sum using PyTorch operations
        # Broadcasting: weight[None, :] creates shape [1, D] to broadcast with x
        output = (weight * x).sum(dim=-1)
        
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: compute gradients w.r.t. inputs
        
        Args:
            grad_output: Gradient w.r.t. output, shape [...]
            
        Returns:
            grad_x: Gradient w.r.t. x, shape [..., D]
            grad_weight: Gradient w.r.t. weight, shape [D]
        """
        # Retrieve saved tensors
        x, weight = ctx.saved_tensors
        
        # Get the original shape for proper broadcasting
        # grad_output has shape [...], we need [..., 1] for broadcasting
        grad_output_expanded = grad_output.unsqueeze(-1)  # [..., 1]
        
        # Gradient w.r.t. x: outer product of grad_output and weight
        # Formula: (∇_x L)_ij = w_j · (∇_f L)_i
        # grad_output_expanded: [..., 1], weight: [D]
        # Result: [..., D]
        grad_x = grad_output_expanded * weight
        
        # Gradient w.r.t. weight: weighted sum of x by grad_output
        # Formula: (∇_w L)_j = Σ_i x_ij · (∇_f L)_i
        # We need to sum over all dimensions except the last (D)
        # x: [..., D], grad_output_expanded: [..., 1]
        # First multiply, then sum over all batch dimensions
        grad_weight = (x * grad_output_expanded).sum(dim=tuple(range(grad_output.ndim)))
        
        return grad_x, grad_weight


# Create the function
weighted_sum_pytorch_custom = WeightedSumPyTorchFunc.apply

print("✓ PyTorch custom autograd function created!")


✓ PyTorch custom autograd function created!


### Test 1: Forward Pass Correctness

Let's verify that our custom PyTorch autograd function produces the same results as the reference implementation.


In [10]:
torch.manual_seed(42)
n_rows, D = 32, 64
x = torch.randn(n_rows, D, device='cuda', requires_grad=True)
weight = torch.randn(D, device='cuda', requires_grad=True)

# Reference implementation
output_ref = weighted_sum_pytorch(x, weight)

# Custom PyTorch autograd implementation
output_custom = weighted_sum_pytorch_custom(x, weight)

# Compare
print(f"Reference output shape: {output_ref.shape}")
print(f"Custom output shape: {output_custom.shape}")
print(f"\nMax absolute difference: {(output_ref - output_custom).abs().max().item():.2e}")
print(f"Mean absolute difference: {(output_ref - output_custom).abs().mean().item():.2e}")
print(f"\n✓ Outputs match: {torch.allclose(output_ref, output_custom, rtol=1e-5, atol=1e-8)}")

# Show that it has the correct grad_fn
print(f"\nGrad function: {output_custom.grad_fn}")


Reference output shape: torch.Size([32])
Custom output shape: torch.Size([32])

Max absolute difference: 0.00e+00
Mean absolute difference: 0.00e+00

✓ Outputs match: True

Grad function: <torch.autograd.function.WeightedSumPyTorchFuncBackward object at 0x784c1445e470>


### Test 2: Backward Pass Correctness

Now let's verify that the gradients computed by our custom backward pass match PyTorch's automatic differentiation.


In [11]:
torch.manual_seed(42)
n_rows, D = 32, 64

# Create separate tensors for reference and custom implementations
x_ref = torch.randn(n_rows, D, device='cuda', requires_grad=True, dtype=torch.float64)
weight_ref = torch.randn(D, device='cuda', requires_grad=True, dtype=torch.float64)

x_custom = x_ref.clone().detach().requires_grad_(True)
weight_custom = weight_ref.clone().detach().requires_grad_(True)

# Forward pass
output_ref = weighted_sum_pytorch(x_ref, weight_ref)
output_custom = weighted_sum_pytorch_custom(x_custom, weight_custom)

# Backward pass with same gradient
grad_output = torch.randn_like(output_ref)
output_ref.backward(grad_output)
output_custom.backward(grad_output)

# Compare gradients
print("Gradient w.r.t. x:")
print(f"  Max absolute difference: {(x_ref.grad - x_custom.grad).abs().max().item():.2e}")
print(f"  Mean absolute difference: {(x_ref.grad - x_custom.grad).abs().mean().item():.2e}")
print(f"  ✓ Gradients match: {torch.allclose(x_ref.grad, x_custom.grad, rtol=1e-5, atol=1e-8)}")

print("\nGradient w.r.t. weight:")
print(f"  Max absolute difference: {(weight_ref.grad - weight_custom.grad).abs().max().item():.2e}")
print(f"  Mean absolute difference: {(weight_ref.grad - weight_custom.grad).abs().mean().item():.2e}")
print(f"  ✓ Gradients match: {torch.allclose(weight_ref.grad, weight_custom.grad, rtol=1e-5, atol=1e-8)}")


Gradient w.r.t. x:
  Max absolute difference: 0.00e+00
  Mean absolute difference: 0.00e+00
  ✓ Gradients match: True

Gradient w.r.t. weight:
  Max absolute difference: 0.00e+00
  Mean absolute difference: 0.00e+00
  ✓ Gradients match: True


### Test 3: Gradient Checking with torch.autograd.gradcheck

PyTorch provides `gradcheck` to numerically verify that our backward pass is correct by comparing it to finite differences.


In [12]:
from torch.autograd import gradcheck

# Use small inputs and double precision for numerical stability
torch.manual_seed(42)
x_test = torch.randn(4, 8, device='cuda', dtype=torch.float64, requires_grad=True)
weight_test = torch.randn(8, device='cuda', dtype=torch.float64, requires_grad=True)

# gradcheck takes a function and inputs, then numerically checks gradients
test_passed = gradcheck(
    weighted_sum_pytorch_custom, 
    (x_test, weight_test),
    eps=1e-6,
    atol=1e-4,
    rtol=1e-3,
    raise_exception=False
)

print(f"✓ Gradient check passed: {test_passed}")

if test_passed:
    print("\n✅ Our custom backward implementation is mathematically correct!")
    print("   The gradients match numerical finite difference approximations.")
else:
    print("\n❌ Gradient check failed - there may be an error in the backward pass.")


✓ Gradient check passed: True

✅ Our custom backward implementation is mathematically correct!
   The gradients match numerical finite difference approximations.


### Test 4: Batched Inputs

Test with various shapes including batched inputs to ensure our implementation handles arbitrary dimensions correctly.


In [13]:
test_shapes = [
    (16, 32),           # 2D: Simple case
    (128, 256),         # 2D: Larger
    (8, 16, 64),        # 3D: Batched (batch_size=8, seq_len=16, D=64)
    (4, 8, 16, 32),     # 4D: Multi-batch (batch=4, heads=8, seq=16, D=32)
]

print("Testing custom PyTorch autograd function with various shapes:\n")

for shape in test_shapes:
    D = shape[-1]
    x = torch.randn(*shape, device='cuda', requires_grad=True)
    weight = torch.randn(D, device='cuda', requires_grad=True)
    
    # Forward pass
    output_ref = weighted_sum_pytorch(x, weight)
    output_custom = weighted_sum_pytorch_custom(x, weight)
    
    # Check forward pass
    forward_matches = torch.allclose(output_ref, output_custom, rtol=1e-5, atol=1e-8)
    max_diff = (output_ref - output_custom).abs().max().item()
    
    # Backward pass
    grad_output = torch.randn_like(output_ref)
    
    x_ref = x.clone().detach().requires_grad_(True)
    weight_ref = weight.clone().detach().requires_grad_(True)
    output_ref = weighted_sum_pytorch(x_ref, weight_ref)
    output_ref.backward(grad_output)
    
    x_custom = x.clone().detach().requires_grad_(True)
    weight_custom = weight.clone().detach().requires_grad_(True)
    output_custom = weighted_sum_pytorch_custom(x_custom, weight_custom)
    output_custom.backward(grad_output)
    
    # Check backward pass
    grad_x_matches = torch.allclose(x_ref.grad, x_custom.grad, rtol=1e-5, atol=1e-8)
    grad_w_matches = torch.allclose(weight_ref.grad, weight_custom.grad, rtol=1e-5, atol=1e-8)
    
    status = "✓" if (forward_matches and grad_x_matches and grad_w_matches) else "✗"
    print(f"{status} Shape {shape}:")
    print(f"   Forward max diff: {max_diff:.2e}")
    print(f"   Backward (x): {grad_x_matches}, Backward (w): {grad_w_matches}")


Testing custom PyTorch autograd function with various shapes:

✓ Shape (16, 32):
   Forward max diff: 0.00e+00
   Backward (x): True, Backward (w): True
✓ Shape (128, 256):
   Forward max diff: 0.00e+00
   Backward (x): True, Backward (w): True
✓ Shape (8, 16, 64):
   Forward max diff: 0.00e+00
   Backward (x): True, Backward (w): True
✓ Shape (4, 8, 16, 32):
   Forward max diff: 0.00e+00
   Backward (x): True, Backward (w): True


### Test 5: Visualizing the Gradient Flow

Let's visualize how gradients flow through our custom operation in a simple computational graph.


In [14]:
torch.manual_seed(42)

# Create a simple example
x = torch.randn(3, 4, device='cuda', requires_grad=True)
weight = torch.randn(4, device='cuda', requires_grad=True)

print("Input x:")
print(x)
print(f"\nWeight w:")
print(weight)

# Forward pass
output = weighted_sum_pytorch_custom(x, weight)
print(f"\nOutput (weighted sum):")
print(output)
print(f"Output shape: {output.shape}")

# Create a simple loss (sum of outputs)
loss = output.sum()
print(f"\nLoss (sum of outputs): {loss.item():.4f}")

# Backward pass
loss.backward()

print(f"\nGradient w.r.t. x:")
print(x.grad)
print(f"Shape: {x.grad.shape}")

print(f"\nGradient w.r.t. weight:")
print(weight.grad)
print(f"Shape: {weight.grad.shape}")

print("\n" + "="*60)
print("Understanding the gradients:")
print("="*60)
print("\n1. grad_x = grad_output[:, None] * weight[None, :]")
print("   - Each element x[i,j] contributes to output[i] via weight[j]")
print("   - So grad_x[i,j] = grad_output[i] * weight[j]")
print("\n2. grad_weight = (x * grad_output[:, None]).sum(dim=0)")
print("   - Each weight[j] affects all outputs through x[:, j]")
print("   - So grad_weight[j] = sum_i(x[i,j] * grad_output[i])")


Input x:
tensor([[ 0.1940,  2.1614, -0.1721,  0.8491],
        [-1.9244,  0.6530, -0.6494, -0.8175],
        [ 0.5280, -1.2753, -1.6621, -0.3033]], device='cuda:0',
       requires_grad=True)

Weight w:
tensor([ 0.1391, -0.1082, -0.7174,  0.7566], device='cuda:0',
       requires_grad=True)

Output (weighted sum):
tensor([ 0.5590, -0.4911,  1.1744], device='cuda:0',
       grad_fn=<WeightedSumPyTorchFuncBackward>)
Output shape: torch.Size([3])

Loss (sum of outputs): 1.2423

Gradient w.r.t. x:
tensor([[ 0.1391, -0.1082, -0.7174,  0.7566],
        [ 0.1391, -0.1082, -0.7174,  0.7566],
        [ 0.1391, -0.1082, -0.7174,  0.7566]], device='cuda:0')
Shape: torch.Size([3, 4])

Gradient w.r.t. weight:
tensor([-1.2024,  1.5390, -2.4836, -0.2718], device='cuda:0')
Shape: torch.Size([4])

Understanding the gradients:

1. grad_x = grad_output[:, None] * weight[None, :]
   - Each element x[i,j] contributes to output[i] via weight[j]
   - So grad_x[i,j] = grad_output[i] * weight[j]

2. grad_weigh

## Comparison: PyTorch Custom vs Triton vs Reference

Let's compare all three implementations side by side.


In [15]:
torch.manual_seed(42)
n_rows, D = 128, 256

x = torch.randn(n_rows, D, device='cuda', dtype=torch.float32)
weight = torch.randn(D, device='cuda', dtype=torch.float32)

# Test all three implementations
output_ref = weighted_sum_pytorch(x, weight)
output_custom = weighted_sum_pytorch_custom(x, weight)
output_triton = weighted_sum_triton(x, weight)

print("Forward Pass Comparison:")
print("="*60)
print(f"Reference output shape: {output_ref.shape}")
print(f"Custom PyTorch output shape: {output_custom.shape}")
print(f"Triton output shape: {output_triton.shape}")

print(f"\nCustom vs Reference - Max diff: {(output_custom - output_ref).abs().max().item():.2e}")
print(f"Triton vs Reference - Max diff: {(output_triton - output_ref).abs().max().item():.2e}")
print(f"Custom vs Triton - Max diff: {(output_custom - output_triton).abs().max().item():.2e}")

print("\n" + "="*60)
print("Backward Pass Comparison:")
print("="*60)

# Test gradients
x_ref = x.clone().detach().requires_grad_(True)
weight_ref = weight.clone().detach().requires_grad_(True)
x_custom = x.clone().detach().requires_grad_(True)
weight_custom = weight.clone().detach().requires_grad_(True)
x_triton = x.clone().detach().requires_grad_(True)
weight_triton = weight.clone().detach().requires_grad_(True)

output_ref = weighted_sum_pytorch(x_ref, weight_ref)
output_custom = weighted_sum_pytorch_custom(x_custom, weight_custom)
output_triton = weighted_sum_triton(x_triton, weight_triton)

grad_output = torch.randn_like(output_ref)
output_ref.backward(grad_output)
output_custom.backward(grad_output)
output_triton.backward(grad_output)

print(f"\nGradient w.r.t. x:")
print(f"  Custom vs Reference - Max diff: {(x_custom.grad - x_ref.grad).abs().max().item():.2e}")
print(f"  Triton vs Reference - Max diff: {(x_triton.grad - x_ref.grad).abs().max().item():.2e}")

print(f"\nGradient w.r.t. weight:")
print(f"  Custom vs Reference - Max diff: {(weight_custom.grad - weight_ref.grad).abs().max().item():.2e}")
print(f"  Triton vs Reference - Max diff: {(weight_triton.grad - weight_ref.grad).abs().max().item():.2e}")

print("\n✅ All implementations produce consistent results!")


Forward Pass Comparison:
Reference output shape: torch.Size([128])
Custom PyTorch output shape: torch.Size([128])
Triton output shape: torch.Size([128])

Custom vs Reference - Max diff: 0.00e+00
Triton vs Reference - Max diff: 5.72e-06
Custom vs Triton - Max diff: 5.72e-06

Backward Pass Comparison:

Gradient w.r.t. x:
  Custom vs Reference - Max diff: 0.00e+00
  Triton vs Reference - Max diff: 0.00e+00

Gradient w.r.t. weight:
  Custom vs Reference - Max diff: 0.00e+00
  Triton vs Reference - Max diff: 3.81e-06

✅ All implementations produce consistent results!


## Key Takeaways

### PyTorch Custom Autograd Functions

**When to use PyTorch-only custom autograd:**
- You need custom gradient behavior (e.g., gradient clipping, custom chain rule)
- You want to understand how autograd works internally
- You're implementing operations that compose well with existing PyTorch ops
- You don't need low-level performance optimization

**Key concepts:**
1. **`torch.autograd.Function`**: Base class for custom differentiable operations
2. **`forward(ctx, *args)`**: Compute output, save tensors needed for backward
3. **`backward(ctx, grad_output)`**: Compute gradients w.r.t. inputs using chain rule
4. **`ctx.save_for_backward()`**: Efficiently save tensors for backward pass
5. **Broadcasting**: Use `unsqueeze()` and broadcasting for gradient computation

**Mathematical formulas (from assignment Equation 2):**
- Forward: `f(x, w) = (w * x).sum(axis=-1)`
- Backward for x: `(∇_x L)_ij = w_j · (∇_f L)_i` (outer product)
- Backward for w: `(∇_w L)_j = Σ_i x_ij · (∇_f L)_i` (reduction)

---

### Triton Custom Kernels

**When to use Triton:**
- You need high performance for custom operations
- Memory access patterns matter (e.g., FlashAttention)
- You want GPU-level optimization without writing CUDA
- Standard PyTorch ops don't fuse well for your use case

**Key concepts:**
1. **Block Pointers**: `tl.make_block_ptr()` simplifies memory access
2. **Tiling Strategy**: Process data in tiles for memory locality and parallelism
3. **Program IDs**: `tl.program_id(0)` divides work across thread blocks
4. **Boundary Checking**: Handle edge cases when tiles don't evenly divide inputs
5. **Manual Memory Management**: Explicit loads/stores with pointer arithmetic

---

### Comparison Summary

| Aspect | PyTorch Custom | Triton Kernel |
|--------|---------------|---------------|
| **Complexity** | Low - just Python | Medium - need GPU concepts |
| **Performance** | Good - optimized PyTorch ops | Excellent - fine-grained control |
| **Use Case** | Custom gradients, simple ops | Performance-critical kernels |
| **Debugging** | Easy - standard Python | Harder - GPU-specific issues |
| **Portability** | Works on CPU/GPU | GPU-only |

---

### Backward Pass Strategies

**PyTorch approach:**
- Use `.sum(dim=...)` to reduce over batch dimensions
- PyTorch handles the reduction efficiently

**Triton approach:**
1. Compute partial results per tile
2. Store in a buffer
3. Reduce outside the kernel with PyTorch

---

## Experiment Further!

Try modifying:
- **The operation**: Add a bias term, use different aggregation functions (max, mean)
- **Custom gradients**: Implement gradient clipping or custom scaling
- **Input shapes**: Test with different batch sizes and dimensions
- **Data types**: Try `float16`, `bfloat16`, or `float64`
- **Composition**: Chain multiple custom autograd functions together
- **Numerical stability**: Add epsilon terms or use stable formulations

This pattern extends to more complex operations:
- **Attention mechanisms**: Custom attention with different scoring functions
- **Layer normalization**: With learnable affine parameters
- **Custom activations**: GELU, Swish with custom backward passes
- **Sparse operations**: Custom sparse matrix operations
- **Quantization**: Custom quantization/dequantization ops

---

### Further Reading
- [PyTorch Autograd Mechanics](https://pytorch.org/docs/stable/notes/autograd.html)
- [Extending PyTorch](https://pytorch.org/docs/stable/notes/extending.html)
- [Triton Documentation](https://triton-lang.org/)
- [CS336 Assignment 2](https://stanford-cs336.github.io/spring2025/assignments/assignment2.html)
