# PyTorch Tutorial: Performance Engineering (Triton & Profiling)

In FAANG, making a model 10% faster can save millions of dollars. This chapter moves beyond "making it work" to "making it fast".

## Learning Objectives
- **Profile** your code to find bottlenecks.
- Use **`torch.compile`** for free speedups.
- Write a custom GPU kernel using **Triton**.

## 1. Vocabulary First

- **Latency**: Time per request. How long a single inference takes.
- **Throughput**: Requests per second. How many inferences you can do in parallel.
- **Kernel**: A function that runs on the GPU. Every PyTorch operation (matmul, relu, add) launches one or more GPU kernels.
- **Fusion**: Combining multiple operations (Add + Multiply) into one kernel to save memory bandwidth.
- **Triton**: A language from OpenAI to write GPU kernels in Python (instead of CUDA C++).

### Compute-Bound vs Memory-Bound (The Most Important Concept)

Every GPU operation is bottlenecked by one of two things:

**Compute-bound**: The GPU's math units (FLOPS) are the bottleneck.
- Large matrix multiplications (e.g., `torch.matmul` on big tensors)
- Convolutions with many channels
- The GPU cores are doing math as fast as they can

**Memory-bound**: Moving data between GPU memory (HBM) and compute cores is the bottleneck.
- Element-wise operations (ReLU, Add, LayerNorm)
- Small matrix operations
- The compute cores are **idle**, waiting for data to arrive

**Why this matters**: Most PyTorch operations are memory-bound. The GPU can compute faster than it can read data. This is why **operator fusion** is so powerful — it eliminates unnecessary reads/writes.

### The Roofline Model (How to Think About Performance)

```
Throughput
    │
    │         ╱ Compute ceiling (max FLOPS)
    │        ╱
    │       ╱───────────────── Compute-bound region
    │      ╱
    │     ╱
    │    ╱  Memory-bound region
    │   ╱
    │──╱───────────────────────
    └────────────────────────── Arithmetic Intensity (FLOPS/byte)
```

- **Left side**: Operations with low arithmetic intensity are memory-bound
- **Right side**: Operations with high arithmetic intensity are compute-bound
- **Goal**: Move operations to the right (higher arithmetic intensity) via fusion

### What is Operator Fusion?

Without fusion (3 separate kernel launches, 3 memory round-trips):
```
Read X from HBM → Compute matmul → Write Y to HBM
Read Y from HBM → Compute ReLU  → Write Z to HBM
Read Z from HBM → Compute Add   → Write W to HBM
```

With fusion (1 kernel launch, 1 memory round-trip):
```
Read X from HBM → Compute (matmul + ReLU + Add) → Write W to HBM
```

This is exactly what `torch.compile` does automatically.

In [None]:
import torch
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

## 2. Profiling with `torch.profiler`

Stop guessing where your code is slow. Measure it.

### How to Read Profiler Output

The profiler shows you:
- **Self CPU time**: Time spent in the operation itself (not counting sub-operations)
- **CPU total**: Total time including sub-operations
- **CUDA time**: Time spent on GPU (if applicable)
- **# Calls**: How many times the operation was called

**Reading strategy**: Sort by `cpu_time_total` or `cuda_time_total`. The top entries are your bottlenecks. If an operation shows high CPU time but low CUDA time, you may have a CPU bottleneck (data preprocessing, Python overhead). If CUDA time dominates, optimize the GPU operations.

### Common Bottlenecks You'll Find

1. **Data loading**: CPU can't feed data fast enough. Fix: increase `num_workers` in DataLoader.
2. **CPU-GPU sync**: Operations like `.item()`, `print(tensor)`, or `if tensor > 0` force the GPU to wait. Fix: batch these operations.
3. **Small kernel launches**: Many tiny GPU operations. Fix: use `torch.compile` for fusion.
4. **Memory copies**: Frequent `.to(device)` calls. Fix: move data to GPU once, keep it there.

In [None]:
def heavy_computation(x):
    return torch.matmul(x, x) + torch.relu(x)

x = torch.randn(1000, 1000, device=device)

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU],
    record_shapes=True
) as prof:
    heavy_computation(x)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

## 3. `torch.compile` (PyTorch 2.0)

The easiest way to speed up PyTorch code. It fuses operations automatically.

### How `torch.compile` Works Under the Hood

1. **Tracing**: PyTorch captures the computation graph (what operations are called and in what order).
2. **Graph optimization**: The compiler identifies fusible operations and eliminates redundant memory accesses.
3. **Code generation**: It generates optimized Triton kernels (or uses pre-existing ones) for the fused operations.
4. **Caching**: The compiled code is cached, so subsequent calls are fast.

