<a href="https://colab.research.google.com/github/HosseinEyvazi/Deep-Learning/blob/main/Flash_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



### **FlashAttention: The Complete Technical Booklet**
**_Table of Contents_**

**1. __The Hardware Foundation__**
*   _1.1 The Two-Memory System (SRAM vs. HBM)_
*   _1.2 The Golden Rule of GPU Computing_

**2. __Why Standard Attention Creates $O(N^2)$ Space__**
*   _2.1 The Attention Mechanism_
*   _2.2 The “Every Token Talks to Every Token” Problem_
*   _2.3 The Space Bottleneck_

**3. __Standard Attention’s “Forced” Data Movement__**
*   _3.1 The Loop of Forced I/O (Tiled Implementation)_

**4. __FlashAttention’s Solution__**
*   _4.1 Core Insight (Computation vs. Bandwidth)_
*   _4.2 Three Pillars (Kernel Fusion, Online Softmax, Recomputation)_

**5. __FlashAttention’s Memory Structure__**
*   _5.1 Chunk / Tile Dimensions_
*   _5.2 Why FlashAttention is $O(N)$ Space_
*   _5.3 Comparison: Standard vs. FlashAttention_

**6. __Mathematical Proof: Why Accumulation Works__**
*   _6.1 Softmax and Output Equations_
*   _6.2 Partial / Block Processing Problem_
*   _6.3 Online Softmax with Rescaling_
*   _6.4 Proof of Correctness_

**7. __Numerical Stability: Why We Subtract the Max__**
*   _7.1 Overflow Risk in Float32_
*   _7.2 Stable Softmax by Subtracting the Max_

**8. __Time Complexity vs. Runtime__**
*   _8.1 Same FLOPs, Different Runtime_
*   _8.2 Runtime is Dominated by Data Movement_
*   _8.3 Practical Physics_

**9. __On-Chip Memory (SRAM) Explained__**
*   _9.1 Physical Reality_
*   _9.2 Why SRAM is Small_
*   _9.3 Cache / Memory Hierarchy_

**10. __Summary Table__**

**11. __Final Takeaways__**

**__Appendix: Key Formulas at a Glance__**

***

### **Booklet Structure Summary**

Here is the logic flow table to help you scan the document's critical path. Use this to apply **Critical Thinking** to how the hardware constraints dictate the algorithmic solution .

| **Section Block** | **Chapters** | **Core Concept (Focus)** | **Key Takeaway for Research** |
| :--- | :--- | :--- | :--- |
| **The Problem** | Ch 1–3 | **Hardware Bottlenecks** | The GPU is limited by HBM bandwidth (Memory Wall), not compute power. Standard Attention wastes bandwidth by moving $N^2$ matrices . |
| **The Solution** | Ch 4–5 | **Algorithmic Optimization** | **Tiling & Fusion:** Keep data in fast SRAM. Only write the final result to HBM. Complexity drops from $O(N^2)$ Memory to $O(N)$ . |
| **The Proof** | Ch 6–7 | **Mathematical Validity** | **Online Softmax:** You can compute Softmax incrementally (block-by-block) without seeing the full row at once, using rescaling factors . |
| **The Impact** | Ch 8–11 | **Performance Reality** | Even with re-computation (more FLOPs in backward pass), the speedup is significant because we avoid the slow HBM transfers . |

# FlashAttention: The Complete Technical Booklet

## From Memory Bottleneck to Hardware-Aware Algorithm

*Author’s Note:* This booklet is designed for Data Scientists and Systems Engineers who need to understand **why** FlashAttention is revolutionary, not just that it exists.

---

## Chapter 1 — The Hardware Foundation

### 1.1 The Two-Memory System

Your GPU is fundamentally split into two memory systems.

**SRAM (On-Chip Memory)**

* **Location:** Inside the silicon of the Compute Unit (Streaming Multiprocessor).
* **Size:** ~100 KB – 20 MB per SM (implementation dependent).
* **Speed:** Very high bandwidth (the "fast lane").
* **Latency:** Nanoseconds.
* **Analogy:** Your brain's working memory — what you are actively thinking about.

**HBM (High-Bandwidth Memory — Off-Chip)**

* **Location:** Outside the processor die, soldered near the GPU chip.
* **Size:** Tens of gigabytes (e.g., 40–80 GB).
* **Speed:** Lower bandwidth vs SRAM (e.g., ~1.5–3 TB/s on modern GPUs).
* **Latency:** Microseconds (slower).
* **Analogy:** Your library — large storage, but you must walk to fetch books.

### 1.2 The Golden Rule of GPU Computing

**The GPU compute units can only operate on data that is present in on-chip SRAM.**

Every byte must follow this journey:

