In [1]:
# alpha_probe.py  -- run in the same venv as Axolotl
import os, gc, torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Match your YAML
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
USE_BF16 = True
USE_GC   = True   # gradient checkpointing
B        = 4      # micro_batch_size
S1, S2   = 1536, 2048   # two lengths inside your 2048 cap

# Make sure we use flash-attn if installed (memory is then linear in s)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

device = "cuda"
dtype  = torch.bfloat16 if USE_BF16 else torch.float16

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    device_map={"":0}  # single GPU
).train()

# Match Axolotl's training toggles
if USE_GC:
    model.gradient_checkpointing_enable()
# ensure no KV cache is kept during training
if hasattr(model.config, "use_cache"):
    model.config.use_cache = False
# prefer flash-attn v2 if available in this Transformers build
try:
    model.config.attn_implementation = "flash_attention_2"
except Exception:
    pass

def peak_bytes_for(seq_len: int) -> int:
    gc.collect(); torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
    # random tokens; labels=input_ids triggers full LM loss/backward like SFT
    x = torch.randint(low=0, high=tok.vocab_size, size=(B, seq_len), device=device)
    out = model(input_ids=x, attention_mask=torch.ones_like(x), labels=x)
    loss = out.loss
    loss.backward()
    torch.cuda.synchronize()
    peak = torch.cuda.max_memory_allocated()
    # cleanup grads to avoid accumulation
    model.zero_grad(set_to_none=True)
    gc.collect(); torch.cuda.empty_cache()
    return peak

bytes1 = peak_bytes_for(S1)
bytes2 = peak_bytes_for(S2)

# Read actual model dims so we don't hardcode L,H
cfg = model.config
L = getattr(cfg, "num_hidden_layers", None)
H = getattr(cfg, "hidden_size", None)
assert L and H, "Could not read num_hidden_layers/hidden_size"

bytes_per_elem = 2 if USE_BF16 else 2  # both bf16/fp16 are 2 bytes

# α from slope:  peak ≈ const + (B * s * L * H * α * bytes_per_elem)
slope_bytes_per_token = (bytes2 - bytes1) / (S2 - S1)
alpha = slope_bytes_per_token / (B * L * H * bytes_per_elem)

gb = 1024**3
print(f"Peak @ s={S1}: {bytes1/gb:.2f} GB, @ s={S2}: {bytes2/gb:.2f} GB")
print(f"L={L}, H={H}, B={B}, dtype bytes={bytes_per_elem}")
print(f"alpha ≈ {alpha:.3f}  (BF16, GC={'on' if USE_GC else 'off'})")

Fetching 4 files: 100%|██████████| 4/4 [00:15<00:00,  3.87s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


Peak @ s=1536: 31.84 GB, @ s=2048: 32.88 GB
L=32, H=4096, B=4, dtype bytes=2
alpha ≈ 2.089  (BF16, GC=on)
