In [None]:
# | default_exp utils/activation_checkpointing

# Imports

In [None]:
# | export

from typing import Callable

from torch.utils.checkpoint import checkpoint

# Main class

In [None]:
# | export


class ActivationCheckpointing:
    """Activation checkpointing levels:
    Level 0: No checkpointing
    Level 1: Single layers are checkpointed e.g. linear layer + activation, conv layer + dropout
    Level 2: Small blocks are checkpointed e.g. residual blocks, attention blocks, MLP blocks
    Level 3: Medium-sized modules are checkpointed e.g. transformer layers, decoder blocks
    Level 4: Large modules are checkpointed e.g. groups of transformer layers, decoder stages
    Level 5: Very large modules are checkpointed e.g. entire encoders, decoders etc.
    """

    def __init__(self, fn_checkpoint_level: int, training_checkpoint_level: int):
        super().__init__()

        self.perform_checkpointing = fn_checkpoint_level <= training_checkpoint_level

    def __call__(self, fn: Callable, *args, **kwargs):
        if self.perform_checkpointing:
            return checkpoint(lambda: fn(*args, **kwargs), use_reentrant=False)
        return fn(*args, **kwargs)

Test memory savings

In [None]:
from time import perf_counter

import torch
from torch import nn

In [None]:
class SampleModule(nn.Module):
    def __init__(self, checkpointing_level):
        super().__init__()

        self.sequences_level1 = nn.ModuleList(
            [nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) for I in range(100)]
        )

        self.sequences_level2 = nn.ModuleList(
            [nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) for I in range(100)]
        )

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)
        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def run_sequences(self, x):
        sequence_out = x
        for sequence_level1, sequence_level2 in zip(self.sequences_level1, self.sequences_level2):
            sequence_out = self.checkpointing_level1(sequence_level1, sequence_out)
            sequence_out = sequence_level2(sequence_out)
        return sequence_out

    def forward(self, x):
        return self.checkpointing_level2(self.run_sequences, x)

    def loss_fn(self, output):  # Arbitrary value so that we can run backward
        return output.sum()

In [None]:
sample_input = torch.randn(50000, 10, requires_grad=True).cuda()

In [None]:
torch.cuda.reset_peak_memory_stats()

model = SampleModule(2).cuda()
print("Activation checkpointing level = 2")

output = model(sample_input)

print("Memory used: ", torch.cuda.max_memory_allocated() / 2**30, "GB")

loss = model.loss_fn(output)
tic = perf_counter()
loss.backward()
toc = perf_counter()
print("Time taken for backward: ", toc - tic, 's')

del model, output

Activation checkpointing level = 2
Memory used:  0.20971155166625977 GB
Time taken for backward:  0.14360837265849113 s


  return F.linear(input, self.weight, self.bias)


In [None]:
torch.cuda.reset_peak_memory_stats()

model = SampleModule(1).cuda()
print("Activation checkpointing level = 1")

output = model(sample_input)

print("Memory used: ", torch.cuda.max_memory_allocated() / 2**30, "GB")

loss = model.loss_fn(output)
tic = perf_counter()
loss.backward()
toc = perf_counter()
print("Time taken for backward: ", toc - tic, 's')

del model, output

Activation checkpointing level = 1
Memory used:  0.6066136360168457 GB
Time taken for backward:  0.07752787880599499 s


In [None]:
torch.cuda.reset_peak_memory_stats()

model = SampleModule(0).cuda()
print("Activation checkpointing level = 0")

output = model(sample_input)

print("Memory used: ", torch.cuda.max_memory_allocated() / 2**30, "GB")

loss = model.loss_fn(output)
tic = perf_counter()
loss.backward()
toc = perf_counter()
print("Time taken for backward: ", toc - tic, 's')

del model, output

Activation checkpointing level = 0
Memory used:  0.8019261360168457 GB
Time taken for backward:  0.035065365955233574 s


# nbdev

In [10]:
!nbdev_export