1. Read from HBM into SRAM.
2. Compute in SRAM (or registers).
3. Write results back to HBM.

**Cost analysis (illustrative):**

* Moving 1 byte (HBM ↔ SRAM): large number of cycles (orders of magnitude more than an arithmetic op).
* Computing on 1 byte: a few cycles.

**Conclusion:** Data movement is far more expensive than arithmetic. Minimize data movement.

---

## Chapter 2 — Why Standard Attention Creates $O(N^2)$ Space

### 2.1 The Attention Mechanism

Scaled dot-product attention is:

$$
\text{Attention}(Q,K,V) ;=; \text{Softmax}!\Big(\frac{QK^\top}{\sqrt{d_k}}\Big);V.
$$

Where:

* $Q$ is the query matrix ($N\times d$).
* $K$ is the key matrix ($N\times d$).
* $V$ is the value matrix ($N\times d$).
* $N$ is the sequence length.
* $d$ is the head dimension.

Computing $QK^\top$ produces an $N\times N$ score matrix of pairwise similarities.

### 2.2 The “Every Token Talks to Every Token” Problem

Attention computes a score between every pair of tokens. For $N$ tokens there are $N^2$ interactions.

Example, $N=4$ (A, B, C, D) — 16 scores.

Scale examples:

* $N=2{,}048$: $2{,}048^2 \approx 4{,}194{,}304$ scores (≈16 MB in 4-byte floats).
* $N=32{,}000$: $32{,}000^2 \approx 1{,}024{,}000{,}000$ scores (≈4 GB).
* $N=100{,}000$: $100{,}000^2 = 10^{10}$ scores (≈40 GB).

### 2.3 The Space Bottleneck

Standard attention typically **materializes** the full $N\times N$ score matrix (and often the probability matrix) in memory. For very long contexts this becomes infeasible — e.g., a single 100K × 100K score matrix at 4 bytes per entry is ~40 GB, leaving almost no GPU memory for model weights or gradients.

Thus standard attention requires $O(N^2)$ memory, which does not scale to extremely long sequences.

---

## Chapter 3 — Standard Attention’s “Forced” Data Movement

### 3.1 The Loop of Forced I/O (Tiled Implementation)

Even with tiling (block size $B$), a naive implementation forces many HBM transfers:

1. **Load** a block of $Q$ and a block of $K$ from HBM to SRAM.
2. **Compute** $S_{\text{block}} = Q_{\text{block}} K_{\text{block}}^\top$ in SRAM (size $B\times B$).
3. **Write** that partial score block to HBM (evict from SRAM). (bottleneck)
4. Repeat until the entire $N\times N$ $S$ matrix is assembled in HBM.
5. To compute row-wise Softmax you must **read** the row’s $N$ scores back into SRAM.
6. Compute probabilities and **write** the $P$ blocks to HBM.
7. Read $P$ and $V$ to compute final output $O = PV$.

**Total I/O complexity:** $O(N^2)$ reads/writes. The GPU spends a large fraction of time on memory traffic instead of arithmetic.

(Concrete numeric example in the naive narrative: writing/reading multi-gigabyte intermediates for large $N$.)

---

## Chapter 4 — FlashAttention’s Solution

### 4.1 Core Insight

FlashAttention keeps intermediate results on-chip and fuses operations so that the algorithm:

* Performs the same arithmetic (same FLOPs), **but**
* Minimizes HBM traffic by never materializing the full $N\times N$ intermediate matrices.

Key principle: computation is cheap relative to bandwidth; trade extra arithmetic or recomputation for fewer memory accesses.

### 4.2 Three Pillars

**Pillar 1 — Kernel fusion**
Fuse QKᵀ computation, scaling, stable Softmax, and the final multiply with V into a single monolithic kernel. Load input blocks once, keep temporaries in SRAM, and write only final outputs.

**Pillar 2 — Online (block-by-block) Softmax**
Process scores in tiles while maintaining running statistics (running maximum and running sum) to perform numerically stable Softmax incrementally. This avoids needing the full row in SRAM at once.

**Pillar 3 — Recomputation in backpropagation**
In the backward pass, recompute $S$ (and related intermediates) on the fly from $Q,K,V$ rather than reading huge saved matrices from HBM. Recompute is faster than fetching gigabytes from off-chip memory.

These pillars reduce I/O complexity from $O(N^2)$ to effectively $O(N)$ extra HBM accesses while preserving exactness.

---

## Chapter 5 — FlashAttention’s Memory Structure

### 5.1 Chunk / Tile Dimensions (example)

Using $B = 128$ tokens and $d = 64$:

