In [None]:
import os
import copy
import json
import torch
from transformers import AutoModelForCausalLM, AutoConfig
from safetensors.torch import save_file
from accelerate import init_empty_weights

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
llama_8b_config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
llama_70b_config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct")

# nemotron_config = AutoConfig.from_pretrained("nvidia/Nemotron-4-340B-Instruct")

In [None]:
llama_8b_config

LlamaConfig {
  "_name_or_path": "meta-llama/Meta-Llama-3-8B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128009,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.42.4",
  "use_cache": true,
  "vocab_size": 128256
}

In [None]:
# nemotron 340B
hidden_size = 18432
ffn_hidden_size = 73728
num_layers = 96
num_attention_heads = 96

In [None]:
def estimate_param(dim70b, dim8b, nearest_multiple=None, mult_factor=1, model_size=420):
	x = dim70b * (model_size / 70) * (dim70b / dim8b) * (8 / 70)
	x *= mult_factor
	if nearest_multiple is not None:
		return nearest_multiple * round(x / nearest_multiple)

In [None]:
def esimate_num_hidden_layers(model):
	# count total params in a single decoder layer and estimate number of layers
	decoder_total_params = sum(p.numel() for p in model.model.layers[0].parameters())
	embed_lm_head_total_params = model.model.embed_tokens.weight.numel() + model.lm_head.weight.numel()
	num_estimated_hidden_layers = int((TARGET_MODEL_SIZE - embed_lm_head_total_params / 1e9)  / (decoder_total_params / 1e9))
	# nearest multiple of 8
	num_estimated_hidden_layers = 8 * round(num_estimated_hidden_layers / 8)
	return num_estimated_hidden_layers

In [None]:
def estimate_4bit_qdora_layer_mem(model):
    # Estimate QDoRA decoder layer size
    GROUP_SIZE = 128
    LORA_RANK = 64

    tot_mem_in_gb = 0
    for n,p in model.model.layers[0].named_parameters():
        if "proj" in n:	
            quant_mem = (p.numel() / 4) * 2 / 1e9
            quant_zero_scale_mem = 2 * (p.numel() / GROUP_SIZE) * 2 / 1e9    
            lora_ab_mem = (p.size(0) * LORA_RANK + p.size(1) * LORA_RANK) * 2 / 1e9
            layer_mem = quant_mem + quant_zero_scale_mem + lora_ab_mem
        else:
            layer_mem = p.numel() * 2 / 1e9
        tot_mem_in_gb += layer_mem
    return tot_mem_in_gb

