In [1]:
import torch
import torch.nn as nn
from model import HoloConfig, HoloModel, HoloForCausalLM, \
                  HoloBlock, HoloAttentionV1, HoloAttentionV2
from transformers import PretrainedConfig
import gc
import contextlib

In [2]:
def format_params(num):
    if num >= 1_000_000_000:
        return f"{num / 1_000_000_000:.2f} B"
    elif num >= 1_000_000:
        return f"{num / 1_000_000:.2f} M"
    else:
        return f"{num / 1_000:.2f} K"

In [3]:
def get_peak_memory_mb():
    """Returns peak CUDA memory used in MB since last reset."""
    return torch.cuda.max_memory_allocated() / 1024 / 1024

def count_parameters(model):
    """Returns formatted string of parameter count."""
    return sum(p.numel() for p in model.parameters())

In [4]:
def get_detailed_stats(model, config, name):
    total_params = count_parameters(model)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Calculate Embedding Params
    embed_params = sum(p.numel() for p in model.get_input_embeddings().parameters())
    
    # Calculate Non-Embedding Params
    non_embed_params = total_params - embed_params
    
    return {
        "Name": f"Holo-{name.upper()}",
        "Layers": config.num_hidden_layers,
        "Model Dim": config.d_model,
        "Holo Dim": config.hd_dim,
        "Total Params": format_params(total_params),
        "Non-Embed Params": format_params(non_embed_params),
        "Exact Count": total_params
    }

In [5]:
print(f"{'='*85}")
print(f"{'Model':<15} | {'Layers':<6} | {'d_model':<8} | {'hd_dim':<8} | {'Total Params':<12} | {'Active Params':<12}")
print(f"{'-'*85}")

sizes = ["small", "medium", "large"]

for size in sizes:
    config = HoloConfig.from_preset(size, use_version = 2)

    try: 
        with torch.device("meta"):
            model = HoloForCausalLM(config)
    except:
        # Fallback to CPU if meta device fails (older pytorch)
        model = HoloForCausalLM(config)

    # 3. Get Stats
    stats = get_detailed_stats(model, config, size)
    print(f"{stats['Name']:<15} | {stats['Layers']:<6} | {stats['Model Dim']:<8} | {stats['Holo Dim']:<8} | {stats['Total Params']:<12} | {stats['Non-Embed Params']:<12}")

    print(f"{'='*85}")
    
print("\n* 'Active Params' excludes the vocabulary embedding matrix (which is static lookup).")

Model           | Layers | d_model  | hd_dim   | Total Params | Active Params
-------------------------------------------------------------------------------------
Holo-SMALL      | 12     | 768      | 3072     | 180.26 M     | 141.66 M    
Holo-MEDIUM     | 24     | 1024     | 8192     | 857.04 M     | 805.58 M    
Holo-LARGE      | 36     | 1280     | 10240    | 1.95 B       | 1.89 B      

* 'Active Params' excludes the vocabulary embedding matrix (which is static lookup).


In [6]:
print(f"{'='*85}")
print(f"{'Model':<15} | {'Layers':<6} | {'d_model':<8} | {'hd_dim':<8} | {'Total Params':<12} | {'Active Params':<12}")
print(f"{'-'*85}")

sizes = ["small", "medium", "large"]

for size in sizes:
    config = HoloConfig.from_preset(size, use_version = 1)

    try: 
        with torch.device("meta"):
            model = HoloForCausalLM(config)
    except:
        # Fallback to CPU if meta device fails (older pytorch)
        model = HoloForCausalLM(config)

    # 3. Get Stats
    stats = get_detailed_stats(model, config, size)
    print(f"{stats['Name']:<15} | {stats['Layers']:<6} | {stats['Model Dim']:<8} | {stats['Holo Dim']:<8} | {stats['Total Params']:<12} | {stats['Non-Embed Params']:<12}")

    print(f"{'='*85}")
    
print("\n* 'Active Params' excludes the vocabulary embedding matrix (which is static lookup).")