* Q block: $B\times d = 128 \times 64$ → ~32 KB (4-byte floats).
* K block: same as Q block → ~32 KB.
* V block: same → ~32 KB.
* Score tile (temporary): $B\times B = 128 \times 128$ → ~65 KB. **Ephemeral**; created, consumed (Softmax), and discarded — never written to HBM.
* Output accumulator: $B\times d$ → ~32 KB; persists for the row block until final write.

Total on-chip scratch per tile: on the order of a few 100 KB (constant w.r.t. sequence length).

### 5.2 Why FlashAttention is $O(N)$ Space (practical)

At any time FlashAttention keeps only a fixed number of blocks in SRAM (Q, one K, one V, the accumulator, and a score tile). This is a constant amount of scratch memory independent of $N$. The only HBM storage required beyond inputs is the final output $O$, so overall memory usage is $O(N)$ rather than $O(N^2)$.

### 5.3 Comparison: Standard vs FlashAttention

* **Standard Attention:** HBM must store $Q,K,V$ plus the full $N\times N$ $S$ and $P$ matrices → $O(N^2)$ memory.
* **FlashAttention:** HBM stores only $Q,K,V$ and output $O$. Intermediate tiles remain ephemeral in SRAM → $O(N)$ memory.

---

## Chapter 6 — Mathematical Proof: Why Accumulation Works

### 6.1 Softmax and Output

Softmax probabilities:

$$
P_{ij} = \frac{e^{S_{ij}}}{\sum_k e^{S_{ik}}}.
$$

Output for token $i$:

$$
O_i = \sum_j P_{ij},V_j ;=; \sum_j \frac{e^{S_{ij}}}{\sum_k e^{S_{ik}}};V_j.
$$

### 6.2 Partial / Block Processing Problem

Processing only a subset (a block) of $K$ and $V$ yields only partial scores $S_{i,\text{block}}$. The denominator in Softmax requires the sum over all blocks.

### 6.3 Online Softmax with Rescaling (procedure)

Process blocks sequentially while maintaining numerically stable running statistics.

**Process Block 1:**

* Let $m_1 = \max(\text{scores in Block 1})$.
* Compute unnormalized contributions: $p_{1,j} = e^{S_{ij} - m_1}$.
* Sum: $\ell_1 = \sum_{j\in\text{block1}} p_{1,j}$.
* Partial accumulator: $O_{\text{partial}} = \sum_{j\in\text{block1}} p_{1,j},V_j$ (note: unnormalized).

**Process Block 2:**

* Let $m_2$ be the max over seen scores (block1 ∪ block2).
* The old contributions must be rescaled to the new reference: define $\alpha = e^{m_1 - m_2}$.
* Rescale old normalizer: $\ell_1' = \ell_1 \cdot \alpha$.
* New block normalizer: $\ell_2 = \sum_{j\in\text{block2}} e^{S_{ij} - m_2}$.
* Combined normalizer: $\ell_{\text{new}} = \ell_1' + \ell_2$.
* Update accumulator to maintain $O_{\text{partial}}$ in the same normalization reference; then add block2 contributions appropriately.

Repeat for all blocks.

### 6.4 Correctness

After all blocks, the final accumulator equals

$$
O_i = \frac{\sum_{j=1}^N e^{S_{ij} - m_{\text{global}}},V_j}{\sum_{j=1}^N e^{S_{ij} - m_{\text{global}}}},
$$

where $m_{\text{global}} = \max_j S_{ij}$. Because Softmax is invariant to additive shifts, this equals the standard Softmax result. Therefore the online accumulation is **exact** (no approximation), and numerically stable (subtracting max prevents overflow).

---

## Chapter 7 — Numerical Stability: Why We Subtract the Max

### 7.1 Overflow Risk

Float32 has limited dynamic range (~$10^{38}$). Exponentiating large positive values (e.g., $e^{1000}$) overflows. In attention scores, differences can be large.

### 7.2 Stable Softmax by Subtracting the Max

Softmax is shift invariant:

$$
\frac{e^a}{e^a+e^b} ;=; \frac{e^{a-m}}{e^{a-m}+e^{b-m}} \quad\text{for any } m.
$$

Practical step: subtract $m_i=\max_k S_{ik}$ from all scores in row $i$ (or use running $m$ per block). This makes the largest exponent $e^0=1$ and prevents overflow. The online algorithm uses this technique per block and then rescales as needed.

---

## Chapter 8 — Time Complexity vs Runtime

### 8.1 Same FLOPs, different runtime

Both standard attention and FlashAttention perform $O(N^2)$ arithmetic operations (FLOPs). The difference is **where time is spent**: arithmetic vs memory traffic.

### 8.2 Runtime is dominated by data movement

* **Standard attention**: memory-bound — many HBM transfers; compute units idle waiting on data.
* **FlashAttention**: compute-heavy on-chip; HBM transfers minimized; higher sustained GPU utilization.

