# CS336 Assignment 2 (systems): Systems and Parallelism

## 1.1 Profiling and Benchmarking

maybe useful: https://zhuanlan.zhihu.com/p/1927445674070347876

### Problem (benchmarking_script): 4 points

Write a script to perform basic end-to-end benchmarking of the forward and backward passes in your model. Specifically, your script should support the following:

- Given hyperparameters (e.g., number of layers), initialize a model.
- Generate a random batch of data.
- Run $w$ warm-up steps (before you start measuring time), then time the execution of $n$ steps (either only forward, or both forward and backward passes, depending on an argument). For timing, you can use the Python `timeit` module (e.g., either using the `timeit` function, or using `timeit.default_timer()`, which gives you the system’s highest resolution clock, thus a better default for benchmarking than `time.time()`).
- Call `torch.cuda.synchronize()` after each step.

**Deliverable**: A script that will initialize a basics Transformer model with the given hyperparameters, create a random batch of data, and time forward and backward passes.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

[benchmark.py](./scripts/benchmark.py)

</div>

---

Time the forward and backward passes for the model sizes described in $\S 1.1.2$. Use $5$ warmup steps and compute the average and standard deviation of timings over 10 measurement steps.

How long does a forward pass take? How about a backward pass? Do you see high variability across measurements, or is the standard deviation small?

**Deliverable**: A 1-2 sentence response with your timings.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

**Small model**: A forward pass takes ~25ms, and a backward pass takes ~69ms.

**Medium model**: A forward pass takes ~79ms, and a backward pass takes ~202ms.

Larger models result in a "CUDA out of memory" error.

The measurements are very stable.

</div>

---

One caveat of benchmarking is not performing the warm-up steps. Repeat your analysis without the warm-up steps. How does this affect your results? Why do you think this happens? Also try to run the script with 1 or 2 warm-up steps. Why might the result still be different?

**Deliverable**: A 2-3 sentence response.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

Without warm-up (warm-up=0), the first run is much slower (70.95 ms vs. 25–26 ms later) and has high variance. 

This is due to one-time costs like CUDA context init, memory allocation, and JIT compilation. 

Just 1–2 warm-up runs greatly reduce the time, showing most overhead is paid early.

</div>

---

### Problem (nsys_profile): 5 points

Profile your forward pass, backward pass, and optimizer step using `nsys` with each of the model sizes described in Table 1 and context lengths of `128`, `256`, `512` and `1024` (you may run out of memory with some of these context lengths for the larger models, in which case just note it in your report).

small_fwd
![small_fwd](./data/imgs/small_fwd.png)

ctx1024_fwd
![ctx1024_fwd](./data/imgs/ctx1024_fwd.png)


(a) What is the total time spent on your forward pass? Does it match what we had measured before with the Python standard library?

**Deliverable**: A 1-2 sentence response.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

The total time spent on a forward pass, measured by the Python script during profiling, was ~14.09 ms for the small_fwd model and ~23.33 ms for the ctx1024_fwd model. 

This timing is consistent with the duration of the corresponding NVTX time ranges observed in the Nsight Systems profiler's timeline view.

</div>

---

(b) What CUDA kernel takes the most cumulative GPU time during the forward pass? How many times is this kernel invoked during a single forward pass of your model? Is it the same kernel that takes the most runtime when you do both forward and backward passes? (Hint: look at the “CUDA GPU Kernel Summary” under “Stats Systems View”, and filter using NVTX ranges to identify which parts of the model are responsible for which kernels.)

**Deliverable**: A 1-2 sentence response.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

The CUDA kernel that takes the most cumulative GPU time is for **GEMM**, general matrix multiplication (ampere_sgemm_128x64_tn, ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_tn).

For the small_fwd model, this kernel was invoked 42 times during a single forward pass (calculated from 630 total invocations over 15 steps). This same GEMM kernel type also dominates the runtime when performing both forward and backward passes.

</div>

---

(c) Although the vast majority of `FLOPs` take place in matrix multiplications, you will notice that several other kernels still take a non-trivial amount of the overall runtime. What other kernels besides matrix multiplies do you see accounting for non-trivial CUDA runtime in the forward pass?

**Deliverable**: A 1-2 sentence response.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

Other non-trivial CUDA runtime comes mainly from **element-wise operations** kernels(e.g., `at::native::elementwise_kernel`, `vectorized_elementwise_kernel`) used in activations, residual adds, and LayerNorm.ctions, residual connections, and parts of the Layer Normalization computation.

</div>

---

(d) Profile running one complete training step with your implementation of AdamW (i.e., the forward pass, computing the loss and running a backward pass, and finally an optimizer step, as you’d do during training). How does the fraction of time spent on matrix multiplication change, compared to doing inference (forward pass only)? How about other kernels?

**Deliverable**: A 1-2 sentence response.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

Compared to forward-only, a full training step (forward + backward + optimizer) sees GEMM’s share of time drop, while **element-wise** kernels take a larger share. 

This is because AdamW is memory-bound and made of element-wise ops, adding step time without adding GEMMs, thus diluting GEMM’s percentage.

</div>

---

small_fwd_annot
![small_fwd_annot](./data/imgs/small_fwd_annot.png)

