In [25]:
import time

In [1]:
import torch
import torch.nn as nn
import hqq_aten

In [None]:
from hqq.core.quantize import Quantizer, HQQLinear, BaseQuantizeConfig, HQQBackend

In [3]:
from typing import List
from torch import Tensor
from torch.nn import functional as F

In [6]:
from accelerate.utils import set_seed
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

In [7]:
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
import safetensors

In [30]:
from fastcore.parallel import parallel

In [14]:
# Optionally use the context manager to ensure one of the fused kernels is run
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(True, False, False):
    F.scaled_dot_product_attention(query,key,value)

In [15]:
set_seed(42)

In [16]:
m = torch.nn.Linear(16,128)

In [19]:
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False)
hqq_linear = HQQLinear(m, quant_config=quant_config)

In [24]:
hqq_linear.compute_dtype

torch.float16

In [23]:
next(hqq_linear.parameters())

Parameter containing:
tensor([[-1.8690e+31, -1.7469e-07, -9.8312e-20,  4.3347e+23, -1.0372e-23,
         -5.6423e+16,  1.3304e-05,  6.1785e-24],
        [-5.7602e+10,  5.1494e+18, -1.7353e+27, -7.9082e-32,  8.7318e+06,
         -4.3186e-06,  1.4261e-18,  3.5633e+17],
        [ 2.8733e-02, -6.6121e-15,  4.6052e-22, -5.8633e+18,  1.6486e+06,
          1.2226e-18,  9.0436e+25,  5.9841e-04],
        [ 6.3572e-37,  2.1430e-10,  5.6341e-01, -5.9994e-36,  1.9233e+11,
          2.9263e-09,  3.3071e-09,  1.0180e-20],
        [-1.0810e-13,  8.8023e+08,  6.2707e+18,  1.3579e-24, -4.7377e+23,
          3.5615e+17,  2.6324e-14,  4.2122e-09],
        [ 2.4662e-25, -3.4900e+27,  9.6193e+29,  2.6624e+03,  2.2651e-29,
          3.0514e+14,  6.9221e+30,  1.6402e+19],
        [ 7.4646e+22, -9.6859e-28, -4.3350e-10,  5.1519e-34, -4.1487e-07,
         -7.7171e+37,  9.2547e+13,  8.3544e+23],
        [-1.6869e-09, -2.6847e+18, -8.0041e-29,  9.5645e-38,  1.3935e-02,
         -1.4938e-13,  1.0959e-11,  1.0414e

In [127]:
w = m.weight.data

In [128]:
w.shape

torch.Size([128, 16])

In [129]:
W_q, meta = Quantizer.quantize(w, round_zero=True, optimize=True, view_as_float=False)

In [130]:
W_q.shape, W_q.dtype

(torch.Size([32, 32]), torch.uint8)

In [131]:
meta['scale'].dtype

torch.float16

In [138]:
w_dq = Quantizer.dequantize(W_q, meta)

In [139]:
w, w_dq

(tensor([[ 0.1196,  0.0683, -0.0960,  ..., -0.2410, -0.1544, -0.0864],
         [-0.0278, -0.0483,  0.1141,  ...,  0.0873,  0.0023,  0.2011],
         [ 0.0982, -0.0460,  0.0086,  ...,  0.0627, -0.0216, -0.0140],
         ...,
         [-0.0208,  0.1148, -0.0562,  ..., -0.0961,  0.2354,  0.2077],
         [ 0.1820,  0.1345, -0.0235,  ...,  0.0432, -0.1749,  0.1510],
         [-0.2125,  0.0024, -0.2045,  ..., -0.1916,  0.1080,  0.0231]]),
 tensor([[ 0.1224,  0.0717, -0.0930,  ..., -0.2524, -0.1595, -0.0937],
         [-0.0320, -0.0627,  0.1289,  ...,  0.0945,  0.0091,  0.1919],
         [ 0.0917, -0.0519,  0.0014,  ...,  0.0705, -0.0320,  0.0009],
         ...,
         [-0.0320,  0.1304, -0.0645,  ..., -0.0981,  0.2344,  0.1919],
         [ 0.1841,  0.1334, -0.0301,  ...,  0.0382, -0.1595,  0.1584],
         [-0.2222,  0.0016, -0.1934,  ..., -0.1943,  0.1057,  0.0273]],
        dtype=torch.float16))

In [140]:
torch.norm(w - w_dq, p=0.7)

tensor(390.0982)

In [35]:
BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False)

{'weight_quant_params': {'nbits': 4,
  'channel_wise': True,
  'group_size': 64,
  'optimize': True,
  'round_zero': True},
 'scale_quant_params': None,
 'zero_quant_params': None,
 'offload_meta': False}

In [None]:
quant_configs = [
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=True, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=True)
]

w_dqs = []
for quant_cfg in quant_configs:
    if quant_cfg['scale_quant_params']: 
        quant_cfg['scale_quant_params']['group_size'] = 8
    if quant_cfg['zero_quant_params']: 
        if quant_cfg['offload_meta']:
            quant_cfg['zero_quant_params']['group_size'] = 8
            quant_cfg['zero_quant_params']['channel_wise'] = True
        else:
            quant_cfg['zero_quant_params']['group_size'] = None
            quant_cfg['zero_quant_params']['channel_wise'] = False
    mq = HQQLinear(m, quant_cfg, compute_dtype=torch.bfloat16, initialize=False)
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
    mq.initialize()
    print(mq.W_q.dtype, mq.meta)
    print()
    w_dqs.append(mq.dequantize_aten())

In [143]:
(torch.norm(w.cuda() - w_dqs[0], p=0.7),
torch.norm(w.cuda() - w_dqs[1], p=0.7),
torch.norm(w.cuda() - w_dqs[2], p=0.7),
torch.norm(w.cuda() - w_dqs[3], p=0.7),
torch.norm(w.cuda() - w_dqs[4], p=0.7))

(tensor(390.9176, device='cuda:0'),
 tensor(390.5967, device='cuda:0'),
 tensor(390.7930, device='cuda:0'),
 tensor(390.1439, device='cuda:0'),
 tensor(392.0999, device='cuda:0'))

In [10]:
def replace_linear_hqq(model:nn.Module, quant_config, skip_modules:List[str]=["lm_head"], **kwargs):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        quant_config (`Dict[str, Any]`):
            The quantization configuration for the new linear module.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear_hqq(module, quant_config, skip_modules, **kwargs)

        if isinstance(module, torch.nn.Linear) and name not in skip_modules:
            model._modules[name] = HQQLinear(
                module,
                quant_config,
                **kwargs
            )
    return model

In [21]:
def load_and_quantize_hqq(module:nn.Module, name:str, value:Tensor, device:torch.device=None, dtype:torch.dtype=None,
                                  skip_names:list[str]=[], is_meta_rank:bool=False, low_memory:bool=True, verbose:bool=False):
    """
    Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.

    Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
    """
    def place_on_device(value):
        if is_meta_rank:
            device = 'meta'
        elif low_memory:
            device = 'cpu'
        return value.to(device=device, dtype=dtype)

    if any([skip_name in name for skip_name in skip_names]):
        if verbose:
            print(f"Skipping {name} because it is in skip_names")
        return

    module_key, _, value_key = name.rpartition('.')
    try:
        submodule = module.get_submodule(module_key)
    except AttributeError as e:
        print(f"Module {module_key} not found:\n{e}")
        return

    start = time.time()
    try:
        if isinstance(submodule, HQQLinear):
            if value_key == "weight":
                # init meta weights as empty on cpu
                submodule.linear_layer.to_empty(device="cpu")
                # copy pretrained weights
                submodule.linear_layer.weight.data.copy_(value)
                # quantize and update metadata
                submodule.initialize()
                
                if is_meta_rank:
                    setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("meta")))
                elif low_memory:
                    setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("cpu")))
                submodule.in_gpu = False

            if value_key == "bias":
                raise ValueError("Bias not supported in HQQLinear yet!")
        
            end = time.time()
            if not is_meta_rank:
                print(f"Loaded HQQLinear quantized {module_key} in {end-start:.3f} seconds")
            return
        
        else:
            param = submodule.get_parameter(value_key)
            value = type(param)(place_on_device(value).data)

    except AttributeError:
        # it's a buffer
        value = place_on_device(value)
        pass
    
    setattr(submodule, value_key, value)
    end = time.time()
    torch.cuda.empty_cache()
    if not is_meta_rank:
        print(f"Loaded {module_key} and {value_key} in {end-start:.3f} seconds")

