In [None]:
import time

In [None]:
import torch
import torch.nn as nn
import hqq_aten

In [None]:
from hqq.core.quantize import Quantizer, HQQLinear, BaseQuantizeConfig, HQQBackend

[36mhqq_aten package available. Set backend to HQQBackend.ATEN for faster inference and HQQBackend.ATEN_BACKPROP for faster training![0m


In [None]:
from typing import List
from torch import Tensor
from torch.nn import functional as F

In [None]:
from accelerate.utils import set_seed
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

In [None]:
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
import safetensors

In [None]:
from fastcore.parallel import parallel

In [None]:
# Optionally use the context manager to ensure one of the fused kernels is run
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(True, False, False):
    F.scaled_dot_product_attention(query,key,value)

In [None]:
set_seed(42)

In [None]:
m = torch.nn.Linear(16,128)

### FSDP

In [None]:
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False)
hqq_linear = HQQLinear(m, quant_config=quant_config)

In [None]:
hqq_linear.compute_dtype

torch.float16

In [None]:
next(hqq_linear.parameters())

Parameter containing:
tensor([[-1.8690e+31, -1.7469e-07, -9.8312e-20,  4.3347e+23, -1.0372e-23,
         -5.6423e+16,  1.3304e-05,  6.1785e-24],
        [-5.7602e+10,  5.1494e+18, -1.7353e+27, -7.9082e-32,  8.7318e+06,
         -4.3186e-06,  1.4261e-18,  3.5633e+17],
        [ 2.8733e-02, -6.6121e-15,  4.6052e-22, -5.8633e+18,  1.6486e+06,
          1.2226e-18,  9.0436e+25,  5.9841e-04],
        [ 6.3572e-37,  2.1430e-10,  5.6341e-01, -5.9994e-36,  1.9233e+11,
          2.9263e-09,  3.3071e-09,  1.0180e-20],
        [-1.0810e-13,  8.8023e+08,  6.2707e+18,  1.3579e-24, -4.7377e+23,
          3.5615e+17,  2.6324e-14,  4.2122e-09],
        [ 2.4662e-25, -3.4900e+27,  9.6193e+29,  2.6624e+03,  2.2651e-29,
          3.0514e+14,  6.9221e+30,  1.6402e+19],
        [ 7.4646e+22, -9.6859e-28, -4.3350e-10,  5.1519e-34, -4.1487e-07,
         -7.7171e+37,  9.2547e+13,  8.3544e+23],
        [-1.6869e-09, -2.6847e+18, -8.0041e-29,  9.5645e-38,  1.3935e-02,
         -1.4938e-13,  1.0959e-11,  1.0414e

In [None]:
w = m.weight.data

In [None]:
w.shape

torch.Size([128, 16])

In [None]:
W_q, meta = Quantizer.quantize(w, round_zero=True, optimize=True, view_as_float=False)

In [None]:
W_q.shape, W_q.dtype

(torch.Size([32, 32]), torch.uint8)

In [None]:
meta['scale'].dtype

torch.float16

In [None]:
w_dq = Quantizer.dequantize(W_q, meta)

In [None]:
w, w_dq

(tensor([[ 0.1196,  0.0683, -0.0960,  ..., -0.2410, -0.1544, -0.0864],
         [-0.0278, -0.0483,  0.1141,  ...,  0.0873,  0.0023,  0.2011],
         [ 0.0982, -0.0460,  0.0086,  ...,  0.0627, -0.0216, -0.0140],
         ...,
         [-0.0208,  0.1148, -0.0562,  ..., -0.0961,  0.2354,  0.2077],
         [ 0.1820,  0.1345, -0.0235,  ...,  0.0432, -0.1749,  0.1510],
         [-0.2125,  0.0024, -0.2045,  ..., -0.1916,  0.1080,  0.0231]]),
 tensor([[ 0.1224,  0.0717, -0.0930,  ..., -0.2524, -0.1595, -0.0937],
         [-0.0320, -0.0627,  0.1289,  ...,  0.0945,  0.0091,  0.1919],
         [ 0.0917, -0.0519,  0.0014,  ...,  0.0705, -0.0320,  0.0009],
         ...,
         [-0.0320,  0.1304, -0.0645,  ..., -0.0981,  0.2344,  0.1919],
         [ 0.1841,  0.1334, -0.0301,  ...,  0.0382, -0.1595,  0.1584],
         [-0.2222,  0.0016, -0.1934,  ..., -0.1943,  0.1057,  0.0273]],
        dtype=torch.float16))

In [None]:
torch.norm(w - w_dq, p=0.7)

tensor(390.0982)

In [None]:
BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False)

{'weight_quant_params': {'nbits': 4,
  'channel_wise': True,
  'group_size': 64,
  'optimize': True,
  'round_zero': True},
 'scale_quant_params': None,
 'zero_quant_params': None,
 'offload_meta': False}

In [None]:
quant_configs = [
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=True, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=True)
]

w_dqs = []
for quant_cfg in quant_configs:
    if quant_cfg['scale_quant_params']: 
        quant_cfg['scale_quant_params']['group_size'] = 8
    if quant_cfg['zero_quant_params']: 
        if quant_cfg['offload_meta']:
            quant_cfg['zero_quant_params']['group_size'] = 8
            quant_cfg['zero_quant_params']['channel_wise'] = True
        else:
            quant_cfg['zero_quant_params']['group_size'] = None
            quant_cfg['zero_quant_params']['channel_wise'] = False
    mq = HQQLinear(m, quant_cfg, compute_dtype=torch.bfloat16, initialize=False)
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
    mq.initialize()
    print(mq.W_q.dtype, mq.meta)
    print()
    w_dqs.append(mq.dequantize_aten())

In [None]:
(torch.norm(w.cuda() - w_dqs[0], p=0.7),
torch.norm(w.cuda() - w_dqs[1], p=0.7),
torch.norm(w.cuda() - w_dqs[2], p=0.7),
torch.norm(w.cuda() - w_dqs[3], p=0.7),
torch.norm(w.cuda() - w_dqs[4], p=0.7))

(tensor(390.9176, device='cuda:0'),
 tensor(390.5967, device='cuda:0'),
 tensor(390.7930, device='cuda:0'),
 tensor(390.1439, device='cuda:0'),
 tensor(392.0999, device='cuda:0'))

In [None]:
def replace_linear_hqq(model:nn.Module, quant_config, skip_modules:List[str]=["lm_head"], **kwargs):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        quant_config (`Dict[str, Any]`):
            The quantization configuration for the new linear module.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear_hqq(module, quant_config, skip_modules, **kwargs)

        if isinstance(module, torch.nn.Linear) and name not in skip_modules:
            model._modules[name] = HQQLinear(
                module,
                quant_config,
                **kwargs
            )
    return model

In [None]:
def load_and_quantize_hqq(module:nn.Module, name:str, value:Tensor, device:torch.device=None, dtype:torch.dtype=None,
                                  skip_names:list[str]=[], is_meta_rank:bool=False, low_memory:bool=True, verbose:bool=False):
    """
    Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.

    Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
    """
    def place_on_device(value):
        if is_meta_rank:
            device = 'meta'
        elif low_memory:
            device = 'cpu'
        return value.to(device=device, dtype=dtype)

    if any([skip_name in name for skip_name in skip_names]):
        if verbose:
            print(f"Skipping {name} because it is in skip_names")
        return

    module_key, _, value_key = name.rpartition('.')
    try:
        submodule = module.get_submodule(module_key)
    except AttributeError as e:
        print(f"Module {module_key} not found:\n{e}")
        return

    start = time.time()
    try:
        if isinstance(submodule, HQQLinear):
            if value_key == "weight":
                # init meta weights as empty on cpu
                submodule.linear_layer.to_empty(device="cpu")
                # copy pretrained weights
                submodule.linear_layer.weight.data.copy_(value)
                # quantize and update metadata
                submodule.initialize()
                
                if is_meta_rank:
                    setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("meta")))
                elif low_memory:
                    setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("cpu")))
                submodule.in_gpu = False

            if value_key == "bias":
                raise ValueError("Bias not supported in HQQLinear yet!")
        
            end = time.time()
            if not is_meta_rank:
                print(f"Loaded HQQLinear quantized {module_key} in {end-start:.3f} seconds")
            return
        
        else:
            param = submodule.get_parameter(value_key)
            value = type(param)(place_on_device(value).data)

    except AttributeError:
        # it's a buffer
        value = place_on_device(value)
        pass
    
    setattr(submodule, value_key, value)
    end = time.time()
    torch.cuda.empty_cache()
    if not is_meta_rank:
        print(f"Loaded {module_key} and {value_key} in {end-start:.3f} seconds")