small_fwbw_annot
![small_fwbw_annot](./data//imgs/small_fwbw_annot.png)

(e) Compare the runtime of the softmax operation versus the matrix multiplication operations within the self-attention layer of your model during a forward pass. How does the difference in runtimes compare to the difference in FLOPs?

**Deliverable**: A 1-2 sentence response.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

In self-attention, GEMM takes much longer than Softmax, but the gap is far smaller than their FLOP difference. 

This is because GEMM is **compute-bound** with high arithmetic intensity, fully utilizing the GPU, while Softmax is **memory-bound**, limited by data access, making its runtime disproportionately high relative to its low FLOPs.

</div>

---

### Problem (mixed_precision_accumulation): 1 point

Run the following code and commment on the (accuracy of the) results.

**Deliverable**: A 2-3 sentence response.

In [1]:
import torch

s = torch.tensor(0,dtype=torch.float32)
for i in range(1000):
	s += torch.tensor(0.01,dtype=torch.float32)
print(s)
s = torch.tensor(0,dtype=torch.float16)
for i in range(1000):
	s += torch.tensor(0.01,dtype=torch.float16)
print(s)
s = torch.tensor(0,dtype=torch.float32)
for i in range(1000):
	s += torch.tensor(0.01,dtype=torch.float16)
print(s)
s = torch.tensor(0,dtype=torch.float32)
for i in range(1000):
	x = torch.tensor(0.01,dtype=torch.float16)
	s += x.type(torch.float32)
print(s)

tensor(10.0001)
tensor(9.9531, dtype=torch.float16)
tensor(10.0021)
tensor(10.0021)


---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

The first result (`tensor(10.0001)`) shows accurate FP32 accumulation. 

The second (`tensor(9.9531, dtype=torch.float16)`) shows significant precision loss from FP16 rounding errors over many iterations. 

Results 3 and 4 (`tensor(10.0021)`) demonstrate that FP32 accumulation—whether FP16 values are auto-promoted (case 3) or explicitly converted (case 4)—yields identical accuracy, as PyTorch promotes types automatically in mixed-precision addition.

</div>

---

### Problem (benchmarking_mixed_precision): 2 points

Consider the following model:

```python
class ToyModel(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 10, bias=False)
        self.ln = nn.LayerNorm(10)
        self.fc2 = nn.Linear(10, out_features, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.ln(x)
        x = self.fc2(x)
        return x
```

Suppose we are training the model on a GPU and that the model parameters are originally in FP32. We’d like to use autocasting mixed precision with FP16. What are the data types of:

  * the model parameters within the autocast context,
  * the output of the first feed-forward layer (`ToyModel.fc1`),
  * the output of layer norm (`ToyModel.ln`),
  * the model’s predicted logits,
  * the loss,
  * and the model’s gradients?

**Deliverable**: The data types for each of the components listed above.

---

``` sh
--- Running with autocast ---
Input dtype: torch.float32
After fc1 and relu dtype: torch.float16
After ln dtype: torch.float32
After fc2 (output) dtype: torch.float16

--- Final Output Tensor Properties ---
Final y dtype: torch.float16
Final y device: cuda:0
```

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

Data types:
- **model parameters**: FP32
- **output of fc1**: FP16, there is a matrix multiplication.
- **output of ln**: FP32, :ayerNorm is kept in FP32 by autocast to maintain numerical stability.
- **logits**: FP16, the last layer is a connected layer.
- **loss**: FP32, requires higher precision.
- **gradients**: FP32, same as parameters.

</div>

---

You should have seen that FP16 mixed precision autocasting treats the layer normalization layer differently than the feed-forward layers. What parts of layer normalization are sensitive to mixed precision? If we use BF16 instead of FP16, do we still need to treat layer normalization differently? Why or why not?

**Deliverable**: A 2-3 sentence response.

In [3]:
import torch
import struct
import numpy as np

# A simple number that is exactly representable in these formats
number = 9.75

fp32_bytes = struct.pack('!f', number)
fp32_bin = ''.join(f'{byte:08b}' for byte in fp32_bytes)
fp16_bytes = struct.pack('!e', np.float16(number))
fp16_bin = ''.join(f'{byte:08b}' for byte in fp16_bytes)

# BF16 (16 bits total) is the first 16 bits of FP32
bf16_bin = fp32_bin[:16]


print(f"Original Number: {number}\n")
# FP32: 1 Sign, 8 Exponent, 23 Mantissa
print("FP32 (float)")
print(f"Value:  {torch.tensor(number, dtype=torch.float32).item()}")
print(f"Binary: {fp32_bin[0]} {fp32_bin[1:9]} {fp32_bin[9:]}")
print(f"Format: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM\n")

# FP16: 1 Sign, 5 Exponent, 10 Mantissa
print("FP16 (half)")
print(f"Value:  {torch.tensor(number, dtype=torch.float16).item()}")
print(f"Binary: {fp16_bin[0]} {fp16_bin[1:6]} {fp16_bin[6:]}")
print(f"Format: S EEEEE MMMMMMMMMM\n")

# BF16: 1 Sign, 8 Exponent, 7 Mantissa
print("BF16 (bfloat16)")
print(f"Value:  {torch.tensor(number, dtype=torch.bfloat16).item()}")
print(f"Binary: {bf16_bin[0]} {bf16_bin[1:9]} {bf16_bin[9:]}")
print(f"Format: S EEEEEEEE MMMMMMM\n")

Original Number: 9.75

FP32 (float)
Value:  9.75
Binary: 0 10000010 00111000000000000000000
Format: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM

FP16 (half)
Value:  9.75
Binary: 0 10010 0011100000
Format: S EEEEE MMMMMMMMMM

BF16 (bfloat16)
Value:  9.75
Binary: 0 10000010 0011100
Format: S EEEEEEEE MMMMMMM



---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

LayerNorm is sensitive in FP16 because computing the mean and variance can lose precision due to FP16’s small mantissa, and the division/sqrt can overflow or underflow.

BF16 has the same large exponent range as FP32, so it avoids those stability problems even with fewer mantissa bits.

</div>

---

Modify your benchmarking script to optionally run the model using mixed precision with BF16. Time the forward and backward passes with and without mixed-precision for each language model size described in $\S 1.1.2$. Compare the results of using full vs. mixed precision, and comment on any trends as model size changes. You may find the `nullcontext` no-op context manager to be useful.

**Deliverable**: A 2-3 sentence response with your timings and commentary.

---

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">



</div>

---


### Problem (memory_profiling): 4 points

Profile your forward pass, backward pass, and optimizer step of the 2.7B model from Table 1 with context lengths of 128, 256 and 512.

-----

<span style="background-color: #29B6F6; color: black">

</span>

-----

(a) Add an option to your profiling script to run your model through the memory profiler. It may be helpful to reuse some of your previous infrastructure (e.g., to activate mixed-precision, load specific model sizes, etc). Then, run your script to get a memory profile of the 2.7B model when either doing inference only (just forward pass) or a full training step. How do your memory timelines look like? Can you tell which stage is running based on the peaks you see?

**Deliverable**: Two images of the “Active memory timeline” of a 2.7B model, from the `memory_viz` tool: one for the forward pass, and one for running a full training step (forward and backward passes, then optimizer step), and a 2-3 sentence response.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(b) What is the peak memory usage of each context length when doing a forward pass? What about when doing a full training step?

**Deliverable**: A table with two numbers per context length.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(c) Find the peak memory usage of the 2.7B model when using mixed-precision, for both a forward pass and a full optimizer step. Does mixed-precision significantly affect memory usage?

**Deliverable**: A 2-3 sentence response.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(d) Consider the 2.7B model. At our reference hyperparameters, what is the size of a tensor of activations in the Transformer residual stream, in single-precision? Give this size in MB (i.e., divide the number of bytes by $1024^2$).

**Deliverable**: A 1-2 sentence response with your derivation.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(e) Now look closely at the “Active Memory Timeline” from `pytorch.org/memory_viz` of a memory snapshot of the 2.7B model doing a forward pass. When you reduce the “Detail” level, the tool hides the smallest allocations to the corresponding level (e.g., putting “Detail” at 10% only shows the 10% largest allocations). What is the size of the largest allocations shown? Looking through the stack trace, can you tell where those allocations come from?

**Deliverable**: A 1-2 sentence response.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

## 1.2 Optimizing Attention with FlashAttention-2

### Problem (pytorch_attention): 2 points

Benchmark your attention implementation at different scales. Write a script that will:

1. Fix the batch size to 8 and don’t use multihead attention (i.e. remove the head dimension).
2. Iterate through the cartesian product of [16, 32, 64, 128] for the head embedding dimension $d_{model}$, and [256, 1024, 4096, 8192, 16384] for the sequence length.
3. Create random inputs $Q, K, V$ for the appropriate size.
4. Time 100 forward passes through attention using the inputs.
5. Measure how much memory is in use before the backward pass starts, and time 100 backward passes.
(f) Make sure to warm up, and to call `torch.cuda.synchronize()` after each forward/backward pass.

Report the timings (or out-of-memory errors) you get for these configurations. At what size do you get out-of-memory errors? Do the accounting for the memory usage of attention in one of the smallest configurations you find that runs out of memory (you can use the equations for memory usage of Transformers from Assignment 1). How does the memory saved for backward change with the sequence length? What would you do to eliminate this memory cost?

**Deliverable**: A table with your timings, your working out for the memory usage, and a 1-2 paragraph response.

-----

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">


**Table of Results**

| d_model | seq_len | Status | fwd_mean_ms | bwd_mean_ms | mem_before_bwd |
| -------- | -------- | ------ | ------------- | ------------- | ---------------- |
| 16       | 256      | OK     | 0.16          | 0.42          | 18.75 MB         |
| 16       | 1024     | OK     | 0.28          | 0.69          | 50.25 MB         |
| 16       | 4096     | OK     | 3.76          | 7.70          | 536.25 MB        |
| 16       | 8192     | OK     | 14.91         | 30.46         | 2.03 GB          |
| 16       | 16384    | OOM    | –             | –             | –                |
| 32       | 256      | OK     | 0.22          | 0.63          | 19.25 MB         |
| 32       | 1024     | OK     | 0.31          | 0.69          | 52.25 MB         |
| 32       | 4096     | OK     | 3.82          | 7.80          | 544.25 MB        |
| 32       | 8192     | OK     | 15.04         | 30.65         | 2.05 GB          |
| 32       | 16384    | OOM    | –             | –             | –                |
| 64       | 256      | OK     | 0.23          | 0.63          | 20.25 MB         |
| 64       | 1024     | OK     | 0.29          | 0.73          | 56.25 MB         |
| 64       | 4096     | OK     | 3.88          | 7.86          | 560.25 MB        |
| 64       | 8192     | OK     | 15.20         | 30.94         | 2.08 GB          |
| 64       | 16384    | OOM    | –             | –             | –                |
| 128      | 256      | OK     | 0.24          | 0.67          | 22.25 MB         |
| 128      | 1024     | OK     | 0.32          | 0.74          | 64.25 MB         |
| 128      | 4096     | OK     | 4.19          | 8.26          | 592.25 MB        |
| 128      | 8192     | OK     | 16.57         | 32.41         | 2.14 GB          |
| 128      | 16384    | OOM    | –             | –             | –                |

OOM occurs at `seq_len = 16384` for all `d_model` values tested.

For smallest OOM case: d_model=16, seq_len=16384, batch=8):
Attention stores Q, K, V, attention weights, and intermediate grads.
Forward activations:

$$
3 \times (B \times L \times d) \times 4 \text{ bytes} \approx 3 \times (8 \times 16384 \times 16) \times 4 \approx 25 \text{ MB}
$$

Attention scores:

$$
(B \times L \times L) \times 4 \text{ bytes} \approx (8 \times 16384^2) \times 4 \approx 8.59 \text{ GB}
$$

The quadratic term in $L$ dominates, leading to OOM.

**Observation**
Memory grows \~$O(L^2)$ due to the score matrix. Backward roughly doubles activation memory, so longer sequences drastically increase peak usage.

**How to Reduce Memory**
Use memory-efficient attention, maybe chunking, recomputation of $QK^T$ in backward, to avoid storing the full attention matrix.

</div>

-----

## 1.3 Benchmarking JIT-Compiled Attention

maybe useful: https://zhuanlan.zhihu.com/p/1927458713821749678

### Problem (torch_compile): 2 points

(a) Extend your attention benchmarking script to include a compiled version of your PyTorch implementation of attention, and compare its performance to the uncompiled version with the same configuration as the `pytorch_attention` problem above.

**Deliverable**: A table comparing your forward and backward pass timings for your compiled attention module with the uncompiled version from the `pytorch_attention` problem above.

-----

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

| method | d_model | seq_len | status | fwd_mean_ms | fwd_std_ms | bwd_mean_ms | bwd_std_ms | mem_before_bwd |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| eager | 16 | 256 | OK | 0.131 | 0.037 | 0.389 | 0.133 | 18.75 MB |
| compiled | 16 | 256 | OK | 0.273 | 0.069 | 0.384 | 0.032 | 0.38 MB |
| eager | 16 | 1024 | OK | 0.262 | 0.009 | 0.621 | 0.010 | 50.25 MB |
| compiled | 16 | 1024 | OK | 0.527 | 0.034 | 0.674 | 0.012 | 1.50 MB |
| eager | 16 | 4096 | OK | 3.764 | 0.127 | 7.675 | 0.055 | 536.25 MB |
| compiled | 16 | 4096 | OK | 2.725 | 0.065 | 4.967 | 0.060 | 6.00 MB |
| eager | 16 | 8192 | OK | 14.851 | 0.126 | 30.323 | 0.202 | 2.03 GB |
| compiled | 16 | 8192 | OK | 10.344 | 0.047 | 20.212 | 0.151 | 12.00 MB |
| eager | 16 | 16384 | OOM |  |  |  |  | OOM |
| compiled | 16 | 16384 | OOM |  |  |  |  | OOM |
| eager | 32 | 256 | OK | 0.185 | 0.041 | 0.628 | 0.023 | 19.25 MB |
| compiled | 32 | 256 | OK | 0.545 | 0.103 | 0.509 | 0.132 | 0.75 MB |
| eager | 32 | 1024 | OK | 0.278 | 0.010 | 0.701 | 0.136 | 52.25 MB |
| compiled | 32 | 1024 | OK | 0.598 | 0.055 | 0.666 | 0.012 | 3.00 MB |
| eager | 32 | 4096 | OK | 3.785 | 0.022 | 7.740 | 0.073 | 544.25 MB |
| compiled | 32 | 4096 | OK | 3.398 | 0.108 | 5.650 | 0.081 | 12.00 MB |
| eager | 32 | 8192 | OOM |  |  |  |  | OOM |
| compiled | 32 | 8192 | OK | 14.697 | 0.075 | 23.801 | 0.131 | 24.00 MB |
| eager | 32 | 16384 | OOM |  |  |  |  | OOM |
| compiled | 32 | 16384 | OOM |  |  |  |  | OOM |
| eager | 64 | 256 | OK | 0.217 | 0.017 | 0.667 | 0.132 | 20.25 MB |
| compiled | 64 | 256 | OK | 0.586 | 0.074 | 0.452 | 0.097 | 1.50 MB |
| eager | 64 | 1024 | OK | 0.306 | 0.165 | 0.691 | 0.056 | 56.25 MB |
| compiled | 64 | 1024 | OK | 0.623 | 0.044 | 0.748 | 0.108 | 6.00 MB |
| eager | 64 | 4096 | OK | 3.840 | 0.029 | 7.832 | 0.068 | 560.25 MB |
| compiled | 64 | 4096 | OK | 2.824 | 0.017 | 5.095 | 0.046 | 24.00 MB |
| eager | 64 | 8192 | OK | 15.147 | 0.035 | 30.779 | 0.045 | 2.08 GB |
| compiled | 64 | 8192 | OK | 10.771 | 0.214 | 19.169 | 0.090 | 48.00 MB |
| eager | 64 | 16384 | OOM |  |  |  |  | OOM |
| compiled | 64 | 16384 | OK | 40.929 | 0.321 | 79.471 | 0.892 | 96.00 MB |
| eager | 128 | 256 | OK | 0.125 | 0.038 | 0.623 | 0.179 | 22.25 MB |
| compiled | 128 | 256 | OK | 0.577 | 0.084 | 0.558 | 0.035 | 3.00 MB |
| eager | 128 | 1024 | OK | 0.295 | 0.008 | 0.697 | 0.009 | 64.25 MB |
| compiled | 128 | 1024 | OK | 0.650 | 0.036 | 0.763 | 0.030 | 12.00 MB |
| eager | 128 | 4096 | OOM |  |  |  |  | OOM |
| compiled | 128 | 4096 | OK | 3.247 | 0.036 | 5.464 | 0.083 | 48.00 MB |
| eager | 128 | 8192 | OOM |  |  |  |  | OOM |
| compiled | 128 | 8192 | OK | 12.300 | 0.076 | 20.624 | 0.181 | 96.00 MB |
| eager | 128 | 16384 | OOM |  |  |  |  | OOM |
| compiled | 128 | 16384 | OK | 47.612 | 0.627 | 85.768 | 0.654 | 192.00 MB |

</div>

-----

(b) Now, compile your entire Transformer model in your end-to-end benchmarking script. How does the performance of the forward pass change? What about the combined forward and backward passes and optimizer steps?

**Deliverable**: A table comparing your vanilla and compiled Transformer model.

-----

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

Experiment on a Transformer with **d_model=512, 6 layers, 8 heads, d_ff=2048, batch size 4, and context length 512**.


| Method   | Mode        | Time per step (ms) | Speedup | Mem before bwd |
| -------- | ----------- | -----------------: | ------: | -------------: |
| Eager    | fwd         |     13.297 ± 0.090 |       – |        1.31 GB |
| Compiled | fwd         |      6.894 ± 0.208 | \~1.93× |        1.06 GB |
| Eager    | fwd+bwd+opt |     42.674 ± 3.807 |       – |        1.72 GB |
| Compiled | fwd+bwd+opt |     20.463 ± 0.060 | \~2.09× |      524.83 MB |

Compiling the full Transformer nearly halves runtime for both forward-only and full training steps.

Memory use is also significantly reduced, especially in full training.

</div>

-----

### 1.3.1 Example - Weighted Sum

![](./data/imgs/weighted_sum.png)

In [None]:
import torch
import triton
import triton.language as tl
from einops import rearrange

def cdiv(a, b):
    return (a + b - 1) // b

@triton.jit
def weighted_sum_fwd(
    x_ptr, weight_ptr, # Input pointers
    output_ptr, # Output pointer
    x_stride_row, x_stride_dim, # Strides tell us how to move one element in each axis of a tensor
    weight_stride_dim, # Likely 1
    output_stride_row, # Likely 1
    ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr, # Tile shapes must be known at compile time
):
    # Each instance will compute the weighted sum of a tile of rows of x.
    # `tl.program_id` gives us a way to check which thread block we're running in
    row_tile_idx = tl.program_id(0)

    # Block pointers give us a way to select from an ND region of memory
    # and move our selection around.
    # The block pointer must know:
    # - The pointer to the first element of the tensor
    # - The overall shape of the tensor to handle out-of-bounds access
    # - The strides of each dimension to use the memory layout properly
    # - The ND coordinates of the starting block, i.e., "offsets"
    # - The block shape to use load/store at a time
    # - The order of the dimensions in memory from major to minor
    # axes (= np.argsort(strides)) for optimizations, especially useful on H100
    
    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(ROWS, D,),
        strides=(x_stride_row, x_stride_dim),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )
    
    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(weight_stride_dim,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
    )
    
    output_block_ptr = tl.make_block_ptr(
        output_ptr,
        shape=(ROWS,),
        strides=(output_stride_row,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )
    
    # Initialize a buffer to write to
    output = tl.zeros((ROWS_TILE_SIZE,), dtype=tl.float32)
    
    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        # Load the current block pointer
        # Since ROWS_TILE_SIZE might not divide ROWS, and D_TILE_SIZE might not divide D,
        # we need boundary checks for both dimensions
        row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero") # (ROWS_TILE_SIZE, D_TILE_SIZE)
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero") # (D_TILE_SIZE,)
        
        # Compute the weighted sum of the row.
        output += tl.sum(row * weight[None, :], axis=1)
        
        # Move the pointers to the next tile.
        # These are (rows, columns) coordinate deltas
        x_block_ptr = x_block_ptr.advance((0, D_TILE_SIZE)) # Move by D_TILE_SIZE in the last dimension
        weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,)) # Move by D_TILE_SIZE
    
    # Write output to the output block pointer (a single scalar per row).
    # Since ROWS_TILE_SIZE might not divide ROWS, we need boundary checks
    tl.store(output_block_ptr, output, boundary_check=(0,))

