# CPU offloading example

Transformer Engine offers a CPU offloading capability that can lower GPU memory usage by offloading a portion of activations into host memory. Because offload operations are overlapped with ongoing computations, you can reclaim significant GPU memory with only a minimal performance impact. To achieve full overlap, you can enable offloading for a subset of your model’s layers.

This approach is particularly advantageous on systems with high-bandwidth NVLink connections between CPU and GPU. For instance, the GB200 Grace Blackwell Superchip features two GPUs of up to 372 GB of memory in total linked to 480 GB of CPU memory via NVLink at up to 900 GB/s.


CPU Offloading in Transformer Engine can be easily integrated with any transformer training, because it supports offloading activation of all the layers, not only these provided by the Transformer Engine. For TE layers it additionally supports offloading of FP8 activations.

Our tutorial covers two scenarios:

1. A basic offloading setup, and
2. A customized offload schedule illustrating more complex use case: pipeline-parallel execution.

The tutorial was run on GB200 superchips.


## Basic offloading setup

Let's demonstrate the default CPU offloading functionality in Transformer Engine. To illustrate that CPU offloading works with any transformer layer (not just TE-specific ones), we'll create a custom layer implementation consisting of TE TransformerLayer and torch linear layer.

In [1]:
import transformer_engine.pytorch as te
import torch
import warnings
warnings.filterwarnings('ignore')

class CustomTransformerLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.torch_linear = torch.nn.Linear(512, 512).to(torch.bfloat16)
        self.te_transformer = te.TransformerLayer(512, 512, 4, params_dtype=torch.bfloat16)

    def forward(self, x):
        x = self.torch_linear(x)
        x = self.te_transformer(x)
        return x

Let's see how CPU Offloading API works:

In [2]:
# We want to offload activations of 4 out of 20 layers.
cpu_offload, sync = te.cpu_offload.get_cpu_offload_context(
    enabled=True, num_layers=4, model_layers=20, offload_activations=True)


def fwd_without_offload(x):
    with te.fp8_autocast():
        for layer in model:
            x = layer(x)
    return x

def fwd_with_offload(x):
    with te.fp8_autocast():
        # There are 2 things that need to be done in case of cpu offload:
        # - put every layer forward computation inside the cpu offload context,
        # - run synchronization function on every layer's output,
        for layer in model:
            with cpu_offload:
                x = layer(x)
            x = sync(x)
    return x

Let's compare memory usage and execution time between offloaded and non-offloaded versions:

In [3]:
import time 

x = torch.randn(((4096, 128, 512)), dtype=torch.bfloat16).cuda()
model = [CustomTransformerLayer().cuda() for _ in range(20)]


# warm-up
x = fwd_without_offload(x)
x.sum().backward()

# without offload
x = torch.randn(((4096, 128, 512)), dtype=torch.bfloat16).cuda()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
t0 = time.perf_counter()
x = fwd_without_offload(x)
x.sum().backward()
print("Memory usage without offload = ", round(torch.cuda.max_memory_allocated() / 1024 ** 3, 2), "GB")
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Time without offloading: {(t1 - t0)*1000:.2f} ms")

# offload warm-up
x = torch.randn(((4096, 128, 512)), dtype=torch.bfloat16).cuda()
x = fwd_with_offload(x)
x.sum().backward()

# with offload
x = torch.randn(((4096, 128, 512)), dtype=torch.bfloat16).cuda()
torch.cuda.reset_peak_memory_stats()
t0 = time.perf_counter()
x = fwd_with_offload(x)
x.sum().backward()
print("Memory usage with offload = ", round(torch.cuda.max_memory_allocated() / 1024 ** 3, 2), "GB")
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Time with offloading: {(t1 - t0)*1000:.2f} ms")

OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 47.39 GiB of which 98.69 MiB is free. Process 3744630 has 47.29 GiB memory in use. Of the allocated memory 46.48 GiB is allocated by PyTorch, and 402.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Let's investigate nsys profile: 

![](./offloading_trace.png)

We can see that compute and offloading/reloading is fully overlapped. 

It's worth describing the default offloading/reloading startedy we apply.
If we offload `k` out of `N` layers, we offload layers `0`, `1`, ..., `k - 1`. Layer `0 <= i <= k - 1` need to finish offloading before layer `N - k + i` starts compute. This layer starts offloading after the layer `N - (k - i)` finishes backward and it needs to finish before the backward pass of layer `i`. This stratedy minimizes memory peak in most training workflows (with preallocated gradient buffers).

