In [1]:
import os
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoConfig
import torch
import torch.nn as nn
from ShardedFP8ModelLoader import ShardedFP8ModelLoader, FP8Format

In [2]:
model_dir = "./Meta-Llama-3.1-8B-Instruct-FP8"

In [3]:
model_loader = ShardedFP8ModelLoader(
    model_dir=model_dir,
    device_ids=[0, 1],
    memory_efficient=True,
    fp8_format=FP8Format(e4m3=True)
)


In [4]:
# First check if CUDA is available
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available")


In [5]:
# Print GPU info
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")


GPU 0: NVIDIA L4
Memory: 21.95 GB
GPU 1: NVIDIA L4
Memory: 21.95 GB


In [6]:
# Load the model
checkpoint_path = os.path.join(model_dir, "model-00001-of-00002.safetensors")
model = model_loader.load_model(checkpoint_path=checkpoint_path)


Missing keys: ['model.layers.28.input_layernorm.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.24.mlp.gate_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.27.mlp.gate_proj.weight', 'model.layers.30.mlp.up_proj.weight', 'model.layers.25.mlp.up_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.28.mlp.down_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.30.mlp.down_proj.weight', 'model.layers.29.post_attention_layernorm.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.29.mlp.gate_proj.weight', 'model.layers.26.mlp.gate_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.29.input_layernorm.weight', 'model.layers.

In [7]:
# Verify model loading
print(f"\nModel loaded successfully on devices: {[f'cuda:{i}' for i in model_loader.device_ids]}")
print(f"Model type: {type(model).__name__}")



Model loaded successfully on devices: ['cuda:0', 'cuda:1']
Model type: LlamaForCausalLM


In [8]:
# Print model statistics
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params/1e9:.2f}B")

Total parameters: 8.03B
