In [1]:
import torch
import specdecodes.models.llm.modeling_llama as modeling_llama

from hqq.core.quantize import *
from hqq.models.hf.base import AutoHQQHFModel
from hf_share_sd.base import AutoHQQHFShareSDModel


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = 'meta-llama/Llama-2-7b-chat-hf'
dtype = torch.bfloat16
device = 'cuda:0'

model = modeling_llama.LlamaForCausalLM.from_pretrained(
    model_path,
    torch_dtype=dtype,
    low_cpu_mem_usage=True,
    device_map=device,
    _attn_implementation="sdpa",
)

#Quantize
quant_config = BaseQuantizeConfig(nbits=2, group_size=32) 
# AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=dtype, device=device)
AutoHQQHFShareSDModel.quantize_model(model, quant_config=quant_config, compute_dtype=dtype, device=device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00,  9.63s/it]
100%|██████████| 99/99 [00:00<00:00, 5108.46it/s]
100%|██████████| 32/32 [00:16<00:00,  1.88it/s]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (quant_q_proj): HQQLinear(
            in_features=4096, out_features=4096, bias=False
            (linear_layer): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (quant_k_proj): HQQLinear(
            in_features=4096, out_features=4096, bias=False
            (linear_layer): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (quant_v_proj): HQQLinear(
            in_features=4096, out_features=4096, bias=False
            (linear_layer): 

In [3]:
save_dir = 'quantized_model'

#Save: Make sure to save the model BEFORE any patching
AutoHQQHFModel.save_quantized(model, save_dir)

#Load
model = AutoHQQHFModel.from_quantized(save_dir)

  return torch.load(cls.get_weight_file(save_dir), map_location=map_location)
100%|██████████| 99/99 [00:00<00:00, 35985.45it/s]
100%|██████████| 225/225 [00:00<00:00, 4264.51it/s]


In [3]:
save_dir = 'quantized_model'
model = AutoHQQHFModel.from_quantized(save_dir)

  return torch.load(cls.get_weight_file(save_dir), map_location=map_location)
100%|██████████| 99/99 [00:00<00:00, 7518.17it/s]
100%|██████████| 225/225 [00:00<00:00, 5977.44it/s]


In [4]:
size_model = 0
for param in model.parameters():
    if param.data.is_floating_point():
        size_model += param.numel() * torch.finfo(param.data.dtype).bits
    else:
        size_model += param.numel() * torch.iinfo(param.data.dtype).bits
print(f"model size: {size_model} / bit | {size_model / 8e6:.2f} / MB")

model size: 30102585344 / bit | 3762.82 / MB


In [3]:
size_model = 0
for param in model.parameters():
    if param.data.is_floating_point():
        size_model += param.numel() * torch.finfo(param.data.dtype).bits
    else:
        size_model += param.numel() * torch.iinfo(param.data.dtype).bits
print(f"model size: {size_model} / bit | {size_model / 8e6:.2f} / MB")

model size: 120766660608 / bit | 15095.83 / MB


In [3]:
size_model = 0
for param in model.parameters():
    if param.data.is_floating_point():
        size_model += param.numel() * torch.finfo(param.data.dtype).bits
    else:
        size_model += param.numel() * torch.iinfo(param.data.dtype).bits
print(f"model size: {size_model} / bit | {size_model / 8e6:.2f} / MB")

model size: 107814649856 / bit | 13476.83 / MB


In [6]:
size_model = 0
for n, param in model.named_parameters():
    # ignore embed_tokens, lm_head
    print(f"{n} {param.numel()}")
    if 'embed_tokens' in n or 'lm_head' in n:
        print(f"ignored {n}")
        continue
    if param.data.is_floating_point():
        size_model += param.numel() * torch.finfo(param.data.dtype).bits
    else:
        size_model += param.numel() * torch.iinfo(param.data.dtype).bits
print(f"model size: {size_model} / bit | {size_model / 8e6:.2f} / MB")

model.embed_tokens.weight 131072000
ignored model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight 16777216
model.layers.0.self_attn.k_proj.weight 16777216
model.layers.0.self_attn.v_proj.weight 16777216
model.layers.0.self_attn.o_proj.weight 16777216
model.layers.0.self_attn.quant_q_proj.W_q 4194304
model.layers.0.self_attn.quant_k_proj.W_q 4194304
model.layers.0.self_attn.quant_v_proj.W_q 4194304
model.layers.0.self_attn.quant_o_proj.W_q 4194304
model.layers.0.mlp.gate_proj.weight 45088768
model.layers.0.mlp.up_proj.weight 45088768
model.layers.0.mlp.down_proj.weight 45088768
model.layers.0.mlp.quant_gate_proj.W_q 11272192
model.layers.0.mlp.quant_up_proj.W_q 11272192
model.layers.0.mlp.quant_down_proj.W_q 11272192
model.layers.0.input_layernorm.weight 4096
model.layers.0.post_attention_layernorm.weight 4096
model.layers.1.self_attn.q_proj.weight 16777216
model.layers.1.self_attn.k_proj.weight 16777216
model.layers.1.self_attn.v_proj.weight 16777216
model.layers.1.self_attn