In [None]:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)

In [None]:
compute_dtype = torch.bfloat16

model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
# cfg.num_hidden_layers = 8 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(cfg)
    # TODO: Tune BaseQuantizeConfig.
    quant_config = BaseQuantizeConfig(nbits=4, 
                                      group_size=64, 
                                      quant_zero=True, 
                                      quant_scale=True, 
                                      offload_meta=True)
    model.model = replace_linear_hqq(model.model, quant_config, device_n=torch.cuda.current_device(),
                                    compute_dtype=compute_dtype, del_orig=True, initialize=False)     
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
model.is_loaded_in_4bit = True

In [None]:
local_rank = 0
low_memory = True
load_param_skip_names = []
rank = 0

print("Loading model", rank)
start = time.time()
for filename in files:
    weights = safetensors.torch.load_file(filename)
    for name, param in weights.items():
        load_and_quantize_hqq(model, name, param, dtype=torch.bfloat16, device=local_rank, skip_names=load_param_skip_names,
                                is_meta_rank=(low_memory and rank!=0), verbose=True)
print(f"Loaded model weights in {time.time()-start:.3f} seconds")

Loading model 0
Loaded model.embed_tokens and weight in 0.067 seconds
Loaded model.layers.0.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.mlp.down_proj in 0.271 seconds
Loaded HQQLinear quantized model.layers.0.mlp.gate_proj in 0.243 seconds
Loaded HQQLinear quantized model.layers.0.mlp.up_proj in 0.236 seconds
Loaded model.layers.0.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.k_proj in 0.065 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.o_proj in 0.062 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.q_proj in 0.063 seconds
Loaded model.layers.0.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.v_proj in 0.060 seconds
Loaded model.layers.1.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.mlp.down_proj in 0.239 seconds
Loaded HQQLinear quantized model.layers.1.mlp.gate_proj 

