In [None]:
# | default_exp utils/activation_checkpointing

# Imports

In [None]:
# | export

from collections.abc import Callable

from torch import nn
from torch.utils.checkpoint import checkpoint

# Main class

In [None]:
# | export


class ActivationCheckpointing(nn.Module):
    """This class is used to perform activation checkpointing during training. Users can set a level of checkpointing
    for each module / function in their architecture. While training, the module / function will be checkpointed if the
    training checkpoint level is greater than or equal to the checkpoint level set for the module / function.

    A general guide of the Activation checkpointing levels in this repository:

    - **Level 0**: No checkpointing
    - **Level 1**: Single layers are checkpointed e.g. linear layer + activation, conv layer + dropout
    - **Level 2**: Small blocks are checkpointed e.g. residual blocks, attention blocks, MLP blocks
    - **Level 3**: Medium-sized modules are checkpointed e.g. transformer layers, decoder blocks
    - **Level 4**: Large modules are checkpointed e.g. groups of transformer layers, decoder stages
    - **Level 5**: Very large modules are checkpointed e.g. entire encoders, decoders etc.
    """

    def __init__(self, fn_checkpoint_level: int, training_checkpoint_level: int):
        """Initialize the ActivationCheckpointing class.

        Args:
            fn_checkpoint_level: Level at which the module / function should be checkpointed
            training_checkpoint_level: Checkpointing level at which the model is being trained

        Example:
            .. code-block:: python

                class MyModel(nn.Module):
                    def __init__(self, training_checkpointing_level: int = 0):
                        super().__init__()
                        my_network = nn.Sequential(
                            nn.Linear(784, 256),
                            nn.ReLU(),
                            nn.Linear(256, 10)
                        )

                        self.activation_checkpointing_level2 = ActivationCheckpointing(2, training_checkpointing_level)

                    def forward(self, x):
                        y = self.activation_checkpointing_level2(self.my_network, x)
                        return y

            In this example, a ``training_checkpointing_level`` of greater than or equal to 2 will checkpoint ``my_network``
            during training. If it's less than 2, the network will not be checkpointed.
        """
        super().__init__()

        self.fn_checkpoint_level = fn_checkpoint_level
        self.training_checkpoint_level = training_checkpoint_level

        self.perform_checkpointing = fn_checkpoint_level <= training_checkpoint_level

    def __call__(self, fn: Callable, *fn_args, use_reentrant: bool = False, **fn_kwargs):
        """Checkpoint the module / function if the checkpointing level is greater than or equal to the training
        checkpoint level.

        Args:
            fn: The module / function to checkpoint
            use_reentrant: Passed on to torch.utils.checkpoint.checkpoint. Defaults to False.
            *fn_args: Arguments to pass to the module / function
            **fn_kwargs: Keyword arguments to pass to the module / function

        Returns:
            The checkpointed module / function if checkpointing is performed, else the module / function itself.
        """
        if self.training and self.perform_checkpointing:
            return checkpoint(lambda: fn(*fn_args, **fn_kwargs), use_reentrant=use_reentrant)
        return fn(*fn_args, **fn_kwargs)

    def extra_repr(self):
        return f"enabled={self.perform_checkpointing}, checkpointing_level={self.fn_checkpoint_level}"

Test memory savings

In [None]:
from time import perf_counter

import torch
from torch import nn

In [None]:
class SampleModule(nn.Module):
    def __init__(self, checkpointing_level):
        super().__init__()

        self.sequences_level1 = nn.ModuleList(
            [nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) for _ in range(100)]
        )

        self.sequences_level2 = nn.ModuleList(
            [nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) for _ in range(100)]
        )

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)
        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def run_sequences(self, x):
        sequence_out = x
        for sequence_level1, sequence_level2 in zip(self.sequences_level1, self.sequences_level2):
            sequence_out = self.checkpointing_level1(sequence_level1, sequence_out)
            sequence_out = sequence_level2(sequence_out)
        return sequence_out

    def forward(self, x):
        return self.checkpointing_level2(self.run_sequences, x)

    def loss_fn(self, output):  # Arbitrary value so that we can run backward
        return output.sum()

In [None]:
sample_input = torch.randn(50000, 10, requires_grad=True).cuda()

In [None]:
torch.cuda.reset_peak_memory_stats()

model = SampleModule(2).cuda()
print("Activation checkpointing level = 2")

output = model(sample_input)

print("Memory used: ", torch.cuda.max_memory_allocated() / 2**30, "GB")

loss = model.loss_fn(output)
tic = perf_counter()
loss.backward()
toc = perf_counter()
print("Time taken for backward: ", toc - tic, "s")

del model, output

Activation checkpointing level = 2
Memory used:  0.2188282012939453 GB
Time taken for backward:  0.08570726797915995 s


In [None]:
torch.cuda.reset_peak_memory_stats()

model = SampleModule(1).cuda()
print("Activation checkpointing level = 1")

output = model(sample_input)

print("Memory used: ", torch.cuda.max_memory_allocated() / 2**30, "GB")

loss = model.loss_fn(output)
tic = perf_counter()
loss.backward()
toc = perf_counter()
print("Time taken for backward: ", toc - tic, "s")

del model, output

Activation checkpointing level = 1
Memory used:  0.6066136360168457 GB
Time taken for backward:  0.07214503700379282 s


In [None]:
torch.cuda.reset_peak_memory_stats()

model = SampleModule(0).cuda()
print("Activation checkpointing level = 0")

output = model(sample_input)

print("Memory used: ", torch.cuda.max_memory_allocated() / 2**30, "GB")

loss = model.loss_fn(output)
tic = perf_counter()
loss.backward()
toc = perf_counter()
print("Time taken for backward: ", toc - tic, "s")

del model, output

Activation checkpointing level = 0
Memory used:  0.8019261360168457 GB
Time taken for backward:  0.03969525604043156 s


# nbdev

In [None]:
!nbdev_export