Ever stared at a CUDA Out of Memory error and wondered where it all went wrong? Or perhaps you've noticed your GPU utilization sitting at a measly 30% while training your shiny new Transformer? Welcome to MemWall.
MemWall is a comprehensive Python library designed to help ML practitioners plan, profile, and optimize their hardware utilization. We bridge the gap between abstract model architectures and the harsh reality of hardware physics.
At the heart of optimization lies the Roofline Model, an intuitive visual model that connects computational performance, memory bandwidth, and arithmetic intensity.
What is the Ridge Point? Every piece of hardware (like your A100 or RTX 4090) has two critical ceilings:
- Peak Memory Bandwidth (GB/s): How fast can it move data?
- Peak FLOPs (TFLOP/s): How fast can it crunch numbers?
The Ridge Point is the exact ratio where these two meet, calculated simply as:
Ridge Point = Peak FLOPs / Peak Bandwidth
This single number is your North Star. By calculating the Arithmetic Intensity (FLOPs / Bytes) of your operation, you can compare it to the Ridge Point:
- Arithmetic Intensity < Ridge Point: You are Memory-Bound. Your GPU is starved for data. Adding more compute power won't help; you need faster memory or better data reuse.
- Arithmetic Intensity > Ridge Point: You are Compute-Bound. You are crunching numbers efficiently. Your bottleneck is the raw math power of the GPU.
MemWall makes understanding and visualizing this dynamic effortless.
- Transformer VRAM Estimator: Instantly compute memory needs for weights, KV cache, and activations.
- Roofline Model Calculator: Automatically classify operations as memory-bound or compute-bound.
- PyTorch Integration: Lightweight hooks to profile real-world peak memory and layer-by-layer usage.
- Optimization Advisors: Actionable recommendations for mixed precision and batch sizes.
pip install memwallTo really understand what MemWall can do, let's walk through the three core examples included in the library. Think of this as your interactive guide to solving ML performance bottlenecks.
The Question: "I want to run Llama-7B with a batch size of 4 and a sequence length of 2048. Will it fit on my GPU?"
Before you even load a single tensor or write a PyTorch script, MemWall can tell you exactly what you're getting into. The basic_estimation.py script demonstrates how to load a preset estimator for models like Llama-7B.
What it does: It mathematically estimates the exact VRAM breakdown. It doesn't just guess a single number; it breaks it down into:
- Weights Memory: The static size of the model parameters.
- KV Cache Memory: The dynamic memory required for generating tokens during inference.
- Activation Memory: The memory needed for intermediate calculations.
Result: You get a clean readout of exactly how many Gigabytes you need, saving you from trial-and-error OOM crashes.
The Question: "My training script is eating 40GB of VRAM, but my model is only 2GB. Where is the memory going?"
Sometimes math isn't enough—you need to see what PyTorch is actually doing. The pytorch_profiling.py example highlights MemWall's lightweight PyTorch hooks.
What it does:
By wrapping your model with MemWall's profile_model function and passing dummy data, the library tracks every single forward pass operation. It outputs:
- Peak Memory: The absolute highest water-mark your GPU hit during the pass.
- Layer Breakdown: A microscopic view of exactly how much memory (in MB) every single layer (
fc1,relu,fc2) consumed. - Incremental VRAM: How much additional memory was allocated beyond the base model weights.
Result: You instantly spot the specific layers that are hoarding memory, allowing you to selectively apply gradient checkpointing or optimize your architecture.
The Question: "Why is my matrix multiplication taking so long? Do I need a faster GPU, or just different hardware?"
This script brings the Ridge Point concept to life.
What it does:
- You select a target GPU (e.g.,
A100_80GB_SXM), and MemWall loads its peak constraints. - You define specific operations by their workload: how many FLOPs they execute and how many bytes they move.
- MemWall calculates the Arithmetic Intensity for each operation.
- Finally, it compares this to the GPU's Ridge Point to print out a diagnosis: Is this layer Memory-Bound or Compute-Bound? What is the predicted utilization?
As a bonus, the script automatically generates a beautiful roofline_example.png plot, mapping your operations directly against the theoretical limits of your hardware.
Result: You stop guessing why your code is slow. You get a mathematical proof pointing you toward either memory optimization or algorithmic improvements.
MemWall ships with precise hardware specifications for modern accelerators:
- NVIDIA A100 (80GB/40GB SXM & PCIe)
- NVIDIA H100
- NVIDIA RTX 4090 / 3090
- ...or define your own custom peak FLOPs and Bandwidth!
MIT License. Built for the community, by the community.