In [1]:
import os
import time
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoConfig
import torch
import torch.nn as nn
from ShardedFP8ModelLoader import ShardedFP8ModelLoader, FP8Format

In [2]:
model_dir = "./Meta-Llama-3.1-8B-Instruct-FP8"

In [3]:
model_loader = ShardedFP8ModelLoader(
    model_dir=model_dir,
    device_ids=[0, 1],
    memory_efficient=True,
    fp8_format=FP8Format(e4m3=True)
)

In [4]:
# Check if CUDA is available
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available")

In [5]:
# Print GPU info
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")

GPU 0: NVIDIA L4
Memory: 21.95 GB
GPU 1: NVIDIA L4
Memory: 21.95 GB


In [6]:
# Gather all shard files
shard_files = [f for f in os.listdir(model_dir) if f.startswith("model-") and f.endswith(".safetensors")]
shard_files = sorted(shard_files)
print(f"Found shard files: {shard_files}")

Found shard files: ['model-00001-of-00002.safetensors', 'model-00002-of-00002.safetensors']


In [7]:
# Load all shards into a single dictionary
all_weights = {}
for sf in shard_files:
    shard_path = os.path.join(model_dir, sf)
    print(f"Loading shard: {shard_path}")
    shard_weights = load_file(shard_path)
    all_weights.update(shard_weights)

Loading shard: ./Meta-Llama-3.1-8B-Instruct-FP8/model-00001-of-00002.safetensors
Loading shard: ./Meta-Llama-3.1-8B-Instruct-FP8/model-00002-of-00002.safetensors


In [8]:
# Measure the loading time
start_time = time.time()
model = model_loader.load_model_from_weights(all_weights)
end_time = time.time()

Model distributed across 2 GPUs


In [9]:
# Verify model loading
print(f"\nModel loaded successfully on devices: {[f'cuda:{i}' for i in model_loader.device_ids]}")
print(f"Model type: {type(model).__name__}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
print(f"Model loaded in {end_time - start_time:.2f} seconds")


Model loaded successfully on devices: ['cuda:0', 'cuda:1']
Model type: LlamaForCausalLM
Total parameters: 8.03B
Model loaded in 119.80 seconds


In [10]:
# print("Model keys example:", list(model.state_dict().keys())[:50])
# print("Checkpoint keys example:", list(all_weights.keys())[:50])

In [None]:
import os
import time
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoConfig
import torch
from ShardedFP8ModelLoader import ShardedFP8ModelLoader, FP8Format

torch.cuda.empty_cache()
torch.cuda.synchronize()
print("Cleared GPU memory cache.")

# Path to model directory
model_dir = "./Meta-Llama-3.1-8B-Instruct-FP8"

# Initialize model loader
model_loader = ShardedFP8ModelLoader(
    model_dir=model_dir,
    device_ids=[0, 1],
    memory_efficient=True,
    fp8_format=FP8Format(e4m3=True)
)

# Check CUDA availability
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available")

# Print GPU info and initial memory usage
print("\nGPU Information and Initial Memory Usage:")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    total_mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
    print(f"Total Memory: {total_mem:.2f} GB")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
    print(f"Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")

# Gather shard files
shard_files = [f for f in os.listdir(model_dir) if f.startswith("model-") and f.endswith(".safetensors")]
shard_files = sorted(shard_files)
print(f"\nFound shard files: {shard_files}")

# Load shards into a dictionary
all_weights = {}
for sf in shard_files:
    shard_path = os.path.join(model_dir, sf)
    print(f"Loading shard: {shard_path}")
    shard_weights = load_file(shard_path)
    all_weights.update(shard_weights)

# Measure loading time and memory usage
start_time = time.time()
torch.cuda.synchronize()

print("\nMemory Usage Before Model Loading:")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
    print(f"GPU {i} Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")

# Load the model
model = model_loader.load_model_from_weights(all_weights)

torch.cuda.synchronize()
end_time = time.time()

# Print post-loading memory usage
print("\nMemory Usage After Model Loading:")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
    print(f"GPU {i} Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")

# Final model summary
print(f"\nModel loaded successfully on devices: {[f'cuda:{i}' for i in model_loader.device_ids]}")
print(f"Model type: {type(model).__name__}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
print(f"Model loaded in {end_time - start_time:.2f} seconds")


In [None]:
import os
import time
import torch
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM

# Utility function to display memory usage
def print_gpu_memory_usage():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
        print(f"GPU {i} Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")

# Clear GPU cache and set environment variable
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Path to model directory
model_dir = "./Meta-Llama-3.1-8B-Instruct-FP8"

# Display initial memory usage
print("=== Initial GPU Memory Usage ===")
print_gpu_memory_usage()

# Load model with mixed precision and auto device map
print("\nInitializing the model with mixed precision and device map...")
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    torch_dtype=torch.float16,
    device_map="auto"
)
torch.cuda.synchronize()
end_time = time.time()

print("\n=== GPU Memory Usage After Model Initialization ===")
print_gpu_memory_usage()
print(f"Model initialization time: {end_time - start_time:.2f} seconds")

# Gather shard files
shard_files = [f for f in os.listdir(model_dir) if f.startswith("model-") and f.endswith(".safetensors")]
shard_files = sorted(shard_files)
print(f"\nFound shard files: {shard_files}")

# Load and apply shards sequentially
print("\nLoading and applying shards...")
for i, sf in enumerate(shard_files):
    shard_path = os.path.join(model_dir, sf)
    print(f"\nLoading shard {i+1}/{len(shard_files)}: {shard_path}")

    # Load shard
    start_shard_time = time.time()
    shard_weights = load_file(shard_path)
    
    model.load_state_dict(shard_weights, strict=False)
    torch.cuda.synchronize()
    end_shard_time = time.time()

    # Print memory usage after loading the shard
    print(f"Shard {i+1} loaded in {end_shard_time - start_shard_time:.2f} seconds")
    print("GPU Memory Usage After Shard Loading:")
    print_gpu_memory_usage()

    # Free up memory
    del shard_weights
    torch.cuda.empty_cache()

# Final memory usage
print("\n=== Final GPU Memory Usage After All Shards Loaded ===")
print_gpu_memory_usage()
print("\nModel loaded successfully!")