In [None]:
def load_and_quantize_parallel(name_param, load_func, model, **kwargs):
    name, param = name_param
    load_func(model, name, param, **kwargs)

In [None]:
compute_dtype = torch.bfloat16

model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
# cfg.num_hidden_layers = 8 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
with init_empty_weights():
    model_fast = AutoModelForCausalLM.from_config(cfg)
    # TODO: Tune BaseQuantizeConfig.
    quant_config = BaseQuantizeConfig(nbits=4, 
                                      group_size=64, 
                                      quant_zero=True, 
                                      quant_scale=True, 
                                      offload_meta=True)
    model_fast.model = replace_linear_hqq(model_fast.model, quant_config, device_n=torch.cuda.current_device(),
                                          compute_dtype=compute_dtype, del_orig=True, initialize=False)     
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
model_fast.is_loaded_in_4bit = True

In [None]:
local_rank = 0
low_memory = True
load_param_skip_names = []
rank = 0

print("Loading model", rank)
start = time.time()
for filename in files:
    weights = safetensors.torch.load_file(filename)
    parallel(load_and_quantize_parallel, weights.items(), n_workers=8, threadpool=True, 
             load_func=load_and_quantize_hqq, model=model_fast, 
             dtype=torch.bfloat16, device=local_rank, skip_names=load_param_skip_names, 
             is_meta_rank=(low_memory and rank!=0), verbose=True)
print(f"Loaded model weights in {time.time()-start:.3f} seconds")

Loading model 0
Loaded model.layers.0.input_layernorm and weight in 0.003 seconds
Loaded model.layers.0.post_attention_layernorm and weight in 0.004 seconds
Loaded model.layers.0.self_attn.rotary_emb and inv_freq in 0.032 seconds
Loaded model.embed_tokens and weight in 0.203 seconds
Loaded model.layers.1.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.k_proj in 1.016 seconds
Loaded HQQLinear quantized model.layers.0.mlp.gate_proj in 1.065 seconds
Loaded HQQLinear quantized model.layers.0.mlp.down_proj in 1.201 seconds
Loaded model.layers.1.post_attention_layernorm and weight in 0.008 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.v_proj in 1.155 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.q_proj in 1.211 seconds
Loaded HQQLinear quantized model.layers.0.mlp.up_proj in 1.252 seconds
Loaded model.layers.1.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.o_pro