Model           | Layers | d_model  | hd_dim   | Total Params | Active Params
-------------------------------------------------------------------------------------
Holo-SMALL      | 12     | 768      | 3072     | 180.26 M     | 141.66 M    
Holo-MEDIUM     | 24     | 1024     | 8192     | 857.04 M     | 805.58 M    
Holo-LARGE      | 36     | 1280     | 10240    | 1.95 B       | 1.89 B      

* 'Active Params' excludes the vocabulary embedding matrix (which is static lookup).


### Profiling the Memory 

In [7]:
def profile_config(size_name, config, batch_size=4, seq_len=1024,
                   use_autocast = True, use_checkpointing = False):
    print(f"\n--- Profiling Holo-{size_name.upper()} (Ckpt={use_checkpointing}) ---")
    
    # 1. Setup Environment
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if device.type == 'cpu':
        print("❌ GPU not detected. Cannot profile CUDA memory.")
        return

    # Define the Context Manager based on the flag
    # Use bfloat16 for modern GPUs (Ampere+), otherwise fallback to float16 usually
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

    if use_autocast:
        amp_ctx = torch.amp.autocast(device_type = "cuda", dtype = dtype)
        print(f"   » Precision: Mixed ({dtype})")
    else:
        amp_ctx = contextlib.nullcontext()
        print(f"   » Precision: FP32 (Full)")     
        
    try:
        # 2. Initialize Model
        print("   Initializing model...")
        model = HoloForCausalLM(config).to(device)

        if use_checkpointing:
            model.gradient_checkpointing_enable()
            print("   » Gradient Checkpointing: ENABLED")
        
        param_count = count_parameters(model)
        model_mem = get_peak_memory_mb()
        print(f"   Model Parameters: {format_params(param_count)}")
        print(f"   Static Model VRAM: {model_mem:.2f} MB")

        # 3. Prepare Dummy Data
        input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)).to(device)
        
        # --- 4. Profile Inference (Forward Pass) ---
        torch.cuda.reset_peak_memory_stats()
        print(f"  Running Inference (B={batch_size}, L={seq_len})...")  
        model.eval() # Ensure eval mode
        with torch.no_grad():
            with amp_ctx: # <--- Applies autocast if enabled
                outputs = model(input_ids)        
                
        inference_peak = get_peak_memory_mb()
        print(f"  Peak Inference VRAM: {inference_peak:.2f} MB")
        # Cleanup inference tensors to get a clean slate for training
        del outputs
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        
        # --- 5. Profile Training (Forward + Backward) ---
        print(f"   Running Training Step (Forward + Backward)...")
        
        model.train() # Switch to train mode
        
        # Forward Pass (Inside Autocast)
        with amp_ctx: 
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss        
            
        # Backward Pass (Outside Autocast)
        # Note: For bfloat16, GradScaler is usually not needed. 
        # If using float16, you would need scaler.scale(loss).backward() here.
        loss.backward()
        
        training_peak = get_peak_memory_mb()
        print(f"  Peak Training VRAM: {training_peak:.2f} MB")
        
        # Cleanup
        del model, outputs, loss, input_ids
        gc.collect()
        torch.cuda.empty_cache()

        return {
            "size": size_name,
            "params": param_count,
            "inference_mb": inference_peak,
            "training_mb": training_peak
        }

    except torch.cuda.OutOfMemoryError:
        print(f"❌ OOM: Holo-{size_name.upper()} is too large for this GPU.")
        torch.cuda.empty_cache()
        return None
    except Exception as e:
        print(f"❌ Error: {e}")
        return None

sizes = ["small", "medium", "large"]
results = []

# Test Settings
BATCH_SIZE = 2
SEQ_LEN = 512 # Keep modest for testing

print(f"Global Settings: Batch Size={BATCH_SIZE}, Seq Len={SEQ_LEN}")

for size in sizes:
    # Load the preset config
    cfg = HoloConfig.from_preset(size, use_version = 2)
    
    # Run profiler
    res = profile_config(size, cfg, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, use_checkpointing = True)
    if res:
        results.append(res)

# Print Summary Table
print("\n" + "="*60)
print(f"{'Model Size':<10} | {'Params':<15} | {'Inference (MB)':<15} | {'Training (MB)':<15}")
print("-" * 60)
for r in results:
    print(f"{r['size'].upper():<10} | {format_params(r['params']):<15} | {r['inference_mb']:<15.2f} | {r['training_mb']:<15.2f}")
