In [17]:
# set to auto reload modules
%load_ext autoreload
%autoreload 2
!CUDA_VISIBLE_DEVICES=0,1

import torch
import yaml 
import os 
import glob
import argparse
from src.model.llama import LlamaForCausalLM
from transformers import LlamaForCausalLM as OrigLlama
from transformers import AutoConfig

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
base_model = "meta-llama/Llama-2-70b-hf"
checkpoints_path = "/data/lliu/huffman/models/meta-llama/Llama-2-70b-hf/compressed/run_18/checkpoints.yaml"
hf_model_save_path = "/data/lliu/huffman/models/meta-llama/Llama-2-70b-hf/compressed_hf/run_18/"
add_bias = True

In [19]:
orig_config = AutoConfig.from_pretrained(base_model,dtype = "auto",
                                         device_map="cpu",
                                        attn_implementation='sdpa')
orig_model = OrigLlama.from_pretrained(base_model, config=orig_config, torch_dtype="auto",
                                        device_map="cpu",
                                        low_cpu_mem_usage=True, attn_implementation='sdpa')

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [20]:
for name, param in orig_model.named_parameters():
    assert param.dtype == torch.float16, f"{name} is not fp16, it is {param.dtype}"

In [21]:
checkpoints_dict = yaml.load(open(checkpoints_path, "r"), Loader=yaml.FullLoader)


compression_kwargs = yaml.load(open((checkpoints_dict[list(checkpoints_dict.keys())[0]]).replace("compressed.pt", "compressed_args.yaml")),
                                Loader=yaml.FullLoader)
#check that all the other checkpoints have the same compression args
for checkpoint in checkpoints_dict.values():
    assert compression_kwargs == yaml.load(open(checkpoint.replace("compressed.pt", "compressed_args.yaml"), "r"), Loader=yaml.FullLoader)

#remove dtype from the compression kwargs
compression_kwargs.pop("dtype", None)

compression_type = compression_kwargs["compression_type"]


compression_config = {"compression_kwargs": compression_kwargs, "compression_type": compression_type,
                        "add_bias": add_bias, "skip_list":None}

orig_config.compress_config = compression_config


In [22]:
model = LlamaForCausalLM(orig_config)
model.to(orig_config.torch_dtype)
#iterate through all parameters and assert that they are fp16
for name, param in model.named_parameters():
    assert param.dtype == torch.float16, f"{name} is not fp16, it is {param.dtype}"