def estimate_mixed_bit_qdora_layer_mem(model):
    # Estimate QDoRA decoder layer size
    LORA_RANK = 64

    tot_mem_in_gb = 0
    for n,p in model.model.layers[0].named_parameters():
        if "proj" in n:	
            # 4bit attn
            if any(attn_n in n for attn_n in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
                quant_mem = (p.numel() / 4) * 2 / 1e9
                GROUP_SIZE = 128
            # 2bit mlp
            else:
                quant_mem = (p.numel() / 8) * 2 / 1e9
                GROUP_SIZE = 64

            quant_zero_scale_mem = 2 * (p.numel() / GROUP_SIZE) * 2 / 1e9    
            lora_ab_mem = (p.size(0) * LORA_RANK + p.size(1) * LORA_RANK) * 2 / 1e9
            layer_mem = quant_mem + quant_zero_scale_mem + lora_ab_mem
        else:
            layer_mem = p.numel() * 2 / 1e9
        tot_mem_in_gb += layer_mem
    return tot_mem_in_gb

In [None]:
TARGET_MODEL_SIZE = 420
NUM_GPUS = 8

### Config A

- deep

1 LAYER TRAINING: ~16GB

In [None]:
llama_400b_config = copy.deepcopy(llama_70b_config)
llama_400b_config.hidden_size = estimate_param(llama_70b_config.hidden_size, llama_8b_config.hidden_size, 128)
llama_400b_config.intermediate_size = estimate_param(llama_70b_config.intermediate_size, llama_8b_config.intermediate_size, 128)
llama_400b_config.num_attention_heads = estimate_param(llama_70b_config.num_attention_heads, llama_8b_config.num_attention_heads, 8)
llama_400b_config.num_key_value_heads = estimate_param(llama_70b_config.num_key_value_heads, llama_8b_config.num_key_value_heads, 8)
llama_400b_config.hidden_size, llama_400b_config.intermediate_size, llama_400b_config.num_attention_heads, llama_400b_config.num_key_value_heads

(11264, 39296, 88, 8)

In [None]:
llama_400b_config.hidden_size

11264

In [None]:
# hidden_size = 11264
# intermediate_size = 39296
# num_attention_heads = 88
# num_key_value_heads = 8

llama_400b_config.hidden_size, llama_400b_config.intermediate_size, llama_400b_config.num_attention_heads, llama_400b_config.num_key_value_heads

(11264, 39296, 88, 8)

In [None]:
# find num layers
llama_400b_config.num_hidden_layers = 1

In [None]:
with init_empty_weights():
	model = AutoModelForCausalLM.from_config(llama_400b_config)

In [None]:
num_estimated_hidden_layers = esimate_num_hidden_layers(model); num_estimated_hidden_layers

256

In [None]:
embed_lm_head_total_params = model.model.embed_tokens.weight.numel() + model.lm_head.weight.numel()

In [None]:
total_mem_in_gb_4bit = (estimate_4bit_qdora_layer_mem(model) * num_estimated_hidden_layers) + (embed_lm_head_total_params * 2 / 1e9); total_mem_in_gb_4bit

231.28335974400002

In [None]:
total_mem_in_gb_mixed_bit = (estimate_mixed_bit_qdora_layer_mem(model) * num_estimated_hidden_layers) + (embed_lm_head_total_params * 2 / 1e9); total_mem_in_gb_mixed_bit

156.921495552

In [None]:
total_mem_in_gb_4bit / NUM_GPUS, total_mem_in_gb_mixed_bit / NUM_GPUS

(28.910419968000003, 19.615186944)

In [None]:
state_dict = {k:torch.empty_like(v, device="cuda") for k,v in model.state_dict().items()}
state_dict = {k:v.fill_(0.01) if len(v.size()) < 2 else torch.nn.init.xavier_uniform(v) for k,v in state_dict.items()}

  state_dict = {k:v.fill_(0.01) if len(v.size()) < 2 else torch.nn.init.xavier_uniform(v) for k,v in state_dict.items()}


In [None]:
output_dir = "/workspace/models/meta-llama/Meta-Llama-3-400B-Instruct-A"
os.makedirs(output_dir, exist_ok=True)

In [None]:
save_file(state_dict, os.path.join(output_dir, "model_state_dict.safetensors"))

In [None]:
# save config
with open(os.path.join(output_dir, "config.json"), "w") as f:
	json.dump(llama_400b_config.to_dict(), f)

### Config B

- mid

1 LAYER TRAINING: ~25GB

In [None]:
llama_400b_config = copy.deepcopy(llama_70b_config)
llama_400b_config.hidden_size = estimate_param(llama_70b_config.hidden_size, llama_8b_config.hidden_size, mult_factor=1.5, nearest_multiple=128)
llama_400b_config.intermediate_size = estimate_param(llama_70b_config.intermediate_size, llama_8b_config.intermediate_size, mult_factor=1.5, nearest_multiple=128)
llama_400b_config.num_attention_heads = estimate_param(llama_70b_config.num_attention_heads, llama_8b_config.num_attention_heads, mult_factor=1.5, nearest_multiple=8)
llama_400b_config.num_key_value_heads = estimate_param(llama_70b_config.num_key_value_heads, llama_8b_config.num_key_value_heads, mult_factor=1.5, nearest_multiple=8)
llama_400b_config.hidden_size, llama_400b_config.intermediate_size, llama_400b_config.num_attention_heads, llama_400b_config.num_key_value_heads

(16896, 59008, 128, 8)

In [None]:
# hidden_size = 16896
# intermediate_size = 59008
# num_attention_heads = 128
# num_key_value_heads = 8

llama_400b_config.hidden_size, llama_400b_config.intermediate_size, llama_400b_config.num_attention_heads, llama_400b_config.num_key_value_heads

(16896, 59008, 128, 8)

In [None]:
# find num layers
llama_400b_config.num_hidden_layers = 1

In [None]:
with init_empty_weights():
	model = AutoModelForCausalLM.from_config(llama_400b_config)

In [None]:
num_estimated_hidden_layers = esimate_num_hidden_layers(model); num_estimated_hidden_layers

112

In [None]:
embed_lm_head_total_params = model.model.embed_tokens.weight.numel() + model.lm_head.weight.numel()

In [None]:
total_mem_in_gb_4bit = (estimate_4bit_qdora_layer_mem(model) * num_estimated_hidden_layers) + (embed_lm_head_total_params * 2 / 1e9); total_mem_in_gb_4bit

227.48277964799996

In [None]:
total_mem_in_gb_mixed_bit = (estimate_mixed_bit_qdora_layer_mem(model) * num_estimated_hidden_layers) + (embed_lm_head_total_params * 2 / 1e9); total_mem_in_gb_mixed_bit

154.20334079999998

In [None]:
total_mem_in_gb_4bit / NUM_GPUS, total_mem_in_gb_mixed_bit / NUM_GPUS

(28.435347455999995, 19.275417599999997)

In [None]:
state_dict = {k:torch.empty_like(v, device="cuda") for k,v in model.state_dict().items()}
state_dict = {k:v.fill_(0.01) if len(v.size()) < 2 else torch.nn.init.xavier_uniform(v) for k,v in state_dict.items()}

  state_dict = {k:v.fill_(0.01) if len(v.size()) < 2 else torch.nn.init.xavier_uniform(v) for k,v in state_dict.items()}


In [None]:
output_dir = "/workspace/models/meta-llama/Meta-Llama-3-400B-Instruct-B"
os.makedirs(output_dir, exist_ok=True)

In [None]:
save_file(state_dict, os.path.join(output_dir, "model_state_dict.safetensors"))

In [None]:
# save config
with open(os.path.join(output_dir, "config.json"), "w") as f:
	json.dump(llama_400b_config.to_dict(), f)

### Config C

- wide

1 LAYER TRAINING: ~40GB

In [None]:
llama_400b_config = copy.deepcopy(llama_70b_config)
llama_400b_config.hidden_size = estimate_param(llama_70b_config.hidden_size, llama_8b_config.hidden_size, mult_factor=2, nearest_multiple=128)
llama_400b_config.intermediate_size = estimate_param(llama_70b_config.intermediate_size, llama_8b_config.intermediate_size, mult_factor=2, nearest_multiple=128)
llama_400b_config.num_attention_heads = estimate_param(llama_70b_config.num_attention_heads, llama_8b_config.num_attention_heads, mult_factor=2, nearest_multiple=8)
llama_400b_config.num_key_value_heads = estimate_param(llama_70b_config.num_key_value_heads, llama_8b_config.num_key_value_heads, mult_factor=2, nearest_multiple=8)
llama_400b_config.hidden_size, llama_400b_config.intermediate_size, llama_400b_config.num_attention_heads, llama_400b_config.num_key_value_heads

(22528, 78592, 176, 8)

In [None]:
# hidden_size = 22528
# intermediate_size = 78592
# num_attention_heads = 176
# num_key_value_heads = 8

llama_400b_config.hidden_size, llama_400b_config.intermediate_size, llama_400b_config.num_attention_heads, llama_400b_config.num_key_value_heads

(22528, 78592, 176, 8)

In [None]:
# find num layers
llama_400b_config.num_hidden_layers = 1

In [None]:
llama_400b_config

LlamaConfig {
  "_name_or_path": "meta-llama/Meta-Llama-3-70B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128009,
  "hidden_act": "silu",
  "hidden_size": 22528,
  "initializer_range": 0.02,
  "intermediate_size": 78592,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 176,
  "num_hidden_layers": 1,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.42.4",
  "use_cache": true,
  "vocab_size": 128256
}

In [None]:
with init_empty_weights():
	model = AutoModelForCausalLM.from_config(llama_400b_config)

In [None]:
num_estimated_hidden_layers = esimate_num_hidden_layers(model); num_estimated_hidden_layers

64

In [None]:
embed_lm_head_total_params = model.model.embed_tokens.weight.numel() + model.lm_head.weight.numel()

In [None]:
total_mem_in_gb_4bit = (estimate_4bit_qdora_layer_mem(model) * num_estimated_hidden_layers) + (embed_lm_head_total_params * 2 / 1e9); total_mem_in_gb_4bit

231.84487219199994

In [None]:
total_mem_in_gb_mixed_bit = (estimate_mixed_bit_qdora_layer_mem(model) * num_estimated_hidden_layers) + (embed_lm_head_total_params * 2 / 1e9); total_mem_in_gb_mixed_bit

157.48300799999996

In [None]:
total_mem_in_gb_4bit / NUM_GPUS, total_mem_in_gb_mixed_bit / NUM_GPUS

(28.980609023999993, 19.685375999999994)

In [None]:
state_dict = {k:torch.empty_like(v, device="cuda") for k,v in model.state_dict().items()}
state_dict = {k:v.fill_(0.01) if len(v.size()) < 2 else torch.nn.init.xavier_uniform(v) for k,v in state_dict.items()}

  state_dict = {k:v.fill_(0.01) if len(v.size()) < 2 else torch.nn.init.xavier_uniform(v) for k,v in state_dict.items()}


In [None]:
output_dir = "/workspace/models/meta-llama/Meta-Llama-3-400B-Instruct-C"
os.makedirs(output_dir, exist_ok=True)

In [None]:
save_file(state_dict, os.path.join(output_dir, "model_state_dict.safetensors"))

In [None]:
# save config
with open(os.path.join(output_dir, "config.json"), "w") as f:
	json.dump(llama_400b_config.to_dict(), f)