# Kernel Fusion Development Example

This notebook demonstrates the development environment setup for fusion kernels.

In [None]:
import torch
import triton
import numpy as np
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA version: {torch.version.cuda}")
print(f"Triton version: {triton.__version__}")

## Test Basic GPU Operations

In [None]:
# Test basic GPU operations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create tensors on GPU
x = torch.randn(1000, 1000, device=device)
y = torch.randn(1000, 1000, device=device)

# Perform matrix multiplication
z = torch.matmul(x, y)
print(f"Matrix multiplication result shape: {z.shape}")
print(f"Result tensor device: {z.device}")

## Example: Simple Triton Kernel

In [None]:
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    """
    Simple element-wise addition kernel
    """
    # Get the program ID
    pid = tl.program_id(axis=0)
    
    # Compute block start
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    
    # Create mask for bounds checking
    mask = offsets < n_elements
    
    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    
    # Compute output
    output = x + y
    
    # Store result
    tl.store(output_ptr + offsets, output, mask=mask)


def add_triton(x: torch.Tensor, y: torch.Tensor):
    """
    Wrapper function for Triton addition kernel
    """
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda
    n_elements = output.numel()
    
    # The BLOCK_SIZE must be a power of 2
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
    return output

# Test the kernel
if torch.cuda.is_available():
    a = torch.randn(1000, device='cuda')
    b = torch.randn(1000, device='cuda')
    
    # Compare Triton vs PyTorch
    result_triton = add_triton(a, b)
    result_torch = a + b
    
    print(f"Max difference: {torch.max(torch.abs(result_triton - result_torch)).item()}")
    print("Triton kernel test passed!" if torch.allclose(result_triton, result_torch) else "Test failed!")
else:
    print("CUDA not available, skipping Triton kernel test")

## Benchmarking Setup

In [None]:
import time

def benchmark_function(func, *args, warmup=10, repeat=100):
    """
    Simple benchmarking utility
    """
    # Warmup
    for _ in range(warmup):
        _ = func(*args)
    
    # Synchronize GPU
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Benchmark
    start_time = time.time()
    for _ in range(repeat):
        _ = func(*args)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    end_time = time.time()
    avg_time = (end_time - start_time) / repeat
    
    return avg_time * 1000  # Return in milliseconds

# Example benchmark
if torch.cuda.is_available():
    sizes = [1000, 2000, 4000, 8000]
    torch_times = []
    triton_times = []
    
    for size in sizes:
        a = torch.randn(size, device='cuda')
        b = torch.randn(size, device='cuda')
        
        torch_time = benchmark_function(torch.add, a, b)
        triton_time = benchmark_function(add_triton, a, b)
        
        torch_times.append(torch_time)
        triton_times.append(triton_time)
        
        print(f"Size {size}: PyTorch {torch_time:.3f}ms, Triton {triton_time:.3f}ms")
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(sizes, torch_times, 'b-o', label='PyTorch')
    plt.plot(sizes, triton_times, 'r-s', label='Triton')
    plt.xlabel('Tensor Size')
    plt.ylabel('Time (ms)')
    plt.title('Performance Comparison: PyTorch vs Triton')
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print("CUDA not available, skipping benchmark")