### 8.3 Practical physics (illustrative numbers)

* Example numbers (illustrative): compute throughput on modern GPUs can be hundreds of TFLOPS, while memory bandwidth is in TB/s. Moving data is often tens to hundreds of times more costly than additional arithmetic on on-chip data. Thus recomputing is frequently cheaper than reloading large intermediates from HBM.

**Rule of thumb:** Prefer extra on-chip computation if it avoids off-chip memory traffic.

---

## Chapter 9 — On-Chip Memory (SRAM) Explained

### 9.1 Physical reality

* SRAM (shared memory, L1) resides inside the GPU die, very close to ALUs — nanosecond latencies and extremely high bandwidth.
* HBM is off-die — longer wires, microsecond-class latencies and lower effective bandwidth.

### 9.2 Why SRAM is small

Latency and bandwidth scale with physical proximity; to maintain nanosecond access times the on-chip memory must be small. There is a tradeoff: small but very fast SRAM vs large but slower HBM.

### 9.3 Cache / memory hierarchy (typical view)

| Level                     |   Typical size | Purpose                    |
| ------------------------- | -------------: | -------------------------- |
| Registers                 |     per thread | fastest per-thread storage |
| Shared memory (SRAM / L1) | ~100 KB per SM | programmer-managed scratch |
| L2 cache                  |       a few MB | GPU-wide cache             |
| HBM (global)              |     tens of GB | main memory                |

FlashAttention maximizes use of shared memory to keep intermediates on-chip.

---

## Chapter 10 — Summary Table

| Aspect                  |     Standard Attention |                  FlashAttention |
| ----------------------- | ---------------------: | ------------------------------: |
| Time complexity (FLOPs) |               $O(N^2)$ |                        $O(N^2)$ |
| Space complexity        |               $O(N^2)$ |                          $O(N)$ |
| I/O complexity          |               $O(N^2)$ | $O(N)$ (effective HBM accesses) |
| GPU utilization         |  Low (bandwidth-bound) |            High (compute-bound) |
| Intermediates           |          Stored in HBM |               Ephemeral in SRAM |
| Softmax strategy        |                 Global |         Online (block-by-block) |
| Backward pass           |    Read saved matrices |            Recompute on the fly |
| Result accuracy         |                  Exact |                           Exact |
| Best use case           | Short sequences (< 1k) |           Long sequences (≫ 1k) |

---

## Chapter 11 — Final Takeaways

1. **FlashAttention is exact.** It computes the same attention result as standard attention.
2. **Bandwidth is the bottleneck, not arithmetic.** Data movement dominates runtime on modern GPUs.
3. **Online Softmax + kernel fusion + recomputation** enable exact, memory-efficient attention.
4. **Recomputation is often cheaper than reloading from HBM.** Use arithmetic to avoid off-chip traffic.
5. **Hardware-aware algorithm design wins.** Algorithms that are aware of SRAM vs HBM constraints deliver real performance.

**Next steps:** implement or use a proven FlashAttention implementation (CUDA/Triton), profile your model with/without it, and measure context window scaling.

---

## Appendix — Key Formulas at a Glance

**Standard Softmax (row $i$):**

$$
O_i ;=; \sum_{j=1}^N \frac{\exp(S_{ij})}{\sum_{k=1}^N \exp(S_{ik})};V_j
$$

**Numerically stable Softmax (subtract max):**

$$
O_i ;=; \sum_{j=1}^N \frac{\exp(S_{ij} - m_i)}{\sum_{k=1}^N \exp(S_{ik} - m_i)};V_j,
\qquad
m_i = \max_k S_{ik}.
$$

**FlashAttention — blockwise (online) accumulation:**

When combining an old partial accumulator with a new block:

$$
O_i \leftarrow
\frac{\ell_{\text{old}};O_i ;+; \displaystyle\sum_{j\in\text{new block}} e^{S_{ij}-m_{\text{new}}},V_j}
{\ell_{\text{old}} + \ell_{\text{new}}},
$$

where

$$
\ell_{\text{old}} = \sum_{j\in\text{old blocks}} e^{S_{ij}-m_{\text{old}}}, \quad
\ell_{\text{new}} = \sum_{j\in\text{new block}} e^{S_{ij}-m_{\text{new}}}.
$$

This yields the correct final Softmax weighted sum after all blocks are processed.

**I/O complexity (sketch):**

$$
\Theta!\left(\frac{N^2 d^2}{M}\right)
$$

where $M$ denotes available on-chip scratch (in elements). For reasonable tile sizes this simplifies to effectively $O(N)$ additional HBM accesses.

---

**End of Booklet**

---