In [49]:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)

In [97]:
compute_dtype = torch.bfloat16

model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
# cfg.num_hidden_layers = 8 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(cfg)
    # TODO: Tune BaseQuantizeConfig.
    quant_config = BaseQuantizeConfig(nbits=4, 
                                      group_size=64, 
                                      quant_zero=True, 
                                      quant_scale=True, 
                                      offload_meta=True)
    model.model = replace_linear_hqq(model.model, quant_config, device_n=torch.cuda.current_device(),
                                    compute_dtype=compute_dtype, del_orig=True, initialize=False)     
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
model.is_loaded_in_4bit = True

In [98]:
local_rank = 0
low_memory = True
load_param_skip_names = []
rank = 0

print("Loading model", rank)
start = time.time()
for filename in files:
    weights = safetensors.torch.load_file(filename)
    for name, param in weights.items():
        load_and_quantize_hqq(model, name, param, dtype=torch.bfloat16, device=local_rank, skip_names=load_param_skip_names,
                                is_meta_rank=(low_memory and rank!=0), verbose=True)
print(f"Loaded model weights in {time.time()-start:.3f} seconds")

Loading model 0
Loaded model.embed_tokens and weight in 0.067 seconds
Loaded model.layers.0.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.mlp.down_proj in 0.271 seconds
Loaded HQQLinear quantized model.layers.0.mlp.gate_proj in 0.243 seconds
Loaded HQQLinear quantized model.layers.0.mlp.up_proj in 0.236 seconds
Loaded model.layers.0.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.k_proj in 0.065 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.o_proj in 0.062 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.q_proj in 0.063 seconds
Loaded model.layers.0.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.v_proj in 0.060 seconds
Loaded model.layers.1.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.mlp.down_proj in 0.239 seconds
Loaded HQQLinear quantized model.layers.1.mlp.gate_proj 

