# GPU-vRAM Usage Estimation for Diffusion Models
## Objective
Derive an analytical equation to estimate peak vRAM usage during inference for the `stable-diffusion-v1-5/stable-diffusion-v1-5` for arbitrary input image sizes.

## Background
vRAM consumption during diffusion model inference differs significantly from model size on disk. Peak memory depends on:
 - Model weights (fixed)
 - Intermediate activations (vary with image dimensions and prompt length)
 - Framework overhead (CUDA kernels, workspace buffers)
 - Attention mechanism memory scaling (O(N²) with sequence length)

Where:
 - `H`, `W` = input image height and width
 - `prompt_length` = tokenized prompt length
 - Identify any additional factors affecting vRAM

## Requirements
 - Analyze the architecture: Understand UNet, VAE, CLIP text encoder, and how tensors flow through the pipeline
 - Account for precision: Assume `FP16` (2 bytes/parameter)
 - Model fully on GPU: Ignore pipeline.enable_model_cpu_offload() in your equation
 - Peak, not average: Find the stage with maximum memory allocation
 - Document assumptions: Clearly state what you include/exclude (e.g., gradient storage, optimizer states)

## Deliverables
 - Equation with explanation of each term
 - Derivation notes showing how you arrived at each component
 - Validation (optional but encouraged): Compare equation predictions against actual nvidia-smi measurements using the provided test code

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [11]:
# pip install torch torchvision diffusers['torch'] transformers accelerate

import torch
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import make_image_grid, load_image

pipeline = AutoPipelineForImage2Image.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
# pipeline = AutoPipelineForImage2Image.from_pretrained(
#     "stable-diffusion-v1-5/stable-diffusion-v1-5",
#     torch_dtype=torch.float32,   # CPU requires FP32
#     use_safetensors=True
# ).to("cpu")

pipeline = pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

# Uncomment this if you have limited GPU vRAM (although, this assignment can be done without any GPU use!)
# pipeline.enable_model_cpu_offload()

# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
# pipeline.enable_xformers_memory_efficient_attention()

# prepare image
img_src = [{
    "url": "/content/balloon--low-res.jpeg",
    "prompt": "aerial view, colorful hot air balloon, lush green forest canopy, springtime, warm climate, vibrant foliage, soft sunlight, gentle shadow, white birds flying alongside, harmony, freedom, bright natural colors, serene atmosphere, highly detailed, realistic, photorealistic, cinematic lighting"
}, {
    'url': "/content/bench--high-res.jpg",
    'prompt': "photorealistic, high resolution, realistic lighting, natural shadows, detailed textures, lush green grass, wooden bench with grain detail, expansive valley, agricultural fields, blue-toned mountains, fluffy cumulus clouds, wispy cirrus clouds, bright blue sky, clear sunny day, soft sunlight, tranquil atmosphere, cinematic realism"
}, {
    'url': "/content/groceries--low-res.jpg",
    'prompt': "cartoon style, bold outlines, simplified shapes, vibrant colors, playful atmosphere, exaggerated proportions, stylized SUV trunk, whimsical paper grocery bags, fresh produce with bright highlights, baguette with cartoon detail, cheerful parking area, greenery with simplified textures, sunny day, lighthearted mood, 2D illustration, animated landscape aesthetic"
}, {
    'url': "/content/truck--high-res.jpg",
    'prompt': "Michelangelo style, Renaissance painting, classical composition, rich earthy tones, detailed brushwork, divine atmosphere, expressive lighting, monumental presence, artistic grandeur, fresco-inspired texture, high contrast shadows, timeless aesthetic"
}]

results = list()

# This for loop is meant to demonstrate that the models' vRAM usage depends
# on Image-size and prompt length (among other factors). You may observe the
# vRAM usage while the model is running by executing the following command
# in a separate terminal and monitoring the changes in vRAM usage:
#    ```shell
#    watch -n 1.0 nvidia-smi
#    ```
#
# You may modify this for loop according to your needs.
for _src in img_src:
    init_image = load_image(_src.get('url'))
    prompt = _src.get('prompt')

    # pass prompt and image to pipeline
    image = pipeline(prompt, image=init_image, guidance_scale=5.0).images[0]
    results.append(make_image_grid([init_image, image], rows=1, cols=2))

results[0].show()

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

## Your Task
Derive a formula:

In [8]:
##################### My Implementation #####################
def f(h: int, w: int, prompt_length: int, **kwargs):
    """
    Estimate peak VRAM usage for Stable Diffusion v1.x inference (FP16).
    This follows a slightly simplified analytic approach based on
    UNet/VAE/CLIP sizes and attention scaling.
    """

    # ---------------------- basic configurable knobs ----------------------
    # Most users won't change these, but exposing them makes the function reusable.
    B = kwargs.get("bytes_per_scalar", 2)             # FP16 → 2 bytes
    f = kwargs.get("vae_downsample", 8)               # SD uses 8× downsample for latents
    P_unet = kwargs.get("unet_params", 860e6)         # typical SD UNet param count
    P_clip = kwargs.get("clip_params", 123e6)         # CLIP params
    P_vae = kwargs.get("vae_params", 35e6)            # VAE params (approx)
    C0 = kwargs.get("base_channels", 320)             # UNet base channel count
    level_muls = kwargs.get("channel_muls", [1, 2, 4, 8])  # UNet multipliers per stage
    include_attn_n2 = kwargs.get("include_attn_n2", True)  # whether to include n^2 attention
    k_overhead = kwargs.get("overhead_factor", 0.12)        # overhead for CUDA buffers, etc.
    D_text = kwargs.get("text_dim", 768)                     # CLIP embedding dim
    batch = kwargs.get("batch", 1)

    # ---------------------- latent resolution ----------------------
    # images get compressed to (H/8 x W/8) by the VAE encoder.
    h_l = max(1, h // f)
    w_l = max(1, w // f)
    S = h_l * w_l       # number of spatial tokens in latent
    L = prompt_length   # number of CLIP tokens processed

    # ---------------------- model weights (always resident) ----------------------
    # straightforward: total FP16 parameter memory
    mem_weights = B * (P_unet + P_clip + P_vae)

    # ---------------------- UNet stored activations ----------------------
    # each downsampling stage reduces spatial dims by 2 → S / 4^i
    S_lvls = [max(1, S // (4**i)) for i in range(len(level_muls))]
    # channel count increases with depth (e.g., 320, 640, 1280, 2560)
    C_lvls = [C0 * m for m in level_muls]

    # sum of feature map sizes that stay alive due to skip connections
    act_scalars = 0
    for c, s_ in zip(C_lvls, S_lvls):
        act_scalars += batch * c * s_
    mem_acts = B * act_scalars

    # ---------------------- attention temporary buffers ----------------------
    # for each level, estimate the temporary memory needed in attention ops:
    # Q,K,V,output projections → around 4 * d * n
    # cross-attn (n * L)
    # optional self-attn (n^2) – large but depends on implementation
    att_sizes = []
    for c, s_ in zip(C_lvls, S_lvls):
        n = batch * s_
        d = c
        qkv_out = 4 * d * n
        cross = n * L
        self_mat = n * n if include_attn_n2 else 0
        att_sizes.append(B * (qkv_out + cross + self_mat))

    mem_attn_peak = max(att_sizes) if att_sizes else 0

    # ---------------------- CLIP activations ----------------------
    # small relative to UNet, but included for completeness
    mem_clip = B * (L * D_text + L * L)

    # ---------------------- VAE activations ----------------------
    # rough upper bound: a few full-res conv maps during decoding
    vae_feat_ch = kwargs.get("vae_act_channels", 64)
    mem_vae = B * batch * vae_feat_ch * h * w

    # ---------------------- sum before overhead ----------------------
    mem_pre = mem_weights + mem_acts + mem_attn_peak + mem_clip + mem_vae

    # ---------------------- framework overhead ----------------------
    mem_total = mem_pre * (1 + k_overhead)

    # return bytes + GB for readability along with breakdown
    return {
        "bytes": mem_total,
        "GB": mem_total / (1024**3),
        "breakdown": {
            "weights_bytes": mem_weights,
            "activations_bytes": mem_acts,
            "attention_peak_bytes": mem_attn_peak,
            "clip_bytes": mem_clip,
            "vae_bytes": mem_vae
        }
    }


In [4]:
print(f(512, 512, 77))


{'bytes': 2373583226.5600004, 'GB': 2.2105716416239742, 'breakdown': {'weights_bytes': 2036000000.0, 'activations_bytes': 4915200, 'attention_peak_bytes': 44670976, 'clip_bytes': 130130, 'vae_bytes': 33554432}}


In [5]:
print(f(256, 256, 50))


{'bytes': 2296582624.0, 'GB': 2.1388592422008514, 'breakdown': {'weights_bytes': 2036000000.0, 'activations_bytes': 1228800, 'attention_peak_bytes': 4820992, 'clip_bytes': 81800, 'vae_bytes': 8388608}}


In [6]:
print(f(4, 4, 50))

{'bytes': 2280447713.6000004, 'GB': 2.1238324359059337, 'breakdown': {'weights_bytes': 2036000000.0, 'activations_bytes': 9600, 'attention_peak_bytes': 20582, 'clip_bytes': 81800, 'vae_bytes': 2048}}


## Tips
- Although no GPU is needed to accomplish this task (analyze code/architecture)
- Use PyTorch documentation and model architecture inspection

# Evaluation Criteria
- Correctness: Formula accounts for major memory consumers
- Completeness: All image-dependent and prompt-dependent factors identified
- Rigor: Derivation shows understanding of PyTorch memory model and diffusion architecture
- Clarity: Equation is readable and well-documented