In [None]:
for (n1,p1), (n2,p2) in zip(model.named_parameters(), model_fast.named_parameters()):
    if n1 == n2:
        if "proj" in n1:
            assert torch.allclose(p1.view(torch.uint8), p2.view(torch.uint8))
        else:
            assert torch.allclose(p1, p2)

In [None]:
class HQQDORA(nn.Module):
    def __init__(self, base_layer, lora_rank, lora_dropout):
        super().__init__()
        self.base_layer = base_layer
        dtype = getattr(base_layer, "compute_dtype", next(base_layer.parameters()).dtype)
        device = next(base_layer.parameters()).device
        
        std_dev = 1 / torch.sqrt(torch.tensor(lora_rank).float())
        self.lora_A = nn.Parameter(torch.randn(base_layer.out_features, lora_rank).to(device=device,dtype=dtype)*std_dev)
        self.lora_B = nn.Parameter(torch.zeros(lora_rank, base_layer.in_features).to(device=device,dtype=dtype))

        self.m = nn.Parameter(self.base_layer.dequantize_aten().clone().norm(p=2, dim=0, keepdim=True))
    
    def forward(self, x):        

        lora = torch.matmul(self.lora_A, self.lora_B)
        adapted = self.base_layer.dequantize_aten() + lora
        column_norm = adapted.norm(p=2, dim=0, keepdim=True)

        assert torch.equal(self.m, column_norm)
        
        calc_weights = self.m * (adapted / column_norm)

        assert torch.allclose(self.base_layer.dequantize_aten(), calc_weights)
        
        return torch.matmul(x, calc_weights.t())

In [None]:
quant_config = BaseQuantizeConfig(nbits=4, 
                                  group_size=64, 
                                  quant_zero=True, 
                                  quant_scale=True, 
                                  offload_meta=True)

base_layer = HQQLinear(nn.Linear(128,256), quant_config, compute_dtype=torch.float32)
dora = HQQDORA(base_layer, 8, 0)
x = torch.randn(2,4,128).cuda()
torch.isclose(dora(x), torch.matmul(x, base_layer.dequantize_aten().t())).float().mean()

tensor(0.9985, device='cuda:0')

In [None]:
class DoRALayer(nn.Module):
    def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
        super().__init__()

        if weight is not None:
            self.weight = nn.Parameter(weight, requires_grad=False)
        else:
            self.weight = nn.Parameter(torch.Tensor(d_out, d_in), requires_grad=False)

        if bias is not None:
            self.bias = nn.Parameter(bias, requires_grad=False)
        else:
            self.bias = nn.Parameter(torch.Tensor(d_out), requires_grad=False)

        # m = Magnitude column-wise across output dimension
        self.m = nn.Parameter(self.weight.norm(p=2, dim=0, keepdim=True))
        
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.lora_A = nn.Parameter(torch.randn(d_out, rank)*std_dev)
        self.lora_B = nn.Parameter(torch.zeros(rank, d_in))

    def forward(self, x):
        lora = torch.matmul(self.lora_A, self.lora_B)
        adapted = self.weight + lora
        column_norm = adapted.norm(p=2, dim=0, keepdim=True)
        norm_adapted = adapted / column_norm
        calc_weights = self.m * norm_adapted
        return F.linear(x, calc_weights, self.bias)

In [None]:
m = nn.Linear(128,256,bias=False).cuda()

In [None]:
dora = DoRALayer(128,256,weight=m.weight).cuda()

In [None]:
dora(x)