There are some situations where this strategy may not be optimal and we may want to define our own. We demonstrate such case in the next section.

## Custom synchronization of CPU offload: pipeline parallelism

Some transformer training workflows are more complicated than the forward-backward scenario. For example consider pipeline parallelism. Suppose we have such a scenario on one node:

| Step | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
|------|---|---|---|---|---|---|---|---|---|---|----|----|----|----|----|----|
| Operation | 0 fwd | 1 fwd | 2 fwd | 3 fwd | 0 bwd | 4 fwd | 1 bwd | 5 fwd | 2 bwd | 6 fwd | 3 bwd | 7 fwd | 4 bwd | 5 bwd | 6 bwd | 7 bwd |

I have some idea - offload layers 1-6 with following synchronization:
- layer 1 will end offload and release memory before layer 0 starts backward, it will start reload before forward of layer 4,
- layer 2 will end offload and release memory before layer 0 starts backward, it will start reload before forward of layer 4,
- layer 3 will end offload and release memory before layer 4 starts forward, it will start reload before forward of layer 6,
- layer 4 will end offload and release memory before layer 5 starts forward, it will start reload before forward of layer 7,
- layer 5 will end offload and release memory before layer 6 starts forward, it will start reload before backward of layer 4,
- layer 6 will end offload and release memory before layer 7 starts forward, it will start reload before backward of layer 5,

To implement such scenario we can use `synchronization_dict` argument of `get_cpu_offload_context()` method. 
One needs to provide offloaded layers as a keys and tuples `(offload_fwd: bool, offload_num: int, reload_fwd: bool, reload_num: int)`.
Layer will finish offload when `offload_num` layers begins its forward/backward pass (depending on `offload_fwd` being True/False respectively).
Layer will start reload when `reload_num` layers starts its forward/backward pass (depending on `reload_fwd` being True/False respectively).

So let's create synchronization dict and see it in action.

In [9]:
synchronization_dict = {
    1: (False, 0, True, 4), 
    2: (False, 0, True, 5), 
    3: (True, 4, True, 6), 
    4: (True, 5, True, 7), 
    5: (True, 6, False, 4),  
    6: (True, 7, False, 5), 
}

cpu_offload, sync = te.cpu_offload.get_cpu_offload_context(
    enabled=True, model_layers=8, offload_activations=True, synchronization_dict=synchronization_dict
)

inp = [torch.randn((512 * 128, 2, 512), dtype=torch.bfloat16).cuda() for _ in range(8)]
out = [None] * 8

model = CustomTransformerLayer().cuda()

The following code demonstrates a simplified pipeline parallel scenario with 8 batches.
While this is not a complete pipeline parallel implementation (it only shows forward and backward
passes on a single node), it illustrates how computation and communication can overlap when
forward and backward passes are interleaved in this way. We also no not claim that provided custom scenario is optimal.

In [10]:
with te.fp8_autocast(), cpu_offload:
    out[0] = model(inp[0])
out[0] = sync(out[0])
with te.fp8_autocast(), cpu_offload:
    out[1] = model(inp[1])
out[1] = sync(out[1])
with te.fp8_autocast(), cpu_offload:
    out[2] = model(inp[2])
out[2] = sync(out[2])
with te.fp8_autocast(), cpu_offload:
    out[3] = model(inp[3])
out[3] = sync(out[3])
out[0].sum().backward()
with te.fp8_autocast(), cpu_offload:
    out[4] = model(inp[4])
out[4] = sync(out[4])
out[1].sum().backward()
with te.fp8_autocast(), cpu_offload:
    out[5] = model(inp[5])
out[5] = sync(out[5])
out[2].sum().backward()
with te.fp8_autocast(), cpu_offload:
    out[6] = model(inp[6])
out[6] = sync(out[6])
out[3].sum().backward()
with te.fp8_autocast(), cpu_offload:
    out[7] = model(inp[7])
out[7] = sync(out[7])
out[4].sum().backward()
out[5].sum().backward()
out[6].sum().backward()
out[7].sum().backward()
torch.cuda.synchronize()

Now let's see the nsys profile:

![](./offloading_trace_pp.png)

We can see that offload/reload is fully overlapped with compute.