model.load_state_dict(orig_model.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['model.layers.0.self_attn.q_proj.bias', 'model.layers.0.self_attn.q_proj.codebook', 'model.layers.0.self_attn.q_proj.assignments', 'model.layers.0.self_attn.q_proj.normalizer.norms.0', 'model.layers.0.self_attn.q_proj.normalizer.norms.1', 'model.layers.0.self_attn.q_proj.normalizer.zeros.0', 'model.layers.0.self_attn.q_proj.normalizer.zeros.1', 'model.layers.0.self_attn.k_proj.bias', 'model.layers.0.self_attn.k_proj.codebook', 'model.layers.0.self_attn.k_proj.assignments', 'model.layers.0.self_attn.k_proj.normalizer.norms.0', 'model.layers.0.self_attn.k_proj.normalizer.norms.1', 'model.layers.0.self_attn.k_proj.normalizer.zeros.0', 'model.layers.0.self_attn.k_proj.normalizer.zeros.1', 'model.layers.0.self_attn.v_proj.bias', 'model.layers.0.self_attn.v_proj.codebook', 'model.layers.0.self_attn.v_proj.assignments', 'model.layers.0.self_attn.v_proj.normalizer.norms.0', 'model.layers.0.self_attn.v_proj.normalizer.norms.1', 'model.layers.0.self_attn.v_proj.no

In [23]:
#for each checkpoint, load the right weight
for checkpoint_name,checkpoint_path in checkpoints_dict.items():
    print(checkpoint_name)
    #first remove the base_model name from it
    checkpoint_name = checkpoint_name.replace(base_model, "")
    #now split by /
    checkpoint_name = checkpoint_name.split("/")[-2:]
    #from the first part, we can get which layer it is
    i_layer = int(checkpoint_name[0].replace("layer_", ""))
    #from the second part we can get which module (self_attn, mlp, etc) and which layer it is
    submodule_name, linear_name = checkpoint_name[1].split(".")
    
    #now we get the right module
    layer = getattr(getattr(model.model.layers[i_layer], submodule_name), linear_name)
    #record the original dtype
    orig_dtype = layer.codebook.dtype
    orig_device = layer.codebook.device
    print(orig_dtype, orig_device)
    #load the state dict
    layer.load_state_dict(torch.load(checkpoint_path, map_location=orig_device), strict=False)
    #convert to the right dtype
    layer.to(orig_dtype)
    # raise ValueError("stop here")

meta-llama/Llama-2-70b-hf/layer_0/mlp.down_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_0/mlp.gate_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_0/mlp.up_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_0/self_attn.k_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_0/self_attn.o_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_0/self_attn.q_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_0/self_attn.v_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_1/mlp.down_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_1/mlp.gate_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_1/mlp.up_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_1/self_attn.k_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_1/self_attn.o_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_1/self_attn.q_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_1/self_attn.v_proj
torch.float16 cpu
meta-llama/Llama-2-70b-hf/layer_10/mlp.down_proj
torch.f

In [14]:
#iterate through all parameters and assert that they are fp16
for name, param in model.named_parameters():
    assert param.dtype == torch.float16, f"{name} is not fp16, it is {param.dtype}"

In [15]:
original_model = OrigLlama.from_pretrained(base_model, device_map="cpu", torch_dtype=torch.float32)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [14]:
model.config.torch_dtype

'float16'

In [24]:
#save the model
model.save_pretrained(hf_model_save_path)

In [16]:
model.model.layers[0].self_attn.q_proj.reconstruct().dtype

torch.float16

In [22]:
#try to load the model
loaded_model = LlamaForCausalLM.from_pretrained(hf_model_save_path,
                                                torch_dtype = 'auto',
                                                low_cpu_mem_usage=True,
                                                attn_implementation='sdpa')

device None dtype torch.float16
kwargs {'d': 6, 'ignore_norms': True, 'initialize_kwargs': {'deterministic': False, 'multiple_each_time': 1.0}, 'initialize_method': 'kmeans++', 'n_bits': 2, 'n_inits': 1, 'n_iters': 100, 'normalizer_kwargs': {'norm_order': [0, 1], 'p': 2, 'zero': [False, False]}}
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float16
device None dtype torch.float16
kwargs {'d': 6, 'ignore_norms': True, 'initialize_kwargs': {'deterministic': False, 'multiple_each_time': 1.0}, 'initialize_method': 'kmeans++', 'n_bits': 2, 'n_inits': 1, 'n_iters': 100, 'normalizer_kwargs': {'norm_order': [0, 1], 'p': 2, 'zero': [False, False]}}
codebook shape:  torch.Size([4096, 6]) device:  meta dtype:  torch.float16
device None dtype torch.float16
kwargs {'d': 6, 'ignore_norms': True, 'initialize_kwargs': {'deterministic': False, 'multiple_each_time': 1.0}, 'initialize_method': 'kmeans++', 'n_bits': 2, 'n_inits': 1, 'n_iters': 100, 'normalizer_kwargs': {'norm_order': 

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.66it/s]


In [23]:
loaded_model.model.layers[0].self_attn.q_proj.reconstruct().dtype

torch.float16

In [16]:
x = torch.randn(1, 1, 4096).cpu()
loaded_model.model.layers[0].self_attn.q_proj(x)

tensor([[[ 0.2513,  0.5720,  0.5787,  ...,  0.3616, -0.9614,  0.5276]]],
       grad_fn=<ViewBackward0>)

In [19]:
loaded_model.model.layers[0].self_attn.q_proj.cache_reconstruct()
loaded_model.model.layers[0].self_attn.q_proj(x)

tensor([[[ 0.2513,  0.5720,  0.5787,  ...,  0.3616, -0.9614,  0.5276]]],
       grad_fn=<ViewBackward0>)

In [7]:
original_model = OrigLlama.from_pretrained(base_model, device_map="cpu", torch_dtype="auto",
                                             attn_implementation='sdpa')

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.02it/s]


In [8]:
#try to save the model to a temp dir
temp_dir = "/data/lliu/huffman/temp/Llama-2-7b-hf/"
original_model.save_pretrained(temp_dir)

In [10]:
loaded_orig_llama = OrigLlama.from_pretrained(temp_dir, device_map="cpu", torch_dtype="auto")


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.88it/s]


In [14]:
loaded_orig_llama.model.layers[0].self_attn.q_proj.weight.dtype

torch.float16

In [21]:
#iterate through all parameters and assert that they are fp16
for name, param in loaded_orig_llama.named_parameters():
    assert param.dtype == torch.float16, f"{name} is not fp16, it is {param.dtype}"