tensor([[[-0.2144, -0.1476, -0.0111,  ...,  0.3745,  0.1425, -0.1142],
         [ 0.3202, -0.2039,  0.7589,  ..., -0.2859, -1.4159,  0.9623],
         [-0.1714,  0.4437, -0.3377,  ...,  1.4839,  1.1261,  0.1933],
         [-0.5015,  0.3812,  1.3170,  ...,  0.3666,  0.0282,  0.3237]],

        [[ 0.2638,  0.0497,  0.2547,  ...,  0.5097,  0.0237,  0.8447],
         [ 0.2788, -0.1295, -0.6743,  ...,  0.1924,  1.0936,  0.3154],
         [-0.4722,  0.2377,  0.0317,  ..., -0.6017, -0.4683, -0.1920],
         [-0.4582,  0.4022, -0.5113,  ...,  0.9794,  1.3093, -0.3878]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [None]:
m(x)

tensor([[[-0.2144, -0.1476, -0.0111,  ...,  0.3745,  0.1425, -0.1142],
         [ 0.3202, -0.2039,  0.7589,  ..., -0.2859, -1.4159,  0.9623],
         [-0.1714,  0.4437, -0.3377,  ...,  1.4839,  1.1261,  0.1933],
         [-0.5015,  0.3812,  1.3170,  ...,  0.3666,  0.0282,  0.3237]],

        [[ 0.2638,  0.0497,  0.2547,  ...,  0.5097,  0.0237,  0.8447],
         [ 0.2788, -0.1295, -0.6743,  ...,  0.1924,  1.0936,  0.3154],
         [-0.4722,  0.2377,  0.0317,  ..., -0.6017, -0.4683, -0.1920],
         [-0.4582,  0.4022, -0.5113,  ...,  0.9794,  1.3093, -0.3878]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [None]:
x.is_meta

False

### Tests

In [None]:
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer

[36mhqq_aten package available. Set backend to HQQBackend.ATEN for faster inference and HQQBackend.ATEN_BACKPROP for faster training![0m


In [None]:
compute_dtype = torch.bfloat16
model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
cfg.num_hidden_layers = 2 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
model = AutoModelForCausalLM.from_config(cfg)

In [None]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): L

In [None]:
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, view_as_float=True)
HQQModelForCausalLM.quantize_model_(model, quant_config, compute_dtype=torch.bfloat16)

100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 144.69it/s]
100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00,  3.38s/it]


In [None]:
model.model.layers[0].self_attn.q_proj.meta

In [None]:
model.model.layers[0].self_attn.q_proj.W_q

In [None]:
model.save_quantized("/weka/home-keremturgutlu/models")

In [None]:
import json
quantized_config = json.load(open("/weka/home-keremturgutlu/models/config.json"))
quantized_weights = torch.load("/weka/home-keremturgutlu/models/qmodel.pt")

In [None]:
quantized_config

In [None]:
list(quantized_weights.keys())

In [None]:
quantized_weights['model.layers.0.self_attn.q_proj']

In [None]:
model_qt = HQQModelForCausalLM.from_quantized("/weka/home-keremturgutlu/models")

100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1804.39it/s]
100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 364.04it/s]


In [None]:
list(n for n,p in model_qt.named_modules())

['',
 'model',
 'model.embed_tokens',
 'model.layers',
 'model.layers.0',
 'model.layers.0.self_attn',
 'model.layers.0.self_attn.q_proj',
 'model.layers.0.self_attn.k_proj',
 'model.layers.0.self_attn.v_proj',
 'model.layers.0.self_attn.o_proj',
 'model.layers.0.self_attn.rotary_emb',
 'model.layers.0.mlp',
 'model.layers.0.mlp.gate_proj',
 'model.layers.0.mlp.up_proj',
 'model.layers.0.mlp.down_proj',
 'model.layers.0.mlp.act_fn',
 'model.layers.0.input_layernorm',
 'model.layers.0.post_attention_layernorm',
 'model.layers.1',
 'model.layers.1.self_attn',
 'model.layers.1.self_attn.q_proj',
 'model.layers.1.self_attn.k_proj',
 'model.layers.1.self_attn.v_proj',
 'model.layers.1.self_attn.o_proj',
 'model.layers.1.self_attn.rotary_emb',
 'model.layers.1.mlp',
 'model.layers.1.mlp.gate_proj',
 'model.layers.1.mlp.up_proj',
 'model.layers.1.mlp.down_proj',
 'model.layers.1.mlp.act_fn',
 'model.layers.1.input_layernorm',
 'model.layers.1.post_attention_layernorm',
 'model.norm',
 'lm_hea

In [None]:
def assert_state_dict(v1,v2):
    if isinstance(v1, torch.Tensor):
        assert torch.isclose(v1,v2, rtol=1e-5).float().mean().item() > 0.99
    if isinstance(v1, dict):
        for _k,_v in v1.items():
            if isinstance(_v, torch.Tensor):
                assert torch.equal(_v, v2[_k])
            else:
                assert _v == v2[_k]

In [None]:
for n,p in model.named_parameters():
    
    module_key, _, value_key = n.rpartition('.')
    
    d1 = model.get_submodule(module_key).state_dict()
    d2 = model_qt.get_submodule(module_key).state_dict()
    
    for (k1,v1),(k2,v2) in zip(d1.items(), d2.items()):
        assert k1 == k2
        assert_state_dict(v1,v2)

In [None]:
import safetensors
from safetensors.torch import save_file
import torch

In [None]:
weights_init = safetensors.torch.load_file("/weka/home-keremturgutlu/models/hqq_lora_dummy_init/model_state_dict.safetensors")
weights = safetensors.torch.load_file("/weka/home-keremturgutlu/models/hqq_lora_dummy/model_state_dict.safetensors")

In [None]:
weights

{'_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.down_proj.lora_AB.0.weight': tensor([[-9.1553e-03,  6.0120e-03, -1.9379e-03,  ..., -7.8201e-04,
          -6.0120e-03,  7.2861e-04],
         [ 1.8616e-03,  8.5449e-03,  6.9275e-03,  ..., -1.3885e-03,
           7.6599e-03,  3.2043e-03],
         [ 7.6599e-03,  3.3417e-03,  4.3030e-03,  ...,  4.6082e-03,
          -5.3711e-03, -1.1139e-03],
         ...,
         [-4.0894e-03, -4.3945e-03,  8.1787e-03,  ...,  5.4321e-03,
          -8.4839e-03, -8.4839e-03],
         [-6.6757e-05,  3.9368e-03,  6.0272e-04,  ..., -5.1270e-03,
          -4.8218e-03, -5.3711e-03],
         [ 4.9744e-03,  1.6556e-03, -1.5640e-03,  ...,  4.1504e-03,
           7.7515e-03,  6.8359e-03]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.down_proj.lora_AB.1.weight': tensor([[-6.2943e-05,  7.9155e-05, -7.9632e-05,  ...,

In [None]:
for k, v in weights_init.items():

    if ('base_layer' in k) or ('W_q' in k):    
        if not torch.equal(v.view(torch.uint8), weights[k].view(torch.uint8)):
            print("Changed", k)
    else:
        if not torch.equal(v, weights[k]):
            print("Changed", k)

Changed model.layers.0.mlp.down_proj.lora_AB.0.weight
Changed model.layers.0.mlp.down_proj.lora_AB.1.weight
Changed model.layers.0.mlp.gate_proj.lora_AB.0.weight
Changed model.layers.0.mlp.gate_proj.lora_AB.1.weight
Changed model.layers.0.mlp.up_proj.lora_AB.0.weight
Changed model.layers.0.mlp.up_proj.lora_AB.1.weight
Changed model.layers.0.self_attn.k_proj.lora_AB.0.weight
Changed model.layers.0.self_attn.k_proj.lora_AB.1.weight
Changed model.layers.0.self_attn.q_proj.lora_AB.0.weight
Changed model.layers.0.self_attn.q_proj.lora_AB.1.weight
Changed model.layers.0.self_attn.v_proj.lora_AB.0.weight
Changed model.layers.0.self_attn.v_proj.lora_AB.1.weight
Changed model.layers.1.mlp.down_proj.lora_AB.0.weight
Changed model.layers.1.mlp.down_proj.lora_AB.1.weight
Changed model.layers.1.mlp.gate_proj.lora_AB.0.weight
Changed model.layers.1.mlp.gate_proj.lora_AB.1.weight
Changed model.layers.1.mlp.up_proj.lora_AB.0.weight
Changed model.layers.1.mlp.up_proj.lora_AB.1.weight
Changed model.laye