print("="*60)

Global Settings: Batch Size=2, Seq Len=512

--- Profiling Holo-SMALL (Ckpt=True) ---
   » Precision: Mixed (torch.bfloat16)
   Initializing model...
   » Gradient Checkpointing: ENABLED
   Model Parameters: 180.26 M
   Static Model VRAM: 688.54 MB
  Running Inference (B=2, L=512)...
  Peak Inference VRAM: 1173.75 MB
   Running Training Step (Forward + Backward)...
  Peak Training VRAM: 1710.31 MB

--- Profiling Holo-MEDIUM (Ckpt=True) ---
   » Precision: Mixed (torch.bfloat16)
   Initializing model...
   » Gradient Checkpointing: ENABLED
   Model Parameters: 857.04 M
   Static Model VRAM: 3289.24 MB
  Running Inference (B=2, L=512)...
  Peak Inference VRAM: 5365.47 MB
   Running Training Step (Forward + Backward)...
  Peak Training VRAM: 7212.78 MB

--- Profiling Holo-LARGE (Ckpt=True) ---
   » Precision: Mixed (torch.bfloat16)
   Initializing model...
   » Gradient Checkpointing: ENABLED
   Model Parameters: 1.95 B
   Static Model VRAM: 7540.49 MB
  Running Inference (B=2, L=512)...
 

### Version 2 for testing

In [8]:
# Test Settings
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 2
SEQ_LEN = 512 # Keep modest for testing

print(f"Global Settings: Batch Size={BATCH_SIZE}, Seq Len={SEQ_LEN}")

for size in sizes:
    # Load the preset config
    cfg = HoloConfig.from_preset(size, use_version = 2)
    
    # Run profiler
    res = profile_config(size, cfg, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, use_checkpointing = True)
    if res:
        results.append(res)

# Print Summary Table
print("\n" + "="*60)
print(f"{'Model Size':<10} | {'Params':<15} | {'Inference (MB)':<15} | {'Training (MB)':<15}")
print("-" * 60)
for r in results:
    print(f"{r['size'].upper():<10} | {format_params(r['params']):<15} | {r['inference_mb']:<15.2f} | {r['training_mb']:<15.2f}")
print("="*60)

Global Settings: Batch Size=2, Seq Len=512

--- Profiling Holo-SMALL (Ckpt=True) ---
   » Precision: Mixed (torch.bfloat16)
   Initializing model...
   » Gradient Checkpointing: ENABLED
   Model Parameters: 180.26 M
   Static Model VRAM: 707.66 MB
  Running Inference (B=2, L=512)...
  Peak Inference VRAM: 1183.75 MB
   Running Training Step (Forward + Backward)...
  Peak Training VRAM: 1709.47 MB

--- Profiling Holo-MEDIUM (Ckpt=True) ---
   » Precision: Mixed (torch.bfloat16)
   Initializing model...
   » Gradient Checkpointing: ENABLED
   Model Parameters: 857.04 M
   Static Model VRAM: 3289.24 MB
  Running Inference (B=2, L=512)...
  Peak Inference VRAM: 5365.47 MB
   Running Training Step (Forward + Backward)...
  Peak Training VRAM: 7212.78 MB

--- Profiling Holo-LARGE (Ckpt=True) ---
   » Precision: Mixed (torch.bfloat16)
   Initializing model...
   » Gradient Checkpointing: ENABLED
   Model Parameters: 1.95 B
   Static Model VRAM: 7540.49 MB
  Running Inference (B=2, L=512)...
 

In [9]:
# # Test Settings
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# BATCH_SIZE = 2
# SEQ_LEN = 512 # Keep modest for testing

# cfg = HoloConfig.from_preset(size = "small", use_version = 2)
# model = HoloForCausalLM(cfg).to(DEVICE)

# input_ids = torch.randint(0, cfg.vocab_size, (BATCH_SIZE, SEQ_LEN)).to(DEVICE)

# out = model(input_ids)
# print(out.shape)