# Activation checkpointing

Notebook exploring how much memory activation checkpointing saves under different setups, e.g. different model sizes, different sequence lengths, checkpointing different layers, etc.

In [13]:
from dataclasses import dataclass
from collections import OrderedDict
import itertools

from rich.console import Console
from rich.table import Table


@dataclass
class ModelConfig:
    n_ctx: int = 1024
    n_layer: int = 12
    n_head: int = 12
    d_model: int = 768
    d_mlp: int = 4 * 768
    vocab_size: int = 50257
    ln_bias: bool = False
    mlp_bias: bool = False
    share_embd_params: bool = True


MODEL_CONFIG_ARGS = {
    # 14M params
    "gpt2-tiny": ModelConfig(
        n_ctx=128,
        n_layer=2,
        n_head=4,
        d_model=256,
        d_mlp=4 * 256,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 124M params
    "gpt2": ModelConfig(
        n_ctx=1024,
        n_layer=12,
        n_head=12,
        d_model=768,
        d_mlp=4 * 768,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 350M params
    "gpt2-medium": ModelConfig(
        n_ctx=1024,
        n_layer=24,
        n_head=16,
        d_model=1024,
        d_mlp=4 * 1024,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 774M params
    "gpt2-large": ModelConfig(
        n_ctx=1024,
        n_layer=36,
        n_head=20,
        d_model=1280,
        d_mlp=4 * 1280,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 1558M params
    "gpt2-xl": ModelConfig(
        n_ctx=1024,
        n_layer=48,
        n_head=25,
        d_model=1600,
        d_mlp=4 * 1600,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
}


def load_config(name: str) -> ModelConfig:
    assert name in MODEL_CONFIG_ARGS
    return MODEL_CONFIG_ARGS[name]


# TODO Change this as desired
model_cfg = load_config("gpt2")

In [14]:
def get_params(cfg: ModelConfig, checkpoint_layers: set[str] | None = None):
    """Estimates the number of parameters in the model."""
    out = OrderedDict()
    if checkpoint_layers is None:
        checkpoint_layers = set()

    # token and position embeddings
    if "embedding" not in checkpoint_layers:
        out["embedding/position"] = cfg.n_ctx * cfg.d_model
        out["embedding/token"] = cfg.vocab_size * cfg.d_model
        out["embedding"] = out["embedding/position"] + out["embedding/token"]
    else:
        out["embedding"] = 0

    # attention blocks
    if "attention" not in checkpoint_layers:
        out["attention/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
        out["attention/kqv"] = cfg.d_model * 3 * cfg.d_model
        out["attention/proj"] = cfg.d_model**2
        out["attention"] = (
            out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]
        )
    else:
        out["attention"] = 0

    # MLP blocks
    if "mlp" not in checkpoint_layers:
        out["mlp/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
        out["mlp/ffw"] = cfg.d_model * cfg.d_mlp + int(cfg.ln_bias) * cfg.d_mlp
        out["mlp/proj"] = cfg.d_mlp * cfg.d_model + int(cfg.ln_bias) * cfg.d_model
        out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]
    else:
        out["mlp"] = 0

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["ln_f"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model  # final layernorm
    if cfg.share_embd_params:
        # 0 because of parameter sharing. This layer uses the weights from the embedding layer
        out["out_embedding"] = 0
    else:
        out["out_embedding"] = cfg.d_model * cfg.vocab_size

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["ln_f"] + out["out_embedding"]
    )

    return out

In [22]:
possible_checkpoints = ["embedding", "attention", "mlp"]

base_params = get_params(model_cfg)["total"]
n_ctx = model_cfg.n_ctx

params_by_checkpoint_policy = {
    "None": base_params,
}
# iterate through all possible permutation of checkpointing,
# e.g. [["embedding"], ["embedding", "attention"], ["embedding", "attention", "mlp"]]
for checkpoint_combo in itertools.chain.from_iterable(
    itertools.combinations(possible_checkpoints, i)
    for i in range(1, len(possible_checkpoints) + 1)
):
    params = get_params(model_cfg, set(checkpoint_combo))["total"]
    params_by_checkpoint_policy[str(checkpoint_combo)] = params


total_base_params = base_params + (base_params * n_ctx * 2)
base_memory_gb = total_base_params / 1024**3

table_headers = [
    "Checkpoint Policy",
    "Forward Params Stored",
    "Total Params Stored",
    "Total Memory (GB)",
    "Params Savings",
    "Memory Savings (%)",
]
table_rows = []
for checkpoint_policy, forward_params in params_by_checkpoint_policy.items():
    # Assume we are using 16-bit precision (i.e. 2 bytes per parameter)
    # Checkpointing still stores the full model + backwards activations, just
    # not some forward activations.
    n_forward_activations = forward_params * n_ctx
    n_backward_activations = base_params * n_ctx
    total_params_stored = base_params + ((base_params + forward_params) * n_ctx * 2)
    total_memory_bytes = total_params_stored * 2
    total_memory_gb = total_memory_bytes / (1024**3)
    params_savings = (total_base_params - forward_params) / total_base_params * 100
    memory_savings = (total_memory_gb - base_memory_gb) / total_memory_gb * 100
    table_rows.append(
        [
            checkpoint_policy,
            str(forward_params),
            str(total_params_stored),
            f"{total_memory_gb:.2f}",
            f"{params_savings:.2f}%",
            f"{memory_savings:.2f}%",
        ]
    )

table = Table(title="Activation Checkpointing")
for header in table_headers:
    table.add_column(header)
for row in table_rows:
    table.add_row(*row)

console = Console()
console.print(table)
