---
---
---
<a name="MATH"></a>
## E) Memory Efficient Backprop [Difficulty: Medium to Hard] [Max points: 10]

In LLMs, the last layer is a projection matrix to calculate the probabilities of the next token, ie $\sigma(XW)$. However, if the vocabulary size is very large, say 128K, then the materialization of the logits causes VRAM spikes.

For example, if the `bsz = 4, qlen = 4096, hd = 4096, vocab = 128K`, then the memory usage for the logits in bfloat16 would be 4GB. In the worst case, we might even need to upcast logits to float32, so 8GB is needed.

In Unsloth, we utilize [Apple's Cut Cross Entropy Loss](https://machinelearning.apple.com/research/cut-your-losses) to reduce VRAM usage, by allowing a Triton kernel to create the logits on the fly to calculate the cross entropy loss. But this does not generalize well to other functions.

Our goal is to generalize this ultimately, but directly creating logits on the fly will be hard. Instead, let's take a slightly less complex approach. Let's first review some stuff. We first notice that during the normal case after forming the intermediate logits for 2 batches, we then do a gather function to aggregate the intermediate results into a single column:
$$
\begin{align}
\begin{bmatrix} x_1 \\ x_2 \end{bmatrix} \times W &= \begin{bmatrix} x_1 W \\ x_2 W \end{bmatrix} \\
f \bigg( \begin{bmatrix} x_1 W \\ x_2 W \end{bmatrix} \bigg) &= \begin{pmatrix} y_1 \\ y_2 \end{pmatrix}
\end{align}
$$

So, if we can somehow skip the materialization of the intermediate logits, and just output the output of `f`, we can save a lot of VRAM!

Notice during backpropagation we can use the chain rule:
$$
\begin{align}
\frac{dL}{dX} &= \frac{dL}{dy} \frac{dy}{dX} ; \frac{dL}{dW} = \frac{dL}{dy} \frac{dy}{dW} \\
\frac{dL}{dy} &= \text{Downstream from backprop} \\
\frac{dy}{dX} &= W^T \\
\frac{dy}{dW} &= X^T \\
\frac{dL}{dX} &= \frac{dL}{dy} W^T \\
\frac{dL}{dW} &= X^T \frac{dL}{dy} \\
\end{align}
$$

If we simply compute the intermediate tensors on the fly via batches, say we do batch 1, then batch 2, we can reduce VRAM usage from 4GB to 2GB!

$$
\begin{align}
\frac{dL}{dX} &= \begin{bmatrix} \frac{dL_1}{dy_1} W^T \\ \frac{dL_2}{dy_2} W^T \end{bmatrix} \\
\frac{dL}{dW} &= \bigg( X_1^T \frac{dL_1}{dy_1} + X_2^T  \frac{dL_2}{dy_2} \bigg)
\end{align}
$$

1. Your goal is to write a `torch.autograd.Function` with a `forward` and `backward` pass showcasing this memory efficient implementation.

2. You must NOT hard code the derivatives - move the transformation function from the logits / intermeditate tensors to a smaller tensor as a separate function which can allow `autograd` to pass through it.

3. As a hint, look at `torch.checkpoint` at https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py. Also, don't forget about the upstream gradients! We need to multiply them to the current gradients!

4. Make the Cross Entropy Loss work. You must show other functions working as well.

## Marking Criteria for E) Max points = 10
```python
if attemped_E:
    E_score = 0
    if VRAM_50_percent_reduction: E_score += 2
    if remove_float32_upcast: E_score = 0
    if show_ce_loss_works: E_score += 1
    if show_other_functions_work: E_score += 1
    if hardcoded_gradients: E_score = 0
    if allows_dynamic_chunk_sizes: E_score += 1
    if llama_1B_training_loss_matches: E_score += 1
    else: E_score = 0
    if GRPO_memory_efficient_linear_works: E_score += 4
    final_score += E_score
else:
    final_score += 0
```

In [2]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os

major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = major_version >= 8
from inspect import currentframe as _C, getframeinfo

_F = lambda c: getframeinfo(c).lineno  # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m")  # Red colored warnings


# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""


def assert_same(x, y, line, dtype):
    assert x.dtype == dtype
    try:
        torch.testing.assert_close(x, y, check_stride=True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )


os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [3]:
def transformation_function(batch, linear, labels):
    x = linear(batch).float()  # Up projection to large space
    from torch.nn import CrossEntropyLoss

    down_projection_function = CrossEntropyLoss(reduction="mean")
    # Down projection to small space
    loss = down_projection_function(x.view(-1, x.shape[-1]), labels.view(-1))
    return loss



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

assert device.type == "cuda", "vram measurements won't work; check environment setup"

In [4]:
class NaiveLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        """
        Naive forward pass that does NOT use chunking.
        This computes the full projection XW at once.

        Args:
          - X: Input tensor (bs x seq_len x hidden) or (seq_len x hidden)
          - linear: Linear projection function
          - labels: Target labels (bs x seq_len) or (seq_len)
          - forward_function: Function computing f(XW)

        Returns:
          - Computed loss (scalar) or transformed tensor
        """

        with torch.no_grad():
            # Compute f(XW), e.g., cross-entropy loss
            loss = forward_function(X, linear, labels)

        # Save context for backward pass
        # ctx.save_for_backward(X, labels)
        ctx.save_for_backward(X, labels)
        ctx.linear = linear
        ctx.forward_function = forward_function

        return loss

    @staticmethod
    def backward(ctx, dY):
        """
        Naive backward pass that computes gradients from the saved full logits.

        Args:
          - dY: Gradient from the next layer (scalar or tensor)

        Returns:
          - Gradients w.r.t. X, None, None, None
        """

        X, labels = ctx.saved_tensors
        linear = ctx.linear
        forward_function = ctx.forward_function

        # Re-enable autograd for X and logits
        with torch.enable_grad():
            X.requires_grad = True
            linear.weight.requires_grad = True

            # Compute f(XW) again with autograd enabled to track gradients
            loss = forward_function(X, linear, labels)

        # Compute gradients with respect to X and W
        loss.backward(dY)

        return X.grad, None, None, None  # Grad w.r.t. X, W, None, None


This implementation of a memory efficient linear requires the reduction method to be specified
ie. `ctx.reduction` must be set to either mean or sum depending on the reduction method used in the loss function

In [5]:
class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        """
        Forward pass for memory-efficient linear projection + transformation.

        - X: Input tensor, could be (bs x seq_len x hidden) or (seq_len x hidden)
        - linear: Linear layer (performs projection to large space)
        - labels: Target labels for cross-entropy loss ((bs) x seq_len x vocab)
        - forward_function: Function that computes f(XW)
        - chunk_size: How many rows to process at once

        Returns:
            The total loss computed in chunks
        """
        outputs = []
        # EDIT THIS FUNCTION
        # we want to compute f(XW) without ever triggering a compute of XW
        # hence we do this in parts.

        # we get >50 vram reduction at 4 chunks, can be changed.
        num_chunks = 4
        seq_dim = len(X.shape) - 2
        seq_len = X.shape[seq_dim]
        num_chunks = min(seq_len, num_chunks)
        chunk_size = (seq_len + num_chunks - 1) // num_chunks

        # we unfortunately also need to be aware of the reduction method.
        # in the case of CE loss above, it's mean
        ctx.reduction = "mean"

        # save context for backward step
        ctx.linear = linear
        ctx.forward_function = forward_function
        ctx.num_chunks = num_chunks
        ctx.chunk_size = chunk_size
        ctx.seq_dim = seq_dim
        ctx.seq_len = seq_len
        ctx.save_for_backward(X, labels)  # save chunked version instead?

        with torch.no_grad():
            for i in range(num_chunks):
                start = i * chunk_size
                end = min((i + 1) * chunk_size, seq_len)

                X_chunk = X.narrow(seq_dim, start, end - start).contiguous()
                labels_chunk = labels.narrow(seq_dim, start, end - start).contiguous()

                output_chunk = forward_function(X_chunk, linear, labels_chunk)
                outputs.append(output_chunk)

        # based on the type of output and reduction method involved, aggregate it:
        # - scalar -> reduce to scalar
        # - vector/tensors -> combine into a single vector/tensor by stacking rows

        if isinstance(outputs[0], torch.Tensor) and outputs[0].dim() > 0:
            # stack outputs
            result = torch.cat(outputs, dim=seq_dim)  # validate choice of dimension
        else:
            # handle scalar case based on reduction method:
            if ctx.reduction == "sum":
                result = sum(outputs)
            if ctx.reduction == "mean":
                # handle weighting of outputs accordingly, based on chunk sizes
                # some numel weighing?
                # handle naively for now, assuming equal chunk sizing.
                result = sum(outputs) / len(outputs)
        return result

    @staticmethod
    def backward(ctx, dY):
        # restore context:
        X, labels = ctx.saved_tensors
        linear = ctx.linear
        forward_function = ctx.forward_function
        num_chunks = ctx.num_chunks
        chunk_size = ctx.chunk_size
        seq_dim = ctx.seq_dim
        seq_len = ctx.seq_len
        reduction = ctx.reduction

        dX = torch.zeros_like(X)  # Initialize gradient storage
        # dW = torch.zeros_like(linear.weight)  # Store accumulated W gradient; not necessary

        torch.cuda.reset_peak_memory_stats(device)
        print(
            f"Peak VRAM, pre-compute: {torch.cuda.max_memory_allocated(device) / 1e6}"
        )

        for i in range(num_chunks):
            start = i * chunk_size
            end = min((i + 1) * chunk_size, seq_len)

            X_chunk = X.narrow(seq_dim, start, end - start).contiguous()
            labels_chunk = labels.narrow(seq_dim, start, end - start).contiguous()

            # Recompute f(X_chunk W) to get the gradient
            X_chunk.requires_grad = True

            with torch.enable_grad():
                output_chunk = forward_function(X_chunk, linear, labels_chunk)

                # Compute gradients for chunk
                if output_chunk.dim() == 0:  # scalar
                    if reduction == "mean":
                        output_chunk.backward(
                            dY / num_chunks
                        )  # rescale gradient based on reduction method
                    if reduction == "sum":
                        output_chunk.backward(dY)
                else:
                    # For non-scalar outputs (not our current case)
                    # TODO: fix the chunking of dY; using the x_chunk shape is a bad reference
                    # instead use output_chunk shape to guide this step
                    output_chunk.backward(dY.narrow(seq_dim, start, end - start))

                # Accumulate gradients
                dX.narrow(seq_dim, start, end - start).copy_(
                    X_chunk.grad * 2
                )  # rescale gradients for dX

                # do a cleanup
                # since dW is additive, it shouldn't be contributing to peak ram.
                # instead focus is on cleanups for dX
                X_chunk.grad.detach_()
                del X_chunk, labels_chunk, output_chunk
                torch.cuda.empty_cache()


        return dX, None, None, None



In [6]:
def test_comparison(
    batch_size=4,
    seq_length=512,
    hidden_dim=1024,
    vocab_size=32_000,
    loss_function=transformation_function,
    seed=42,
):
    """
    Compare naive and memory-efficient implementations of linear projection.

    Args:
        batch_size: Batch size for test data
        seq_length: Sequence length for test data
        hidden_dim: Hidden dimension size
        vocab_size: Vocabulary size for output dimension
        loss_function: Function to transform linear output (default: transformation_function)
        seed: Random seed for reproducibility

    Returns:
        dict: Dictionary containing comparison results
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if device.type != "cuda":
        print("Warning: VRAM tracking only works on CUDA. Running on CPU.")
    else:
        torch.cuda.empty_cache()

    torch.manual_seed(seed)

    # Create test data
    X = torch.randn(
        batch_size,
        seq_length,
        hidden_dim,
        device=device,
        dtype=torch.float32,
        requires_grad=True,
    )
    labels = torch.randint(0, vocab_size, (batch_size, seq_length), device=device)

    # Create identical linear layers
    linear_naive = nn.Linear(hidden_dim, vocab_size, bias=False).to(device)
    linear_efficient = nn.Linear(hidden_dim, vocab_size, bias=False).to(device)
    linear_efficient.weight.data.copy_(linear_naive.weight.data)

    # Helper function to run a test with given implementation
    def run_test(implementation, X_input, linear_layer, base_memory=0):
        if device.type == "cuda":
            torch.cuda.reset_peak_memory_stats(device)

        # Forward and backward pass
        loss = implementation.apply(X_input, linear_layer, labels, loss_function)
        loss.backward()

        # Collect results
        results = {
            "loss": loss.item(),
            "X_grad": X_input.grad.detach().clone(),
            "W_grad": linear_layer.weight.grad.detach().clone(),
        }

        if device.type == "cuda":
            peak_memory = torch.cuda.max_memory_allocated(device)
            results["peak_memory"] = peak_memory
            results["memory_used"] = peak_memory - base_memory

        return results

    # Run naive implementation
    torch.cuda.reset_peak_memory_stats(device)
    naive_base_memory = torch.cuda.max_memory_allocated(device)

    naive_results = run_test(NaiveLinear, X, linear_naive, naive_base_memory)

    # Zero gradients and prepare for efficient implementation
    X.grad.zero_()
    linear_naive.weight.grad.zero_()
    X_eff = X.detach().clone().requires_grad_(True)

    # Run memory-efficient implementation
    torch.cuda.reset_peak_memory_stats(device)
    efficient_base_memory = torch.cuda.max_memory_allocated(device)

    efficient_results = run_test(
        MemoryEfficientLinear, X_eff, linear_efficient, efficient_base_memory
    )

    # Compare results
    comparisons = {
        "loss_match": torch.allclose(
            torch.tensor(naive_results["loss"]),
            torch.tensor(efficient_results["loss"]),
            rtol=1e-3,
            atol=1e-5,
        ),
        "X_grad_match": torch.allclose(
            naive_results["X_grad"], efficient_results["X_grad"], rtol=1e-3, atol=1e-5
        ),
        "W_grad_match": torch.allclose(
            naive_results["W_grad"], efficient_results["W_grad"], rtol=1e-3, atol=1e-5
        ),
    }

    # Print results
    print(f"Naive Loss: {naive_results['loss']:.6f}")
    print(f"Memory-Efficient Loss: {efficient_results['loss']:.6f}")
    print(f"Loss match: {comparisons['loss_match']}")
    print(f"Gradient w.r.t X match: {comparisons['X_grad_match']}")
    print(f"Gradient w.r.t W match: {comparisons['W_grad_match']}")

    if device.type == "cuda":
        print(f"Naive Base VRAM Usage (MB): {naive_base_memory / 1e6:.2f}")
        print(f"Naive Peak VRAM Usage (MB): {naive_results['peak_memory'] / 1e6:.2f}")
        print(f"Efficient Base VRAM Usage (MB): {efficient_base_memory / 1e6:.2f}")
        print(
            f"Memory-Efficient Peak VRAM Usage (MB): {efficient_results['peak_memory'] / 1e6:.2f}"
        )

        memory_saved_pct = (
            1 - (efficient_results["memory_used"] / naive_results["memory_used"])
        ) * 100
        print(f"Memory saved: {memory_saved_pct:.1f}%")
    else:
        print("VRAM usage tracking not available on CPU.")

    return {
        "naive": naive_results,
        "efficient": efficient_results,
        "comparisons": comparisons,
    }

In [7]:
test_comparison(
    batch_size=4,
    seq_length=512,
    hidden_dim=1024,
    vocab_size=32_000,
)


  return F.linear(input, self.weight, self.bias)


Peak VRAM, pre-compute: 585.38496
Naive Loss: 10.554384
Memory-Efficient Loss: 10.554384
Loss match: True
Gradient w.r.t X match: True
Gradient w.r.t W match: True
Naive Base VRAM Usage (MB): 272.65
Naive Peak VRAM Usage (MB): 1076.12
Efficient Base VRAM Usage (MB): 577.00
Memory-Efficient Peak VRAM Usage (MB): 918.31
Memory saved: 57.5%


{'naive': {'loss': 10.554384231567383,
  'X_grad': tensor([[[ 1.3807e-05, -1.6137e-05,  8.1403e-06,  ..., -1.3318e-05,
            -1.3838e-05,  1.6917e-05],
           [ 7.3444e-06, -8.2491e-06, -6.6505e-06,  ...,  9.7188e-06,
             4.3857e-06,  2.1376e-05],
           [ 2.0971e-05,  1.2372e-06,  2.5057e-05,  ..., -2.9860e-05,
             1.3536e-05,  2.0116e-05],
           ...,
           [-2.0055e-05,  9.8789e-06, -2.3446e-05,  ...,  2.1487e-05,
            -8.6920e-07,  1.6595e-05],
           [ 1.1875e-05, -2.1308e-05,  5.3893e-06,  ..., -1.1348e-05,
             2.1326e-05, -5.5247e-06],
           [-4.9283e-06,  3.3811e-06, -2.7616e-05,  ..., -5.9093e-06,
             2.8751e-05,  1.8444e-05]],
  
          [[-2.9359e-05, -2.4922e-05, -2.1881e-05,  ...,  2.0264e-05,
            -2.0535e-05, -7.9485e-06],
           [-1.0894e-05,  1.2523e-05, -2.0353e-05,  ...,  5.3034e-06,
             2.0384e-05, -2.4650e-06],
           [-2.0635e-05,  1.2297e-06,  1.5427e-05,  ..., -2

In [8]:
# alternative loss function
def mse_loss_function(batch, linear, labels):
    """
    Alternative loss function using MSE loss instead of cross-entropy.
    """
    x = linear(batch).float()
    from torch.nn import MSELoss

    # Create one-hot encoded targets for MSE loss
    one_hot_labels = torch.zeros(
        (labels.view(-1).size(0), linear.weight.size(0)), device=labels.device
    )
    one_hot_labels.scatter_(1, labels.view(-1, 1), 1)

    # Apply MSE loss
    loss_fn = MSELoss(reduction="mean")
    loss = loss_fn(x.view(-1, x.shape[-1]), one_hot_labels)
    return loss


test_comparison(
    batch_size=4,
    seq_length=512,
    hidden_dim=1024,
    vocab_size=32_000,
    loss_function=mse_loss_function,
)


Peak VRAM, pre-compute: 863.2576
Naive Loss: 0.332687
Memory-Efficient Loss: 0.332687
Loss match: True
Gradient w.r.t X match: True
Gradient w.r.t W match: True
Naive Base VRAM Usage (MB): 567.56
Naive Peak VRAM Usage (MB): 1878.28
Efficient Base VRAM Usage (MB): 854.87
Memory-Efficient Peak VRAM Usage (MB): 1261.72
Memory saved: 69.0%


{'naive': {'loss': 0.3326873779296875,
  'X_grad': tensor([[[ 9.1096e-08,  1.4354e-06, -1.5406e-07,  ...,  1.8944e-06,
             5.3283e-07,  1.6396e-06],
           [ 5.1661e-07, -1.4593e-06,  8.4060e-07,  ...,  8.0533e-07,
            -1.8818e-06, -6.5520e-07],
           [-3.1083e-07,  6.3797e-07,  1.2761e-07,  ..., -3.4181e-07,
             5.3196e-07, -1.5012e-06],
           ...,
           [ 4.4087e-07, -1.1070e-06, -4.0084e-07,  ...,  1.8360e-07,
             1.3163e-06, -1.0737e-06],
           [-2.0768e-09,  6.2180e-07, -3.3828e-07,  ..., -4.3182e-07,
            -2.8494e-07, -5.3675e-07],
           [-1.7479e-07,  1.5999e-07,  5.8126e-07,  ..., -8.1810e-07,
            -1.9800e-07,  4.7342e-07]],
  
          [[-3.0538e-07, -4.0315e-07, -1.8845e-07,  ..., -3.7491e-07,
            -6.4160e-07,  8.8729e-07],
           [-1.4186e-06,  8.8147e-08, -5.2613e-07,  ..., -5.2573e-07,
            -1.0772e-06, -3.0598e-07],
           [ 6.4900e-07,  1.3771e-07,  7.4652e-07,  ...,  2

In [9]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
    "expandable_segments:True," "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
)

max_seq_length = 1024
torch.set_default_dtype(torch.float16)
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    attn_implementation="sdpa",
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Get LoRA and setup model
model = get_peft_model(model, lora_config)
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

# Currently GC will cause torch.compile to be disabled, so disable it
# model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Get dataset
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files={"train": url}, split="train[:10%]")



In [10]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        warmup_steps=1,
        max_steps=10,
        logging_steps=1,
        output_dir="outputs",
        seed=3407,
        max_seq_length=max_seq_length,
        fp16=model.get_input_embeddings().weight.dtype == torch.float16,
        bf16=model.get_input_embeddings().weight.dtype == torch.bfloat16,
        report_to="none",  # For W&B
        dataset_num_proc=4,
    ),
)

trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
1,1.5199
2,2.3928
3,2.5006
4,3.5288
5,2.1371
6,2.9752
7,2.2446
8,1.6234
9,2.2185
10,2.6773


TrainOutput(global_step=10, training_loss=2.3818094968795775, metrics={'train_runtime': 7.9697, 'train_samples_per_second': 2.51, 'train_steps_per_second': 1.255, 'total_flos': 10592155496448.0, 'train_loss': 2.3818094968795775})

In [13]:
type(model.lm_head)

torch.nn.modules.linear.Linear

In [14]:
# TODO: patch the head with memory efficient linear
