In [1]:
import os
import time
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoConfig
import torch
import torch.nn as nn
from ShardedFP8ModelLoader import ShardedFP8ModelLoader, FP8Format

In [2]:
model_dir = "./Meta-Llama-3.1-8B-Instruct-FP8"

In [3]:
model_loader = ShardedFP8ModelLoader(
    model_dir=model_dir,
    device_ids=[0, 1],
    memory_efficient=True,
    fp8_format=FP8Format(e4m3=True)
)

In [4]:
# Check if CUDA is available
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available")

In [5]:
# Print GPU info
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")

GPU 0: NVIDIA L4
Memory: 21.95 GB
GPU 1: NVIDIA L4
Memory: 21.95 GB


In [6]:
# Gather all shard files
shard_files = [f for f in os.listdir(model_dir) if f.startswith("model-") and f.endswith(".safetensors")]
shard_files = sorted(shard_files)
print(f"Found shard files: {shard_files}")

Found shard files: ['model-00001-of-00002.safetensors', 'model-00002-of-00002.safetensors']


In [7]:
# Load all shards into a single dictionary
all_weights = {}
for sf in shard_files:
    shard_path = os.path.join(model_dir, sf)
    print(f"Loading shard: {shard_path}")
    shard_weights = load_file(shard_path)
    all_weights.update(shard_weights)

Loading shard: ./Meta-Llama-3.1-8B-Instruct-FP8/model-00001-of-00002.safetensors
Loading shard: ./Meta-Llama-3.1-8B-Instruct-FP8/model-00002-of-00002.safetensors


In [8]:
# Measure the loading time
start_time = time.time()
model = model_loader.load_model_from_weights(all_weights)
end_time = time.time()

Model distributed across 2 GPUs


In [9]:
# Verify model loading
print(f"\nModel loaded successfully on devices: {[f'cuda:{i}' for i in model_loader.device_ids]}")
print(f"Model type: {type(model).__name__}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
print(f"Model loaded in {end_time - start_time:.2f} seconds")


Model loaded successfully on devices: ['cuda:0', 'cuda:1']
Model type: LlamaForCausalLM
Total parameters: 8.03B
Model loaded in 119.61 seconds


In [10]:
print("Model keys example:", list(model.state_dict().keys())[:50])
print("Checkpoint keys example:", list(all_weights.keys())[:50])

Model keys example: ['model.embed_tokens.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weig