@triton.jit
def weighted_sum_backward(
    x_ptr, weight_ptr,  # Input
    grad_output_ptr,  # Grad input
    grad_x_ptr, partial_grad_weight_ptr,  # Grad outputs
    stride_xr, stride_xd,
    stride_wd,
    stride_gr,
    stride_gxr, stride_gxd,
    stride_gwb, stride_gwd,
    NUM_ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr,
):
    row_tile_idx = tl.program_id(0)
    n_row_tiles = tl.num_programs(0)

    # Inputs
    grad_output_block_ptr = tl.make_block_ptr(
        grad_output_ptr,
        shape=(NUM_ROWS,), strides=(stride_gr,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )

    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(NUM_ROWS, D,), strides=(stride_xr, stride_xd),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )

    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,), strides=(stride_wd,),
        offsets=(0,), block_shape=(D_TILE_SIZE,),
        order=(0,),
    )

    grad_x_block_ptr = tl.make_block_ptr(
        grad_x_ptr,
        shape=(NUM_ROWS, D,), strides=(stride_gxr, stride_gxd),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )

    partial_grad_weight_block_ptr = tl.make_block_ptr(
        partial_grad_weight_ptr,
        shape=(n_row_tiles, D,), strides=(stride_gwb, stride_gwd),
        offsets=(row_tile_idx, 0),
        block_shape=(1, D_TILE_SIZE),
        order=(1, 0),
    )

    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        grad_output = tl.load(grad_output_block_ptr, boundary_check=(0,), padding_option="zero")  # (ROWS_TILE_SIZE,)

        # Outer product for grad_x
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")  # (D_TILE_SIZE,)
        grad_x_row = grad_output[:, None] * weight[None, :]
        tl.store(grad_x_block_ptr, grad_x_row, boundary_check=(0, 1))

        # Reduce as many rows as possible for the grad_weight result
        row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")  # (ROWS_TILE_SIZE, D_TILE_SIZE)
        grad_weight_row = tl.sum(row * grad_output[:, None], axis=0, keep_dims=True)
        tl.store(partial_grad_weight_block_ptr, grad_weight_row, boundary_check=(1,))  # Never out of bounds for dim 0

        # Move the pointers to the next tile along D
        x_block_ptr = x_block_ptr.advance((0, D_TILE_SIZE))
        weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,))
        partial_grad_weight_block_ptr = partial_grad_weight_block_ptr.advance((0, D_TILE_SIZE))
        grad_x_block_ptr = grad_x_block_ptr.advance((0, D_TILE_SIZE))

class WeightedSumFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight):
        # Cache x and weight to be used in the backward pass, when we
        # only receive the gradient wrt. the output tensor, and
        # need to compute the gradients wrt. x and weight.
        D, output_dims = x.shape[-1], x.shape[:-1]

        # Reshape input tensor to 2D
        input_shape = x.shape
        x = rearrange(x, "... d -> (...) d")

        ctx.save_for_backward(x, weight)

        assert len(weight.shape) == 1 and weight.shape[0] == D, "Dimension mismatch"
        assert x.is_cuda and weight.is_cuda, "Expected CUDA tensors"
        assert x.is_contiguous(), "Our pointer arithmetic will assume contiguous x"

        ctx.D_TILE_SIZE = triton.next_power_of_2(D) // 16  # Roughly 16 loops through the embedding dimension
        ctx.ROWS_TILE_SIZE = 16  # Each thread processes 16 batch elements at a time
        ctx.input_shape = input_shape

        # Need to initialize empty result tensor. Note that these elements are not necessarily 0!
        y = torch.empty(output_dims, device=x.device)

        # Launch our kernel with n instances in our 1D grid.
        n_rows = y.numel()
        weighted_sum_fwd[(cdiv(n_rows, ctx.ROWS_TILE_SIZE),)](
            x, weight,
            y,
            x.stride(0), x.stride(1),
            weight.stride(0),
            y.stride(0),
            ROWS=n_rows, D=D,
            ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE, D_TILE_SIZE=ctx.D_TILE_SIZE,
        )

        return y.view(input_shape[:-1])

    @staticmethod
    def backward(ctx, grad_out):
        x, weight = ctx.saved_tensors
        ROWS_TILE_SIZE, D_TILE_SIZE = ctx.ROWS_TILE_SIZE, ctx.D_TILE_SIZE  # These don't have to be the same
        n_rows, D = x.shape

        # Our strategy is for each thread block to first write to a partial buffer,
        # then we reduce over this buffer to get the final gradient.
        partial_grad_weight = torch.empty((cdiv(n_rows, ROWS_TILE_SIZE), D), device=x.device, dtype=x.dtype)
        grad_x = torch.empty_like(x)

        weighted_sum_backward[(cdiv(n_rows, ROWS_TILE_SIZE),)](
            x, weight,
            grad_out,
            grad_x, partial_grad_weight,
            x.stride(0), x.stride(1),
            weight.stride(0),
            grad_out.stride(0),
            grad_x.stride(0), grad_x.stride(1),
            partial_grad_weight.stride(0), partial_grad_weight.stride(1),
            NUM_ROWS=n_rows, D=D,
            ROWS_TILE_SIZE=ROWS_TILE_SIZE, D_TILE_SIZE=D_TILE_SIZE,
        )
        grad_weight = partial_grad_weight.sum(axis=0)
        return grad_x, grad_weight  

In [2]:
f_weightedsum = WeightedSumFunc.apply

torch.manual_seed(0)
x = torch.randn(128, 64, device="cuda", dtype=torch.float32, requires_grad=True)
w = torch.randn(64, device="cuda", dtype=torch.float32, requires_grad=True)

y = f_weightedsum(x, w)
loss = y.sum()
loss.backward()

print("y:", y[:5])
print("grad_x[0]:", x.grad[0, :5])
print("grad_w[:5]:", w.grad[:5])

y: tensor([-1.5445,  7.5998,  2.6024,  8.2226,  1.1747], device='cuda:0',
       grad_fn=<SliceBackward0>)
grad_x[0]: tensor([ 0.1808, -0.5523,  0.9238, -0.7350,  1.3800], device='cuda:0')
grad_w[:5]: tensor([-11.8855, -15.2647, -13.0576,  -6.6209, -13.9196], device='cuda:0')


### Problem (flash_forward): 15 points

(a) Write a pure PyTorch (no Triton) `autograd.Function` that implements the FlashAttention-2 forward pass. This will be a lot slower than the regular PyTorch implementation, but will help you debug your Triton kernel.

Your implementation should take input $Q, K$, and $V$ as well as a flag `is_causal` and produce the output $O$ and the `logsumexp` value $L$. You can ignore the `is_causal` flag for this task. The `autograd.Function` forward should then use save $L, Q, K, V, O$ for the backward pass and return $O$. Remember that the implementation of the forward method of `autograd.Function` always takes the context as its first parameter. Any `autograd.Function` class needs to implement a backward method, but for now you can make it just raise `NotImplementedError`. If you need something to compare against, you can implement Equation 4 to 6 and 12 in PyTorch and compare your outputs.

The interface is then `def forward(ctx, Q, K, V, is_causal=False)`. Determine your own tile sizes, but make sure they are at least of size $16 \times 16$. We will always test your code with dimensions that are clean powers of 2 and at least 16, so you don’t need to worry about out-of-bounds accesses.

