In [None]:
# FlashAttention: GPU/FPGA Kernel Implementation and Optimization

This notebook demonstrates the implementation and optimization of FlashAttention, a novel ML primitive that reduces memory complexity from O(N²) to O(N) through innovative tiling and recomputation strategies.

## Learning Objectives
1. **Algorithmic Optimization for Hardware**: Understand memory access patterns and compute-memory trade-offs
2. **Kernel Fusion Techniques**: Learn advanced operator fusion for better throughput
3. **Hardware-Specific Debugging**: Master GPU profiling and performance analysis
4. **Advanced Extensions**: Explore hardware modifications for further acceleration


In [None]:
import sys
sys.path.append('../')

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any

# Import our custom implementations
from src.kernels.flash_attention import FlashAttention, FlashAttentionTriton
from src.kernels.moe_routing import MoERouter, ExpertGating
from src.fusion.operator_fusion import FusedLinearReLU, OperatorFuser
from benchmarks.profiler import GPUProfiler, MemoryProfiler, KernelProfiler

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"Compute Capability: {torch.cuda.get_device_capability()}")