### Compilation Modes

```python
# Default: Good balance of compile time and speed
model = torch.compile(model)

# Maximum optimization (slower compile, faster execution)
model = torch.compile(model, mode="max-autotune")

# Fastest compile time (less optimization)
model = torch.compile(model, mode="reduce-overhead")
```

### When `torch.compile` Doesn't Help

- **Dynamic shapes**: If tensor sizes change every iteration, the compiler recompiles each time (slow). Use `dynamic=True` to mitigate.
- **Data-dependent control flow**: `if x.sum() > 0` breaks the graph because the compiler can't predict the branch at compile time.
- **Custom C extensions**: The compiler can only optimize pure PyTorch operations.

### Typical Speedups

| Workload | Speedup with `torch.compile` |
|----------|------|
| Transformer training | 1.3-1.5x |
| CNN inference | 1.2-1.4x |
| Element-wise ops (memory-bound) | 2-3x |
| Already-optimized code (cuBLAS matmuls) | ~1x (no gain) |

In [None]:
@torch.compile
def fast_computation(x):
    return torch.sin(x) + torch.cos(x)

# First run compiles (might be slow)
start = time.time()
fast_computation(x)
print(f"First run (compilation): {time.time() - start:.4f}s")

# Second run is fast
start = time.time()
fast_computation(x)
print(f"Second run (cached): {time.time() - start:.4f}s")

## 4. Writing Custom Kernels with Triton

When `torch.compile` isn't enough, you write your own kernels. Triton makes this accessible to Python engineers.

### Why Triton Over CUDA?

| Aspect | CUDA C++ | Triton (Python) |
|--------|----------|-----------------|
| Language | C++ | Python |
| Learning curve | Months | Days |
| Memory management | Manual (shared memory, tiling) | Automatic |
| Performance | Maximum (hand-tuned) | 90-95% of CUDA |
| Portability | NVIDIA only | NVIDIA (AMD support improving) |
| Used by | NVIDIA engineers | ML researchers, PyTorch team |

**Key insight**: Triton auto-tunes memory tiling and shared memory usage — the hardest parts of GPU programming. You focus on the algorithm; Triton handles the hardware details.

### Flash Attention (The Most Important Custom Kernel)

Standard attention computes:
```
Q, K, V are [batch, heads, seq_len, dim]
Attention = softmax(Q @ K.T / sqrt(d)) @ V
```

The problem: `Q @ K.T` creates an `[seq_len, seq_len]` matrix. For seq_len=8192, that's 256MB per head per batch — enormous memory usage.

**Flash Attention** fuses the entire attention computation into one kernel:
- Never materializes the full `[seq_len, seq_len]` attention matrix
- Processes attention in tiles (blocks) that fit in fast SRAM
- Reduces memory from O(N^2) to O(N)
- 2-4x faster than standard attention

```python
# Using Flash Attention in PyTorch (built-in since 2.0)
from torch.nn.functional import scaled_dot_product_attention
output = scaled_dot_product_attention(Q, K, V)  # Automatically uses Flash Attention
```

*(Note: The Triton kernel below requires a GPU to run)*

In [None]:
import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)

def triton_add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output

# if torch.cuda.is_available():
#     x = torch.randn(1000, device='cuda')
#     y = torch.randn(1000, device='cuda')
#     out = triton_add(x, y)
#     print("Triton add successful!")

## Key Takeaways

1. **Profile first**: Don't optimize blindly. Use `torch.profiler` to find actual bottlenecks.
2. **Understand the bottleneck type**: Is it compute-bound or memory-bound? The fix is different for each.
3. **Use `torch.compile`**: It's free speed for most workloads — automatic operator fusion without code changes.
4. **Flash Attention**: The single most impactful optimization for Transformers. Use `scaled_dot_product_attention` in PyTorch 2.0+.
5. **Triton**: The secret weapon for custom high-performance layers when you need to go beyond what `torch.compile` offers.

### Performance Optimization Checklist

```
1. Profile your code (find the actual bottleneck)
   ↓
2. Is it data loading? → Increase num_workers, use pin_memory=True
   ↓
3. Is it many small GPU ops? → Use torch.compile for fusion
   ↓
4. Is it attention? → Use scaled_dot_product_attention (Flash Attention)
   ↓
5. Is it a custom operation? → Write a Triton kernel
   ↓
6. Still slow? → Consider mixed precision (torch.autocast), quantization, or better hardware
```