**Deliverable**: A `torch.autograd.Function` subclass that implements FlashAttention-2 in the forward pass. To test your code, implement `[adapters.get_flashattention_autograd_function_pytorch]`. Then, run the test with `uv run pytest -k test_flash_forward_pass_pytorch` and make sure your implementation passes it.

In [10]:
# uv run pytest -k test_flash_forward_pass_pytorch
!python -m pytest -k test_flash_forward_pass_pytorch

platform win32 -- Python 3.11.0rc2, pytest-8.3.5, pluggy-1.5.0
rootdir: e:\Code\CS336\assignment2-systems
configfile: pyproject.toml
plugins: jaxtyping-0.3.1
collected 16 items / 15 deselected / 1 selected

tests/test_attention.py::test_flash_forward_pass_pytorch [32mPASSED[0m



(b) Write a Triton kernel for the forward pass of FlashAttention-2 following Algorithm 1. Then, write another subclass of `torch.autograd.Function` that calls this (fused) kernel in the forward pass, instead of computing the result in PyTorch. A few problem-specific tips:

  * To debug, we suggest comparing the results of each Triton operation you perform with the tiled PyTorch implementation you wrote in part (a).
  * Your launch grid should be set as `(Tq, batch_size)`, meaning each Triton program instance will load only elements from a single batch index, and only read/write to a single query tile of $Q, O$, and $L$.
  * The kernel should only have a single loop, which will iterate key tiles $1 \le j \le T_k$.
  * Advance block pointers at the end of the loop.
  * Use the function declaration below (using the block pointer we give you, you should be able to infer the setup of the rest of the pointers):

```python
@triton.jit
def flash_fwd_kernel(
    Q_ptr, K_ptr, V_ptr,
    O_ptr, L_ptr,
    stride_qb, stride_qq, stride_qd,
    stride_kb, stride_kk, stride_kd,
    stride_vb, stride_vk, stride_vd,
    stride_ob, stride_oq, stride_od,
    stride_lb, stride_lq,
    N_QUERIES, N_KEYS,
    scale,
    D: tl.constexpr,
    Q_TILE_SIZE: tl.constexpr,
    K_TILE_SIZE: tl.constexpr,
):
    # Program indices
    query_tile_index = tl.program_id(0)
    batch_index = tl.program_id(1)

    # Offset each pointer with the corresponding batch index
    # multiplied with the batch stride for each tensor
    Q_block_ptr = tl.make_block_ptr(
        Q_ptr + batch_index * stride_qb,
        shape=(N_QUERIES, D),
        strides=(stride_qq, stride_qd),
        offsets=(query_tile_index * Q_TILE_SIZE, 0),
        block_shape=(Q_TILE_SIZE, D),
        order=(1, 0),
    )

    ...
```

where `scale` is $\frac{1}{\sqrt{d}}$ and `Q_TILE_SIZE` and `K_TILE_SIZE` are $B_q$ and $B_k$ respectively. You can tune these later.

These additional guidelines may help you avoid precision issues:

  * The on chip buffers ($O_i, l, m$) should have dtype `tl.float32`. If you’re accumulating into an output buffer, use the `acc` argument (`acc = tl.dot(..., acc=acc)`).
  * Cast $\tilde{P}^{(j)}_i$ to the dtype of $V^{(j)}$ before multiplying them, and cast $O_i$ to the appropriate dtype before writing it to global memory. Casting is done with `tensor.to`. You can get the dtype of a tensor with `tensor.dtype`, and the dtype of a block pointer/pointer with `*_block_ptr.type.element_ty`.

**Deliverable**: A `torch.autograd.Function` subclass that implements FlashAttention-2 in the forward pass using your Triton kernel. Implement `[adapters.get_flash_autograd_function_triton]`. Then, run the test with `uv run pytest -k test_flash_forward_pass_triton` and make sure your implementation passes it.

(c) Add a flag as the last argument to your `autograd.Function` implementation for causal masking. This should be a boolean flag that when set to True enables an index comparison for causal masking. Your Triton kernel should have a corresponding additional parameter `is_causal: tl.constexpr` (this is a required type annotation). In Triton, construct appropriate index vectors for queries and keys, and compare them to form a square mask of size $B_q \times B_k$. For elements that are masked out, add the constant value of $-1e6$ to the corresponding elements of the attention score matrix $S^{(j)}_i$. Make sure to save the mask flag for backward using `ctx.is_causal = is_causal`.

**Deliverable**: An additional flag for your `torch.autograd.Function` subclass that implements the FlashAttention-2 forward pass with causal masking using your Triton kernel. Make sure that the flag is optional with default `False` so the previous tests still pass.

In [1]:
# uv run pytest -k test_flash_forward_pass_triton
!python -m pytest -k test_flash_forward_pass_triton

platform win32 -- Python 3.11.0rc2, pytest-8.3.5, pluggy-1.5.0
rootdir: e:\Code\CS336\assignment2-systems
configfile: pyproject.toml
plugins: jaxtyping-0.3.1
collected 16 items / 14 deselected / 2 selected

tests/test_attention.py::test_flash_forward_pass_triton[False] [32mPASSED[0m
tests/test_attention.py::test_flash_forward_pass_triton[True] [32mPASSED[0m



---

### Problem (flash_backward): 5 points

Implement the backward pass for your FlashAttention-2 `autograd.Function` using PyTorch (not Triton) and `torch.compile`. Your implementation should take the $Q, K, V, O, dO$, and $L$ tensors as output, and return $dQ, dK$ and $dV$. Remember to compute and use the $D$ vector. You may follow along the computations of Equations 13 to 19.

**Deliverable**: To test your implementation, run `uv run pytest -k test_flash_backward`.

In [2]:
# uv run pytest -k test_flash_backward
!python -m pytest -k test_flash_backward

platform win32 -- Python 3.11.0rc2, pytest-8.3.5, pluggy-1.5.0
rootdir: e:\Code\CS336\assignment2-systems
configfile: pyproject.toml
plugins: jaxtyping-0.3.1
collected 16 items / 13 deselected / 3 selected

tests/test_attention.py::test_flash_backward_pytorch [32mPASSED[0m
tests/test_attention.py::test_flash_backward_triton[False] [32mPASSED[0m
tests/test_attention.py::test_flash_backward_triton[True] [32mPASSED[0m

tests/test_attention.py::test_flash_backward_triton[False]
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass



### Problem (flash_benchmarking): 5 points

Write a benchmarking script using `triton.testing.do_bench` that compares the performance of your (partially) Triton implementation of FlashAttention-2 forward and backward passes with a regular PyTorch implementation (i.e., not using FlashAttention). Specifically, you will report a table that includes latencies for forward, backward, and the end-to-end forward-backward pass, for both your Triton and PyTorch implementations. Randomly generate any necessary inputs before you start benchmarking, and run the benchmark on a single H100. Always use batch size 1 and causal masking. Sweep over the cartesian product of sequence lengths of various powers of 2 from 128 up to 65536, embedding dimension sizes of various powers of 2 from 16 up to size 128, and precisions of `torch.bfloat16` and `torch.float32`. You will likely need to adjust tile sizes depending on the input sizes.

**Deliverable**: A table of results comparing your implementation of FlashAttention-2 with the PyTorch implementation, using the settings above and reporting forward, backward, and end-to-end latencies.

-----

<div style="background-color: rgb(196, 196, 196); padding: 10px; color: #333;">

| method | dtype | d_model | seq_len | fwd_ms | bwd_ms | e2e_ms |
|---|---|---|---|---|---|---|
| triton | bfloat16 | 16 | 128 | 0.013 | 0.143 | 0.166 |
| torch | bfloat16 | 16 | 128 | 0.117 | 0.164 | 0.237 |
| triton | bfloat16 | 16 | 256 | 0.015 | 0.126 | 0.155 |
| torch | bfloat16 | 16 | 256 | 0.092 | 0.162 | 0.234 |
| triton | bfloat16 | 16 | 512 | 0.019 | 0.147 | 0.159 |
| torch | bfloat16 | 16 | 512 | 0.123 | 0.218 | 0.323 |
| triton | bfloat16 | 16 | 1024 | 0.029 | 0.173 | 0.181 |
| torch | bfloat16 | 16 | 1024 | 0.218 | 0.311 | 0.527 |
| triton | bfloat16 | 16 | 2048 | 0.053 | 0.174 | 0.230 |
| torch | bfloat16 | 16 | 2048 | 1.048 | 0.993 | 2.033 |
| triton | bfloat16 | 16 | 4096 | 0.147 | 0.342 | 0.479 |
| torch | bfloat16 | 16 | 4096 | 3.263 | 3.603 | 6.859 |
| triton | bfloat16 | 16 | 8192 | 0.440 | 1.000 | 1.442 |
| torch | bfloat16 | 16 | 8192 | 11.573 | 13.796 | 25.353 |
| triton | bfloat16 | 16 | 16384 | 1.540 | 3.648 | 5.216 |
| triton | bfloat16 | 16 | 32768 | 5.927 | 14.598 | 20.593 |
| triton | bfloat16 | 16 | 65536 | 24.180 | 60.227 | 84.836 |
| triton | bfloat16 | 32 | 128 | 0.014 | 0.123 | 0.155 |
| torch | bfloat16 | 32 | 128 | 0.090 | 0.156 | 0.233 |
| triton | bfloat16 | 32 | 256 | 0.015 | 0.185 | 0.183 |
| torch | bfloat16 | 32 | 256 | 0.100 | 0.157 | 0.246 |
| triton | bfloat16 | 32 | 512 | 0.021 | 0.153 | 0.169 |
| torch | bfloat16 | 32 | 512 | 0.127 | 0.419 | 0.570 |
| triton | bfloat16 | 32 | 1024 | 0.035 | 0.163 | 0.236 |
| torch | bfloat16 | 32 | 1024 | 0.218 | 0.318 | 0.529 |
| triton | bfloat16 | 32 | 2048 | 0.073 | 0.222 | 0.291 |
| torch | bfloat16 | 32 | 2048 | 1.051 | 0.971 | 2.023 |
| triton | bfloat16 | 32 | 4096 | 0.195 | 0.537 | 0.733 |
| torch | bfloat16 | 32 | 4096 | 3.281 | 3.580 | 6.861 |
| triton | bfloat16 | 32 | 8192 | 0.618 | 1.763 | 2.385 |
| torch | bfloat16 | 32 | 8192 | 11.715 | 13.921 | 25.665 |
| triton | bfloat16 | 32 | 16384 | 2.120 | 6.546 | 8.876 |
| triton | bfloat16 | 32 | 32768 | 8.961 | 27.094 | 35.812 |
| triton | bfloat16 | 32 | 65536 | 37.069 | 108.055 | 144.384 |
| triton | bfloat16 | 64 | 128 | 0.017 | 0.272 | 0.291 |
| torch | bfloat16 | 64 | 128 | 0.100 | 0.170 | 0.229 |
| triton | bfloat16 | 64 | 256 | 0.018 | 0.138 | 0.160 |
| torch | bfloat16 | 64 | 256 | 0.098 | 0.161 | 0.223 |
| triton | bfloat16 | 64 | 512 | 0.025 | 0.138 | 0.189 |
| torch | bfloat16 | 64 | 512 | 0.106 | 0.198 | 0.293 |
| triton | bfloat16 | 64 | 1024 | 0.043 | 0.164 | 0.199 |
| torch | bfloat16 | 64 | 1024 | 0.219 | 0.344 | 0.545 |
| triton | bfloat16 | 64 | 2048 | 0.113 | 0.279 | 0.399 |
| torch | bfloat16 | 64 | 2048 | 1.050 | 0.973 | 2.020 |
| triton | bfloat16 | 64 | 4096 | 0.319 | 0.798 | 1.125 |
| torch | bfloat16 | 64 | 4096 | 3.314 | 3.698 | 7.006 |
| triton | bfloat16 | 64 | 8192 | 1.005 | 2.772 | 3.800 |
| torch | bfloat16 | 64 | 8192 | 11.951 | 14.231 | 26.583 |
| triton | bfloat16 | 64 | 16384 | 3.869 | 11.036 | 14.831 |
| triton | bfloat16 | 64 | 32768 | 15.929 | 44.985 | 59.731 |
| triton | bfloat16 | 64 | 65536 | 64.536 | 175.378 | 241.275 |
| triton | bfloat16 | 128 | 128 | 0.087 | 0.160 | 0.192 |
| torch | bfloat16 | 128 | 128 | 0.155 | 0.283 | 0.249 |
| triton | bfloat16 | 128 | 256 | 0.023 | 0.145 | 0.179 |
| torch | bfloat16 | 128 | 256 | 0.102 | 0.171 | 0.242 |
| triton | bfloat16 | 128 | 512 | 0.035 | 0.146 | 0.191 |
| torch | bfloat16 | 128 | 512 | 0.119 | 0.203 | 0.298 |
| triton | bfloat16 | 128 | 1024 | 0.074 | 0.206 | 0.279 |
| torch | bfloat16 | 128 | 1024 | 0.258 | 0.362 | 0.606 |
| triton | bfloat16 | 128 | 2048 | 0.193 | 0.508 | 0.707 |
| torch | bfloat16 | 128 | 2048 | 1.109 | 1.112 | 2.218 |
| triton | bfloat16 | 128 | 4096 | 0.538 | 1.610 | 2.147 |
| torch | bfloat16 | 128 | 4096 | 3.563 | 4.127 | 7.712 |
| triton | bfloat16 | 128 | 8192 | 2.038 | 5.930 | 8.025 |
| torch | bfloat16 | 128 | 8192 | 13.369 | 16.519 | 29.539 |
| triton | bfloat16 | 128 | 16384 | 8.426 | 23.505 | 31.836 |
| triton | bfloat16 | 128 | 32768 | 35.007 | 92.706 | 129.925 |
| triton | bfloat16 | 128 | 65536 | 145.978 | 367.843 | 533.232 |
| triton | float32 | 16 | 128 | 0.013 | 0.148 | 0.178 |
| torch | float32 | 16 | 128 | 0.151 | 0.234 | 0.273 |
| triton | float32 | 16 | 256 | 0.017 | 0.147 | 0.154 |
| torch | float32 | 16 | 256 | 0.122 | 0.178 | 0.295 |
| triton | float32 | 16 | 512 | 0.021 | 0.147 | 0.187 |
| torch | float32 | 16 | 512 | 0.148 | 0.242 | 0.360 |
| triton | float32 | 16 | 1024 | 0.036 | 0.152 | 0.191 |
| torch | float32 | 16 | 1024 | 0.342 | 0.576 | 0.942 |
| triton | float32 | 16 | 2048 | 0.081 | 0.232 | 0.307 |
| torch | float32 | 16 | 2048 | 1.403 | 1.815 | 3.218 |
| triton | float32 | 16 | 4096 | 0.216 | 0.636 | 0.862 |
| torch | float32 | 16 | 4096 | 5.138 | 7.127 | 12.257 |
| triton | float32 | 16 | 8192 | 0.708 | 2.091 | 2.798 |
| torch | float32 | 16 | 8192 | 19.049 | 27.785 | 46.795 |
| triton | float32 | 16 | 16384 | 2.456 | 7.868 | 10.417 |
| triton | float32 | 16 | 32768 | 10.343 | 31.788 | 42.199 |
| triton | float32 | 16 | 65536 | 43.439 | 129.405 | 173.832 |
| triton | float32 | 32 | 128 | 0.015 | 0.159 | 0.177 |
| torch | float32 | 32 | 128 | 0.099 | 0.148 | 0.208 |
| triton | float32 | 32 | 256 | 0.017 | 0.154 | 0.143 |
| torch | float32 | 32 | 256 | 0.100 | 0.213 | 0.425 |
| triton | float32 | 32 | 512 | 0.025 | 0.129 | 0.165 |
| torch | float32 | 32 | 512 | 0.119 | 0.234 | 0.354 |
| triton | float32 | 32 | 1024 | 0.044 | 0.197 | 0.251 |
| torch | float32 | 32 | 1024 | 0.346 | 0.558 | 0.915 |
| triton | float32 | 32 | 2048 | 0.114 | 0.364 | 0.476 |
| torch | float32 | 32 | 2048 | 1.419 | 1.871 | 3.299 |
| triton | float32 | 32 | 4096 | 0.315 | 1.116 | 1.433 |
| torch | float32 | 32 | 4096 | 5.299 | 7.299 | 12.586 |
| triton | float32 | 32 | 8192 | 1.027 | 3.860 | 4.897 |
| torch | float32 | 32 | 8192 | 19.438 | 28.232 | 47.559 |
| triton | float32 | 32 | 16384 | 4.282 | 14.641 | 18.928 |
| triton | float32 | 32 | 32768 | 17.697 | 58.232 | 76.344 |
| triton | float32 | 32 | 65536 | 71.228 | 233.577 | 304.153 |
| triton | float32 | 64 | 128 | 0.076 | 0.346 | 0.242 |
| torch | float32 | 64 | 128 | 0.113 | 0.175 | 0.246 |
| triton | float32 | 64 | 256 | 0.022 | 0.136 | 0.186 |
| torch | float32 | 64 | 256 | 0.122 | 0.222 | 0.272 |
| triton | float32 | 64 | 512 | 0.033 | 0.128 | 0.176 |
| torch | float32 | 64 | 512 | 0.120 | 0.239 | 0.354 |
| triton | float32 | 64 | 1024 | 0.072 | 0.243 | 0.301 |
| torch | float32 | 64 | 1024 | 0.402 | 0.651 | 1.027 |
| triton | float32 | 64 | 2048 | 0.191 | 0.576 | 0.760 |
| torch | float32 | 64 | 2048 | 1.504 | 2.048 | 3.603 |
| triton | float32 | 64 | 4096 | 0.553 | 1.771 | 2.320 |
| torch | float32 | 64 | 4096 | 5.338 | 7.643 | 13.033 |
| triton | float32 | 64 | 8192 | 2.067 | 6.594 | 8.724 |
| torch | float32 | 64 | 8192 | 21.432 | 30.872 | 52.177 |
| triton | float32 | 64 | 16384 | 8.537 | 25.905 | 34.863 |
| triton | float32 | 64 | 32768 | 35.241 | 103.257 | 138.800 |
| triton | float32 | 64 | 65536 | 143.147 | 412.156 | 559.069 |
| triton | float32 | 128 | 128 | 0.155 | 0.600 | 0.177 |
| torch | float32 | 128 | 128 | 0.109 | 0.155 | 0.217 |
| triton | float32 | 128 | 256 | 0.036 | 0.138 | 0.180 |
| torch | float32 | 128 | 256 | 0.095 | 0.175 | 0.260 |
| triton | float32 | 128 | 512 | 0.057 | 0.228 | 0.278 |
| torch | float32 | 128 | 512 | 0.135 | 0.257 | 0.398 |
| triton | float32 | 128 | 1024 | 0.148 | 0.468 | 0.610 |
| torch | float32 | 128 | 1024 | 0.416 | 0.702 | 1.112 |
| triton | float32 | 128 | 2048 | 0.442 | 1.439 | 1.884 |
| torch | float32 | 128 | 2048 | 1.728 | 2.539 | 4.329 |
| triton | float32 | 128 | 4096 | 1.574 | 5.150 | 6.732 |
| torch | float32 | 128 | 4096 | 6.447 | 9.785 | 16.366 |
| triton | float32 | 128 | 8192 | 5.980 | 19.558 | 25.451 |
| torch | float32 | 128 | 8192 | 25.560 | 38.994 | 65.210 |
| triton | float32 | 128 | 16384 | 23.313 | 76.732 | 99.668 |
| triton | float32 | 128 | 32768 | 93.283 | 304.343 | 391.926 |
| triton | float32 | 128 | 65536 | 385.503 | 1196.686 | 1600.045 |


</div>

-----


---

## 2.1 Single-Node Distributed Communication in PyTorch

maybe useful: https://zhuanlan.zhihu.com/p/1927774560222175926

### Problem (distributed_communication_single_node): 5 points

Write a script to benchmark the runtime of the all-reduce operation in the single-node multi-process setup. The example code above may provide a reasonable starting point. Experiment with varying the following settings:

  * Backend + device type: Gloo + CPU, NCCL + GPU.
  * all-reduce data size: float32 data tensors ranging over 1MB, 10MB, 100MB, 1GB.
  * Number of processes: 2, 4, or 6 processes.
  * Resource requirements: Up to 6 GPUs. Each benchmarking run should take less than 5 minutes.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

**Deliverable**: Plot(s) and/or table(s) comparing the various settings, with 2-3 sentences of commentary about your results and thoughts about how the various factors interact.

## 2.2 A Naïve Implementation of Distributed Data Parallel Training

### Problem (naive_ddp): 5 points

**Deliverable**: Write a script to naively perform distributed data parallel training by all-reducing individual parameter gradients after the backward pass. To verify the correctness of your DDP implementation, use it to train a small toy model on randomly-generated data and verify that its weights match the results from single-process training.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

### Problem (naive_ddp_benchmarking): 3 points

In this naïve DDP implementation, parameters are individually all-reduced across ranks after each backward pass. To better understand the overhead of data parallel training, create a script to benchmark your previously-implemented language model when trained with this naïve implementation of DDP. Measure the total time per training step and the proportion of time spent on communicating gradients. Collect measurements in the single-node setting (1 node $\times$ 2 GPUs) for the XL model size described in $\S 1.1.2$.


**Deliverable**: A description of your benchmarking setup, along with the measured time per training iteration and time spent communicating gradients for each setting.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

## 2.3 Improving Upon the Minimal DDP Implementation

### Problem (minimal_ddp_flat_benchmarking): 2 points

Modify your minimal DDP implementation to communicate a tensor with flattened gradients from all parameters. Compare its performance with the minimal DDP implementation that issues an all-reduce for each parameter tensor under the previously-used conditions (1 node $\times$ 2 GPUs, XL model size as described in $\S 1.1.2$).

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

**Deliverable**: The measured time per training iteration and time spent communicating gradients under distributed data parallel training with a single batched all-reduce call. 1-2 sentences comparing the results when batching vs. individually communicating gradients.

### Problem (ddp_overlap_individual_parameters): 5 points

Implement a Python class to handle distributed data parallel training. The class should wrap an arbitrary PyTorch `nn.Module` and take care of broadcasting the weights before training (so all ranks have the same initial parameters) and issuing communication calls for gradient averaging. We recommend the following public interface:

```python
def __init__(self, module: torch.nn.Module):
```

Given an instantiated PyTorch `nn.Module` to be parallelized, construct a DDP container that will handle gradient synchronization across ranks.

```python
def forward(self, *inputs, **kwargs):
```

Calls the wrapped module’s `forward()` method with the provided positional and keyword arguments.

```python
def finish_gradient_synchronization(self):
```

When called, wait for asynchronous communication calls to be queued on GPU.

To use this class to perform distributed training, we’ll pass it a module to wrap, and then add a call to `finish_gradient_synchronization()` before we run `optimizer.step()` to ensure that the optimizer step, an operation that depends on the gradients, may be queued:

```python
model = ToyModel().to(device)
ddp_model = DDP(model)
for _ in range(train_steps):
    x, y = get_batch()
    logits = ddp_model(x)
    loss = loss_fn(logits, y)
    loss.backward()
    ddp_model.finish_gradient_synchronization()
    optimizer.step()
```

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

**Deliverable**: Implement a container class to handle distributed data parallel training. This class should overlap gradient communication and the computation of the backward pass. To test your DDP class, first implement the adapters `[adapters.get_ddp_individual_parameters]` and `[adapters.ddp_individual_parameters_on_after_backward]` (the latter is optional, depending on your implementation you may not need it). Then, to execute the tests, run `uv run pytest tests/test_ddp_individual_parameters.py`. We recommend running the tests multiple times (e.g., 5) to ensure that it passes reliably.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

### Problem (ddp_overlap_individual_parameters_benchmarking): 1 point

(a) Benchmark the performance of your DDP implementation when overlapping backward pass computation with communication of individual parameter gradients. Compare its performance with our previously-studied settings (the minimal DDP implementation that either issues an all-reduce for each parameter tensor, or a single all-reduce on the concatenation of all parameter tensors) with the same setup: 1 node, 2 GPUs, and the XL model size described in $\\S 1.1.2$.

**Deliverable**: The measured time per training iteration when overlapping the backward pass with communication of individual parameter gradients, with 1-2 sentences comparing the results.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(b) Instrument your benchmarking code (using the 1 node, 2 GPUs, XL model size setup) with the Nsight profiler, comparing between the initial DDP implementation and this DDP implementation that overlaps backward computation and communication. Visually compare the two traces, and provide a profiler screenshot demonstrating that one implementation overlaps compute with communication while the other doesn’t.

**Deliverable**: 2 screenshots (one from the initial DDP implementation, and another from this DDP implementation that overlaps compute with communication) that visually show that communication is or isn’t overlapped with the backward pass.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

### Problem (ddp_overlap_bucketed): 8 points

Implement a Python class to handle distributed data parallel training, using gradient bucketing to improve communication efficiency. The class should wrap an arbitrary input PyTorch `nn.Module` and take care of broadcasting the weights before training (so all ranks have the same initial parameters) and issuing bucketed communication calls for gradient averaging. We recommend the following interface:

```python
def __init__(self, module: torch.nn.Module, bucket_size_mb: float):
```

Given an instantiated PyTorch `nn.Module` to be parallelized, construct a DDP container that will handle gradient synchronization across ranks. Gradient synchronization should be bucketed, with each bucket holding at most `bucket_size_mb` of parameters.

```python
def forward(self, *inputs, **kwargs):
```

Calls the wrapped module’s `forward()` method with the provided positional and keyword arguments.

```python
def finish_gradient_synchronization(self):
```

When called, wait for asynchronous communication calls to be queued on GPU.

Beyond the addition of a `bucket_size_mb` initialization parameter, this public interface matches the interface of our previous DDP implementation that individually communicated each parameter. We suggest allocating parameters to buckets using the reverse order of `model.parameters()`, since the gradients will become ready in approximately that order during the backward pass.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

**Deliverable**: Implement a container class to handle distributed data parallel training. This class should overlap gradient communication and the computation of the backward pass. Gradient communication should be bucketed, to reduce the total number of communication calls. To test your implementation, complete `[adapters.get_ddp_bucketed]`, `[adapters.ddp_bucketed_on_after_backward]`, and `[adapters.ddp_bucketed_on_train_batch_start]` (the latter two are optional, depending on your implementation you may not need them). Then, to execute the tests, run `pytest tests/test_ddp.py`. We recommend running the tests multiple times (e.g., 5) to ensure that it passes reliably.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

### Problem (ddp_bucketed_benchmarking): 3 points

(a) Benchmark your bucketed DDP implementation using the same config as the previous experiments (1 node, 2 GPUs, XL model size), varying the maximum bucket size (1, 10, 100, 1000 MB). Compare your results to the previous experiments without bucketing—do the results align with your expectations? If they don’t align, why not? You may have to use the PyTorch profiler as necessary to better understand how communication calls are ordered and/or executed. What changes in the experimental setup would you expect to yield results that are aligned with your expectations?

**Deliverable**: Measured time per training iteration for various bucket sizes. 3-4 sentence commentary about the results, your expectations, and potential reasons for any mismatch.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(b) Assume that the time it takes to compute the gradients for a bucket is identical to the time it takes to communicate the gradient buckets. Write an equation that models the communication overhead of DDP (i.e., the amount of additional time spent after the backward pass) as a function of the total size (bytes) of the model parameters ($s$), the all-reduce algorithm bandwidth ($w$, computed as the size of each rank’s data divided by the time it takes to finish the all-reduce), the overhead (seconds) associated with each communication call ($o$), and the number of buckets ($n_b$). From this equation, write an equation for the optimal bucket size that minimizes DDP overhead.

**Deliverable**: Equation that models DDP overhead, and an equation for the optimal bucket size.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

## 2.4 4D Parallelism

### Problem (communication_accounting): 10 points

Consider a new model config, `XXL`, with $d_{model}=16384$, $d_{ff}=53248$, and `num_blocks` $=126$. Because for very large models, the vast majority of FLOPs are in the feedforward networks, we make some simplifying assumptions. First, we omit attention, input embeddings, and output linear layers. Then, we assume that each FFN is simply two linear layers (ignoring the activation function), where the first has input size $d_{model}$ and output size $d_{ff}$, and the second has input size $d_{ff}$ and output size $d_{model}$. Your model consists of $num_{blocks}$ blocks of these two linear layers. Don’t do any activation checkpointing, and keep your activations and gradient communications in `BF16`, while your accumulated gradients, master weights and optimizer state should be in `FP32`.

(a) How much memory would it take to store the master model weights, accumulated gradients and optimizer states in FP32 on a single device? How much memory is saved for backward (these will be in BF16)? How many H100 80GB GPUs worth of memory is this?

**Deliverable**: Your calculations and a one-sentence response.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(b) Now assume your master weights, optimizer state, gradients and half of your activations (in practice every second layer) are sharded across $N_{FSDP}$ devices. Write an expression for how much memory this would take per device. What value does $N_{FSDP}$ need to be for the total memory cost to be less than 1 v5p TPU (95GB per device)?

**Deliverable**: Your calculations and a one-sentence response.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(c) Consider only the forward pass. Use the communication bandwidth of $W_{ici} = 2 \\cdot 9 \\cdot 10^{10}$ and FLOPS/s of $C = 4.6 \\cdot 10^{14}$ for TPU v5p as given in the TPU Scaling Book. Following the notation of the Scaling Book, use $M_X = 2, M_Y = 1$ (a 3D mesh), with $X = 16$ being your FSDP dimension, and $Y = 4$ being your TP dimension. At what per-device batch size is this model compute bound? What is the overall batch size in this setting?

**Deliverable**: Your calculations and a one-sentence response.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(d) In practice, we want the overall batch size to be as small as possible, and we also always use our compute effectively (in other words we want to never be communication bound). What other tricks can we employ to reduce the batch size of our model but retain high throughput?

**Deliverable**: A one-paragraph response. Back up your claims with references and/or equations.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

---

## 3 Optimizer State Sharding

### Problem (optimizer_state_sharding): 15 points

Implement a Python class to handle optimizer state sharding. The class should wrap an arbitrary input PyTorch `optim.Optimizer` and take care of synchronizing updated parameters after each optimizer step. We recommend the following public interface:

```python
def __init__(self, params, optimizer_cls: Type[Optimizer], **kwargs: Any):
```

Initializes the sharded state optimizer. `params` is a collection of parameters to be optimized (or parameter groups, in case the user wants to use different hyperparameters, such as learning rates, for different parts of the model); these parameters will be sharded across all the ranks. The `optimizer_cls` parameter specifies the type of optimizer to be wrapped (e.g., `optim.AdamW`). Finally, any remaining keyword arguments are forwarded to the constructor of the `optimizer_cls`. Make sure to call the `torch.optim.Optimizer` super-class constructor in this method.

```python
def step(self, closure, **kwargs):
```

Calls the wrapped optimizer’s `step()` method with the provided closure and keyword arguments. After updating the parameters, synchronize with the other ranks.

```python
def add_param_group(self, param_group: dict[str, Any]):
```

This method should add a parameter group to the sharded optimizer. This is called during construction of the sharded optimizer by the super-class constructor and may also be called during training (e.g., for gradually unfreezing layers in a model). As a result, this method should handle assigning the model’s parameters among the ranks.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

**Deliverable**: Implement a container class to handle optimizer state sharding. To test your sharded optimizer, first implement the adapter `[adapters.get_sharded_optimizer]`. Then, to execute the tests, run `uv run pytest tests/test_sharded_optimizer.py`. We recommend running the tests multiple times (e.g., 5) to ensure that it passes reliably.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

### Problem (optimizer_state_sharding_accounting): 5 points

(a) Create a script to profile the peak memory usage when training language models with and without optimizer state sharding. Using the standard configuration (1 node, 2 GPUs, XL model size), report the peak memory usage after model initialization, directly before the optimizer step, and directly after the optimizer step. Do the results align with your expectations? Break down the memory usage in each setting (e.g., how much memory for parameters, how much for optimizer states, etc.).

**Deliverable**: 2-3 sentence response with peak memory usage results and a breakdown of how the memory is divided between different model and optimizer components.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(b) How does our implementation of optimizer state sharding affect training speed? Measure the time taken per iteration with and without optimizer state sharding for the standard configuration (1 node, 2 GPUs, XL model size).

**Deliverable**: 2-3 sentence response with your timings.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

(c) How does our approach to optimizer state sharding differ from ZeRO stage 1 (described as ZeRO-DP Pos in Rajbhandari et al., 2020)?

**Deliverable**: 2-3 sentence summary of any differences, especially those related to memory and communication volume.

-----

<span style="background-color: #29B6F6; color: black">
a
</span>

-----

---