# GPU-vRAM Usage Estimation for Diffusion Models
## Objective
Derive an analytical equation to estimate peak vRAM usage during inference for the `stable-diffusion-v1-5/stable-diffusion-v1-5` for arbitrary input image sizes.

## Background
vRAM consumption during diffusion model inference differs significantly from model size on disk. Peak memory depends on:
 - Model weights (fixed)
 - Intermediate activations (vary with image dimensions and prompt length)
 - Framework overhead (CUDA kernels, workspace buffers)
 - Attention mechanism memory scaling (O(N²) with sequence length)

Where:
 - `H`, `W` = input image height and width
 - `prompt_length` = tokenized prompt length
 - Identify any additional factors affecting vRAM

## Requirements
 - Analyze the architecture: Understand UNet, VAE, CLIP text encoder, and how tensors flow through the pipeline
 - Account for precision: Assume `FP16` (2 bytes/parameter)
 - Model fully on GPU: Ignore pipeline.enable_model_cpu_offload() in your equation
 - Peak, not average: Find the stage with maximum memory allocation
 - Document assumptions: Clearly state what you include/exclude (e.g., gradient storage, optimizer states)

## Deliverables
 - Equation with explanation of each term
 - Derivation notes showing how you arrived at each component
 - Validation (optional but encouraged): Compare equation predictions against actual nvidia-smi measurements using the provided test code

In [None]:
import os

# Create 'data' directory if it doesn't exist
if not os.path.exists('data'):
    os.makedirs('data')

# Unzip data.zip into the 'data' directory
!unzip -o data.zip -d data

Archive:  data.zip
  inflating: data/balloon--low-res.jpeg  
  inflating: data/truck--high-res.jpg  
  inflating: data/bench--high-res.jpg  
  inflating: data/groceries--low-res.jpg  


In [None]:
# pip install torch torchvision diffusers['torch'] transformers accelerate

import torch
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import make_image_grid, load_image

pipeline = AutoPipelineForImage2Image.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipeline = pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

# Uncomment this if you have limited GPU vRAM (although, this assignment can be done without any GPU use!)
# pipeline.enable_model_cpu_offload()

# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
# pipeline.enable_xformers_memory_efficient_attention()

# prepare image
img_src = [{
    "url": "./data/balloon--low-res.jpeg",
    "prompt": "aerial view, colorful hot air balloon, lush green forest canopy, springtime, warm climate, vibrant foliage, soft sunlight, gentle shadow, white birds flying alongside, harmony, freedom, bright natural colors, serene atmosphere, highly detailed, realistic, photorealistic, cinematic lighting"
}, {
    'url': "./data/bench--high-res.jpg",
    'prompt': "photorealistic, high resolution, realistic lighting, natural shadows, detailed textures, lush green grass, wooden bench with grain detail, expansive valley, agricultural fields, blue-toned mountains, fluffy cumulus clouds, wispy cirrus clouds, bright blue sky, clear sunny day, soft sunlight, tranquil atmosphere, cinematic realism"
}, {
    'url': "./data/groceries--low-res.jpg",
    'prompt': "cartoon style, bold outlines, simplified shapes, vibrant colors, playful atmosphere, exaggerated proportions, stylized SUV trunk, whimsical paper grocery bags, fresh produce with bright highlights, baguette with cartoon detail, cheerful parking area, greenery with simplified textures, sunny day, lighthearted mood, 2D illustration, animated landscape aesthetic"
}, {
    'url': "./data/truck--high-res.jpg",
    'prompt': "Michelangelo style, Renaissance painting, classical composition, rich earthy tones, detailed brushwork, divine atmosphere, expressive lighting, monumental presence, artistic grandeur, fresco-inspired texture, high contrast shadows, timeless aesthetic"
}]

results = list()

# This for loop is meant to demonstrate that the models' vRAM usage depends
# on Image-size and prompt length (among other factors). You may observe the
# vRAM usage while the model is running by executing the following command
# in a separate terminal and monitoring the changes in vRAM usage:
#    ```shell
#    watch -n 1.0 nvidia-smi
#    ```
#
# You may modify this for loop according to your needs.
for _src in img_src:
    init_image = load_image(_src.get('url'))
    prompt = _src.get('prompt')

    # pass prompt and image to pipeline
    image = pipeline(prompt, image=init_image, guidance_scale=5.0).images[0]
    results.append(make_image_grid([init_image, image], rows=1, cols=2))

results[0].show()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

safety_checker/model.fp16.safetensors:   0%|          | 0.00/608M [00:00<?, ?B/s]

text_encoder/model.fp16.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

unet/diffusion_pytorch_model.fp16.safete(…):   0%|          | 0.00/1.72G [00:00<?, ?B/s]

vae/diffusion_pytorch_model.fp16.safeten(…):   0%|          | 0.00/167M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

## Your Task
Derive a formula:



In [None]:
def f(h: int, w: int, prompt_length: int, *,
      # Parameter counts (FP16 params)
      P_unet: float = 860e6,
      P_vae: float = 35e6,
      P_clip: float = 87e6,

      # Memory precision
      bytes_per_param: int = 2,   # FP16 = 2 bytes

      # UNet activation approximation
      C_peak: int = 1024,         # peak channels at bottleneck
      L_act: int = 4,             # effective activation layers

      # Text encoder & attention
      d_text: int = 768,
      H_att: int = 12,
      F_text: float = 3.0,        # multiplier for text hidden activations

      # System overhead
      overhead_frac: float = 0.15,
      const_buffer_bytes: int = 300 * 1024**2,

      # Optional spatial self-attention (quadratic)
      use_spatial_self_attn: bool = False,
      spatial_attn_factor: float = 0.20
     ) -> dict:
    """
    Improved but structure-preserving VRAM estimator for Stable Diffusion v1.5.
    Keeps your original API and output fields but improves correctness and clarity.
    """

    # -------------------------------------
    # Latent resolution (fixed VAE factor 8)
    # -------------------------------------
    def ceil_div(x, y):
        return (x + y - 1) // y

    lat_h = ceil_div(h, 8)
    lat_w = ceil_div(w, 8)
    N_lat = lat_h * lat_w

    b = bytes_per_param

    # -------------------------------------
    # 1. Model weights (UNet + VAE + CLIP)
    # -------------------------------------
    M_weights = b * (P_unet + P_vae + P_clip)

    # -------------------------------------
    # 2. UNet activations (spatial + channels)
    # -------------------------------------
    # Approx: N_lat * C_peak * L_act
    M_unet_acts = b * N_lat * C_peak * L_act

    # -------------------------------------
    # 3. Text encoder memory
    # -------------------------------------
    M_text = b * prompt_length * d_text * F_text

    # -------------------------------------
    # 4. Cross-attention memory (UNet ↔ text)
    # -------------------------------------
    # Score matrix (N_lat × prompt_len)
    M_att_scores = b * (N_lat * prompt_length)

    # Key/Value (2 × prompt_len × d_text)
    M_att_kv = b * (2 * prompt_length * d_text)

    # Group attention (your original logic kept it merged)
    M_attention = M_att_scores + M_att_kv * (H_att / H_att)

    # -------------------------------------
    # 5. Optional spatial self-attention (O(N_lat²))
    # -------------------------------------
    M_self_attn = (
        b * (N_lat ** 2) * spatial_attn_factor
        if use_spatial_self_attn
        else 0
    )

    # -------------------------------------
    # Base memory
    # -------------------------------------
    M_base = (
        M_weights +
        M_unet_acts +
        M_text +
        (M_attention + M_self_attn)
    )

    # -------------------------------------
    # Peak VRAM = base + overhead
    # -------------------------------------
    M_peak = M_base * (1 + overhead_frac) + const_buffer_bytes

    return {
        "lat_h": lat_h,
        "lat_w": lat_w,
        "N_lat": N_lat,

        "M_weights_bytes": int(M_weights),
        "M_unet_acts_bytes": int(M_unet_acts),
        "M_text_bytes": int(M_text),
        "M_attention_bytes": int(M_attention + M_self_attn),

        "M_base_bytes": int(M_base),
        "M_peak_bytes": int(M_peak),
        "M_peak_gib": float(M_peak) / (1024**3)
    }


## Tips
- Although no GPU is needed to accomplish this task (analyze code/architecture)
- Use PyTorch documentation and model architecture inspection

# Evaluation Criteria
- Correctness: Formula accounts for major memory consumers
- Completeness: All image-dependent and prompt-dependent factors identified
- Rigor: Derivation shows understanding of PyTorch memory model and diffusion architecture
- Clarity: Equation is readable and well-documented

In [None]:
from PIL import Image, ImageOps
import os

def test_image(image_path, prompt_length=77):
    """
    Loads an image safely, corrects orientation, converts to RGB,
    extracts (H, W), and sends them into the VRAM estimator f().
    """

    # Validate path
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

    # Load image
    img = Image.open(image_path)

    # Fix EXIF orientation if needed
    try:
        img = ImageOps.exif_transpose(img)
    except Exception:
        pass  # ignore silently (safe fallback)

    # Ensure standard RGB
    img = img.convert("RGB")

    # PIL gives (width, height)
    w, h = img.size

    # Call your estimator
    return f(h, w, prompt_length)
image_path = "/content/data/bench--high-res.jpg"
result = test_image(image_path)
result


{'lat_h': 256,
 'lat_w': 256,
 'N_lat': 65536,
 'M_weights_bytes': 1964000000,
 'M_unet_acts_bytes': 536870912,
 'M_text_bytes': 354816,
 'M_attention_bytes': 10329088,
 'M_base_bytes': 2511554816,
 'M_peak_bytes': 3202860838,
 'M_peak_gib': 2.9828966021537777}