In [104]:
def load_and_quantize_parallel(name_param, load_func, model, **kwargs):
    name, param = name_param
    load_func(model, name, param, **kwargs)

In [106]:
compute_dtype = torch.bfloat16

model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
# cfg.num_hidden_layers = 8 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
with init_empty_weights():
    model_fast = AutoModelForCausalLM.from_config(cfg)
    # TODO: Tune BaseQuantizeConfig.
    quant_config = BaseQuantizeConfig(nbits=4, 
                                      group_size=64, 
                                      quant_zero=True, 
                                      quant_scale=True, 
                                      offload_meta=True)
    model_fast.model = replace_linear_hqq(model_fast.model, quant_config, device_n=torch.cuda.current_device(),
                                          compute_dtype=compute_dtype, del_orig=True, initialize=False)     
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
model_fast.is_loaded_in_4bit = True

In [107]:
local_rank = 0
low_memory = True
load_param_skip_names = []
rank = 0

print("Loading model", rank)
start = time.time()
for filename in files:
    weights = safetensors.torch.load_file(filename)
    parallel(load_and_quantize_parallel, weights.items(), n_workers=8, threadpool=True, 
             load_func=load_and_quantize_hqq, model=model_fast, 
             dtype=torch.bfloat16, device=local_rank, skip_names=load_param_skip_names, 
             is_meta_rank=(low_memory and rank!=0), verbose=True)
print(f"Loaded model weights in {time.time()-start:.3f} seconds")

Loading model 0
Loaded model.layers.0.input_layernorm and weight in 0.003 seconds
Loaded model.layers.0.post_attention_layernorm and weight in 0.004 seconds
Loaded model.layers.0.self_attn.rotary_emb and inv_freq in 0.032 seconds
Loaded model.embed_tokens and weight in 0.203 seconds
Loaded model.layers.1.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.k_proj in 1.016 seconds
Loaded HQQLinear quantized model.layers.0.mlp.gate_proj in 1.065 seconds
Loaded HQQLinear quantized model.layers.0.mlp.down_proj in 1.201 seconds
Loaded model.layers.1.post_attention_layernorm and weight in 0.008 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.v_proj in 1.155 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.q_proj in 1.211 seconds
Loaded HQQLinear quantized model.layers.0.mlp.up_proj in 1.252 seconds
Loaded model.layers.1.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.o_pro

In [108]:
for (n1,p1), (n2,p2) in zip(model.named_parameters(), model_fast.named_parameters()):
    if n1 == n2:
        if "proj" in n1:
            assert torch.allclose(p1.view(torch.uint8), p2.view(torch.uint8))
        else:
            assert torch.allclose(p1, p2)