### BNB with Tensor Parallelism

In [1]:
import torch
import copy
from bitsandbytes.nn.modules import Params4bit, Linear4bit
from bitsandbytes.functional import dequantize_4bit
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit, QuantState

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

In [3]:
blocksize = 64
quant_type = "nf4"
quant_storage = torch.uint8

data = torch.randn(128,256).to(torch.bfloat16)
param = Params4bit(data, blocksize=blocksize, quant_type=quant_type, 
                   quant_storage=quant_storage, compress_statistics=False)

In [4]:
if quant_storage == torch.uint8:
    pack_factor = 2

In [5]:
param

Parameter containing:
Parameter(Params4bit([[-0.1260, -1.6250,  0.0508,  ...,  0.3379, -1.7500,  1.1484],
            [-0.2305, -1.3047, -2.0312,  ..., -0.1328, -0.3125,  1.8594],
            [-1.1875,  0.9766,  0.0840,  ..., -2.2812,  0.6016,  1.1328],
            ...,
            [ 1.1875,  1.0234,  1.3281,  ...,  1.5781, -0.7227,  0.5156],
            [ 0.0610, -0.3477,  1.2578,  ..., -1.7109,  1.0781,  0.6914],
            [ 0.2002, -0.7422,  0.0527,  ...,  1.1953, -0.0952,  0.0854]],
           dtype=torch.bfloat16))

In [6]:
param.cuda();

In [7]:
param.shape, param.quant_type, param.quant_storage, param.blocksize

(torch.Size([16384, 1]), 'nf4', torch.uint8, 64)

In [8]:
param.shape

torch.Size([16384, 1])

In [9]:
param.numel() / data.numel()

0.5

In [10]:
param.quant_state.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([2.6406, 2.4062, 3.7812, 2.2344, 2.7500, 3.2500, 2.2500, 2.8594, 2.7656,
         2.4688, 3.2969, 3.4844, 2.7812, 2.2188, 2.9375, 3.2812, 2.1094, 2.8125,
         2.2344, 3.2344, 3.2188, 2.8125, 2.3906, 2.2812, 2.7031, 3.5938, 2.5156,
         2.9219, 2.1875, 3.0469, 2.5781, 2.5781, 2.2188, 2.8125, 2.3281, 2.6094,
         2.9219, 2.7344, 2.1719, 3.4375, 2.5000, 3.0625, 2.4688, 2.5625, 2.9531,
         2.8125, 2.7969, 2.2500, 3.6094, 2.4062, 2.0781, 2.4531, 3.3281, 2.6250,
         3.1250, 2.7969, 2.8594, 2.4844, 2.8906, 2.5000, 2.1719, 3.1406, 2.6094,
         2.4375, 3.0000, 2.2188, 3.4062, 3.9688, 2.6250, 2.5938, 2.2500, 3.0000,
         2.7188, 2.5625, 2.5156, 3.3750, 2.5469, 2.8125, 3.0781, 2.4688, 2.5781,
         3.3281, 2.5781, 1.7109, 2.9375, 2.0781, 2.2031, 2.3750, 2.0312, 2.3750,
         2.4688, 2.6406, 2.2812, 2.2344, 3.2031, 2.5312, 1.8203, 2.3438, 2.0625,
         2.7500, 2.5312, 2.6094, 2.4844, 2.7188, 2.6562, 2.1094, 1.8594, 2.46

In [11]:
input_size_per_partition = 64
output_size_per_partition = 64

In [12]:
# row-major quantization, reshape for vllm tensor parallelism
qweight = param.data.reshape(data.size(0), data.size(1) // pack_factor); qweight.shape

torch.Size([128, 128])

In [13]:
qweight

tensor([[ 97, 119,  55,  ...,  65, 201,  29],
        [ 98,  24,  97,  ..., 156, 134, 110],
        [ 59, 112,  71,  ...,  58,  81, 155],
        ...,
        [203, 201, 227,  ...,  20, 173,  73],
        [117, 217, 181,  ...,  65,  97, 202],
        [132, 122, 134,  ...,  38,  44, 119]], device='cuda:0',
       dtype=torch.uint8)

In [14]:
qweight.view(-1,1)

tensor([[ 97],
        [119],
        [ 55],
        ...,
        [ 38],
        [ 44],
        [119]], device='cuda:0', dtype=torch.uint8)

In [15]:
deqweight = dequantize_4bit(qweight.view(-1,1), param.quant_state, blocksize=blocksize)

In [16]:
(data - torch.randn_like(data)).norm()

tensor(254., dtype=torch.bfloat16)

In [17]:
(data - deqweight.cpu()).norm()

tensor(16.5000, dtype=torch.bfloat16)

In [18]:
x = torch.randn(4, 128).cuda().to(torch.bfloat16)

In [19]:
param.quant_state.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([2.6406, 2.4062, 3.7812, 2.2344, 2.7500, 3.2500, 2.2500, 2.8594, 2.7656,
         2.4688, 3.2969, 3.4844, 2.7812, 2.2188, 2.9375, 3.2812, 2.1094, 2.8125,
         2.2344, 3.2344, 3.2188, 2.8125, 2.3906, 2.2812, 2.7031, 3.5938, 2.5156,
         2.9219, 2.1875, 3.0469, 2.5781, 2.5781, 2.2188, 2.8125, 2.3281, 2.6094,
         2.9219, 2.7344, 2.1719, 3.4375, 2.5000, 3.0625, 2.4688, 2.5625, 2.9531,
         2.8125, 2.7969, 2.2500, 3.6094, 2.4062, 2.0781, 2.4531, 3.3281, 2.6250,
         3.1250, 2.7969, 2.8594, 2.4844, 2.8906, 2.5000, 2.1719, 3.1406, 2.6094,
         2.4375, 3.0000, 2.2188, 3.4062, 3.9688, 2.6250, 2.5938, 2.2500, 3.0000,
         2.7188, 2.5625, 2.5156, 3.3750, 2.5469, 2.8125, 3.0781, 2.4688, 2.5781,
         3.3281, 2.5781, 1.7109, 2.9375, 2.0781, 2.2031, 2.3750, 2.0312, 2.3750,
         2.4688, 2.6406, 2.2812, 2.2344, 3.2031, 2.5312, 1.8203, 2.3438, 2.0625,
         2.7500, 2.5312, 2.6094, 2.4844, 2.7188, 2.6562, 2.1094, 1.8594, 2.46

In [29]:
input_size, output_size = data.size()

In [34]:
input_size, output_size

(128, 256)

### Column Parallel

The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p].

In [37]:
num_partitions = 2

In [20]:
qweight.shape

torch.Size([128, 128])

In [21]:
output_size_per_partition

64

In [22]:
qweight_partitioned = qweight.split(output_size_per_partition, dim=1)

In [23]:
len(qweight_partitioned)

2

In [24]:
for w in qweight_partitioned: print(w.shape)

torch.Size([128, 64])
torch.Size([128, 64])


In [25]:
orig_absmax = param.quant_state.absmax

In [26]:
orig_absmax.shape

torch.Size([512])

In [31]:
orig_absmax_reshaped = orig_absmax.reshape(input_size, data.size(1) // blocksize)

In [33]:
orig_absmax_reshaped.dtype, orig_absmax_reshaped.shape

(torch.float32, torch.Size([128, 4]))

In [91]:
num_partitions = len(qweight_partitioned)

In [38]:
absmax_partitioned = orig_absmax_reshaped.split(orig_absmax_reshaped.size(1) // num_partitions, dim=1)

In [39]:
for a in absmax_partitioned: print(a.shape)

torch.Size([128, 2])
torch.Size([128, 2])


In [40]:
len(qweight_partitioned), len(absmax_partitioned)

(2, 2)

In [41]:
quant_state = copy.deepcopy(param.quant_state)

In [42]:
quant_state.shape = torch.Size([quant_state.shape[0], quant_state.shape[1]//num_partitions])

In [43]:
quant_state.shape

torch.Size([128, 128])

In [44]:
quant_state.absmax = absmax_partitioned[0].contiguous().view(-1)
deqweight_part1 = dequantize_4bit(qweight_partitioned[0].contiguous().view(-1,1), quant_state=quant_state)

quant_state.absmax = absmax_partitioned[1].contiguous().view(-1)
deqweight_part2 = dequantize_4bit(qweight_partitioned[1].contiguous().view(-1,1), quant_state=quant_state)

In [45]:
deqweight_part1.shape, deqweight_part2.shape

(torch.Size([128, 128]), torch.Size([128, 128]))

In [46]:
# quant_state.as_dict()

In [47]:
deqweight.shape

torch.Size([128, 256])

In [48]:
torch.cat([deqweight_part1, deqweight_part2], dim=1).shape

torch.Size([128, 256])

In [49]:
assert torch.equal(deqweight, torch.cat([deqweight_part1, deqweight_part2], dim=1))

In [50]:
out1 = (x @ deqweight_part2)

In [51]:
out2 = bnb.matmul_4bit(x, qweight_partitioned[1].contiguous().view(-1,1), quant_state=quant_state)

In [52]:
assert torch.equal(out1, out2)

In [53]:
out2.shape

torch.Size([4, 128])

### Row Parallel

The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:

```
    -   -
    | A_1 |
    | .   |
A = | .   |        X = [X_1, ..., X_p]
    | .   |
    | A_p |
    -   -
```

In [54]:
qweight_partitioned = qweight.split(output_size_per_partition, dim=0)

In [55]:
num_partitions = len(qweight_partitioned); num_partitions

2

In [56]:
for w in qweight_partitioned: print(w.shape)

torch.Size([64, 128])
torch.Size([64, 128])


In [57]:
orig_absmax = param.quant_state.absmax

In [58]:
orig_absmax_reshaped = orig_absmax.reshape(input_size, data.size(1) // blocksize)

In [59]:
absmax_partitioned = orig_absmax.split(len(orig_absmax) // num_partitions, dim=0)

In [60]:
len(absmax_partitioned)

2

In [61]:
quant_state = copy.deepcopy(param.quant_state)

In [62]:
quant_state.shape = torch.Size([quant_state.shape[0]//num_partitions, quant_state.shape[1]]); quant_state.shape

torch.Size([64, 256])

In [63]:
quant_state.absmax = absmax_partitioned[0].contiguous().view(-1)
deqweight_part1 = dequantize_4bit(qweight_partitioned[0].contiguous().view(-1,1), quant_state=quant_state)

quant_state.absmax = absmax_partitioned[1].contiguous().view(-1)
deqweight_part2 = dequantize_4bit(qweight_partitioned[1].contiguous().view(-1,1), quant_state=quant_state)

In [64]:
assert torch.equal(deqweight, torch.cat([deqweight_part1, deqweight_part2], dim=0))

### Loading

In [3]:
from vllm.model_executor.weight_utils import default_weight_loader, hf_model_weights_iterator

INFO 03-29 13:23:10 pynccl_utils.py:13] vLLM is using nccl==2.18.1


In [2]:
weights_iterator = hf_model_weights_iterator("meta-llama/Llama-2-7b-hf")

In [3]:
for name, loaded_weight in weights_iterator: break

INFO 03-29 08:48:45 weight_utils.py:177] Using model weights format ['*.safetensors']


model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

In [5]:
name, loaded_weight

('model.embed_tokens.weight',
 tensor([[ 1.2517e-06, -1.7881e-06, -4.3511e-06,  ...,  8.9407e-07,
          -6.5565e-06,  8.9407e-07],
         [ 1.8616e-03, -3.3722e-03,  3.9864e-04,  ..., -8.3008e-03,
           2.5787e-03, -3.9368e-03],
         [ 1.0986e-02,  9.8877e-03, -5.0964e-03,  ...,  2.5177e-03,
           7.7057e-04, -5.0049e-03],
         ...,
         [-1.3977e-02, -2.7313e-03, -1.9897e-02,  ..., -1.0437e-02,
           9.5825e-03, -1.8005e-03],
         [-1.0742e-02,  9.3384e-03,  1.2939e-02,  ..., -3.3203e-02,
          -1.6357e-02,  3.3875e-03],
         [-8.3008e-03, -4.0588e-03, -1.1063e-03,  ...,  3.4790e-03,
          -1.2939e-02,  3.1948e-05]], dtype=torch.float16))

In [29]:
weights_iterator = hf_model_weights_iterator("TheBloke/CodeUp-Alpha-13B-HF-AWQ")

In [32]:
for name, loaded_weight in weights_iterator: 
    if 'scales' in name or 'zeros' in name:
        print(name)

model.layers.0.mlp.down_proj.qzeros
model.layers.0.mlp.down_proj.scales
model.layers.0.mlp.gate_proj.qzeros
model.layers.0.mlp.gate_proj.scales
model.layers.0.mlp.up_proj.qzeros
model.layers.0.mlp.up_proj.scales
model.layers.0.self_attn.k_proj.qzeros
model.layers.0.self_attn.k_proj.scales
model.layers.0.self_attn.o_proj.qzeros
model.layers.0.self_attn.o_proj.scales
model.layers.0.self_attn.q_proj.qzeros
model.layers.0.self_attn.q_proj.scales
model.layers.0.self_attn.v_proj.qzeros
model.layers.0.self_attn.v_proj.scales
model.layers.1.mlp.down_proj.qzeros
model.layers.1.mlp.down_proj.scales
model.layers.1.mlp.gate_proj.qzeros
model.layers.1.mlp.gate_proj.scales
model.layers.1.mlp.up_proj.qzeros
model.layers.1.mlp.up_proj.scales
model.layers.1.self_attn.k_proj.qzeros
model.layers.1.self_attn.k_proj.scales
model.layers.1.self_attn.o_proj.qzeros
model.layers.1.self_attn.o_proj.scales
model.layers.1.self_attn.q_proj.qzeros
model.layers.1.self_attn.q_proj.scales
model.layers.1.self_attn.v_pro

### Create Quantized Model Files

In [29]:
from pathlib import Path
import os, json
from safetensors.torch import save_file

In [19]:
model_dir = Path("/home/ubuntu/models/llama-7b-hf-nf4-quantized")
os.makedirs(model_dir, exist_ok=True)

In [4]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# original quantized layers from fsdp_qlora/train.py
# ["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]

In [5]:
# Similar to AWQ for now
quantized_layers = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

In [8]:
quantized_state_dict = copy.deepcopy(model.state_dict())

In [9]:
pack_factor = 2
blocksize = 64

In [12]:
for n,p in model.state_dict().items():
    if any(l in n for l in quantized_layers) and "weight" in n:
        # output_size x input_size
        print(n, p.shape, p.t().shape)
        param = Params4bit(p.t(), quant_type="nf4", blocksize=blocksize, compress_statistics=False, quant_storage=torch.uint8)
        input_size, output_size = p.t().shape
        param.cuda();

        # reshape for tensor parallelism
        qweight, absmax = param.data.cpu(), param.quant_state.absmax.cpu()        
        qweight = qweight.reshape(input_size, output_size // pack_factor)
        absmax = absmax.reshape(input_size, output_size // blocksize)
                
        quantized_state_dict[n] = qweight
        quantized_state_dict[n.replace(".weight", ".absmax")] = absmax

        param = None
        torch.cuda.empty_cache()

model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096]) torch.Size([4096, 11008])
model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096]) torch.Size([4096, 11008])
model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008]) torch.Size([11008, 4096])
model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.self_attn.k_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.self_attn.v_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.mlp.gate_pr

In [20]:
# save quantized weights
save_file(quantized_state_dict, model_dir/"model_state_dict.safetensors")

In [22]:
# create and save quantization config
quant_config_filename = model_dir/"quantize_config.json"

In [28]:
quant_config_dict = {
    "weight_bits" : 4,
    "blocksize" : 64,
    "quant_type" : "nf4",
    "quant_storage" : "uint8",
    "compress_statistics" : False
}

In [30]:
with open(quant_config_filename, "w+") as f: json.dump(quant_config_dict, f)

In [31]:
model_config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")

In [33]:
# save model config
model_config_filename = model_dir/"config.json"

In [37]:
with open(model_config_filename, "w+") as f: json.dump(model_config.to_dict(), f)

### BNB Quantized VLLM

In [52]:
import torch
import safetensors
import safetensors.torch
from pathlib import Path
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit, QuantState

In [1]:
from vllm import LLM, SamplingParams

In [2]:
model_dir = "/home/ubuntu/models/llama-7b-hf-nf4-quantized"

In [3]:
llm = LLM(model=model_dir, tokenizer="meta-llama/Llama-2-7b-hf", dtype="bfloat16", quantization="bnb")

INFO 04-02 11:37:13 llm_engine.py:70] Initializing an LLM engine (v0.3.3) with config: model='/home/ubuntu/models/llama-7b-hf-nf4-quantized', tokenizer='meta-llama/Llama-2-7b-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=True, quantization=bnb, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 04-02 11:37:14 pynccl_utils.py:13] vLLM is using nccl==2.18.1
INFO 04-02 11:37:15 selector.py:44] flash_attn is not found.
INFO 04-02 11:37:15 selector.py:20] Using XFormers backend.
INFO 04-02 11:37:18 model_runner.py:104] Loading model weights took 3.9585 GB
input_ids: torch.Size([4096])
hidden_states: torch.Size([4096, 4096])
INFO 04-02 11:37:20 gpu_executor.py:94] # GPU blocks: 2118, # CPU blocks: 512
INFO 04-02 11:37:22 model_runner.py:770] Capturing the model for CUDA graphs. This may lead to unex

In [4]:
outputs = llm.generate(['I'], sampling_params=SamplingParams(max_tokens=16, temperature=0.0))

Processed prompts:   0%|                                                                   | 0/1 [00:00<?, ?it/s]

input_ids: torch.Size([2])
hidden_states: torch.Size([2, 4096])


Processed prompts: 100%|███████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.93it/s]


In [5]:
outputs[0].outputs[0].text

"'ayay Vic Vic Vic Vic Vic Vic South South South South South South South"

In [14]:
llm.llm_engine.model_executor.model_config.quantization

'bnb'

In [35]:
model = llm.llm_engine.model_executor.driver_worker.model_runner.model

In [27]:
quantized_state_dict = safetensors.torch.load_file(Path(model_dir)/"model_state_dict.safetensors")

In [89]:
quant_map = torch.tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
                           0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000]).to(torch.cuda.current_device())

In [104]:
# Sanity Check if weights are loaded correctly into vllm model.

for n, p in model.named_parameters():

    print("Cheking:", n)
    
    if 'qkv_proj' in n:
        # Loaded qkv
        qkv_weight = model.get_parameter(n)
        qkv_absmax = model.get_parameter(n.replace(".weight", ".absmax"))
        
        quant_state = QuantState(qkv_absmax.contiguous().view(-1), dtype=torch.bfloat16)
        quant_state.shape = torch.Size([qkv_weight.shape[0], qkv_weight.shape[1] * 2])
        quant_state.blocksize = 64
        quant_state.quant_type = "nf4"
        quant_state.code = quant_map
        W_dq = bnb.functional.dequantize_4bit(qkv_weight.contiguous().view(-1,1), quant_state=quant_state)

        # Saved q proj
        q_proj_weight_name = n.replace("qkv_proj", "q_proj")
        q_proj_absmax_name = n.replace("qkv_proj", "q_proj").replace(".weight", ".absmax")
        quant_state = QuantState(quantized_state_dict[q_proj_absmax_name].cuda().contiguous().view(-1), dtype=torch.bfloat16)
        quant_state.shape = torch.Size([qkv_weight.shape[0], qkv_weight.shape[1] * 2 // 3])
        quant_state.blocksize = 64
        quant_state.quant_type = "nf4"
        quant_state.code = quant_map
        W_q_proj_dq = bnb.functional.dequantize_4bit(quantized_state_dict[q_proj_weight_name].contiguous().view(-1,1).cuda(), quant_state=quant_state)

        # Saved k proj
        k_proj_weight_name = n.replace("qkv_proj", "k_proj")
        k_proj_absmax_name = n.replace("qkv_proj", "k_proj").replace(".weight", ".absmax")
        quant_state = QuantState(quantized_state_dict[k_proj_absmax_name].cuda().contiguous().view(-1), dtype=torch.bfloat16)
        quant_state.shape = torch.Size([qkv_weight.shape[0], qkv_weight.shape[1] * 2 // 3])
        quant_state.blocksize = 64
        quant_state.quant_type = "nf4"
        quant_state.code = quant_map
        W_k_proj_dq = bnb.functional.dequantize_4bit(quantized_state_dict[k_proj_weight_name].contiguous().view(-1,1).cuda(), quant_state=quant_state)

        # Saved v proj
        v_proj_weight_name = n.replace("qkv_proj", "v_proj")
        v_proj_absmax_name = n.replace("qkv_proj", "v_proj").replace(".weight", ".absmax")
        quant_state = QuantState(quantized_state_dict[v_proj_absmax_name].cuda().contiguous().view(-1), dtype=torch.bfloat16)
        quant_state.shape = torch.Size([qkv_weight.shape[0], qkv_weight.shape[1] * 2 // 3])
        quant_state.blocksize = 64
        quant_state.quant_type = "nf4"
        quant_state.code = quant_map
        W_v_proj_dq = bnb.functional.dequantize_4bit(quantized_state_dict[v_proj_weight_name].contiguous().view(-1,1).cuda(), quant_state=quant_state)
        
        assert torch.equal(W_dq, torch.cat([W_q_proj_dq, W_k_proj_dq, W_v_proj_dq], dim=1))
    
    elif 'gate_up_proj' in n:
        # Loaded gate_up
        gate_up_weight = model.get_parameter(n)
        gate_up_absmax = model.get_parameter(n.replace(".weight", ".absmax"))
        
        quant_state = QuantState(gate_up_absmax.contiguous().view(-1), dtype=torch.bfloat16)
        quant_state.shape = torch.Size([gate_up_weight.shape[0], gate_up_weight.shape[1] * 2])
        quant_state.blocksize = 64
        quant_state.quant_type = "nf4"
        quant_state.code = quant_map
        W_dq = bnb.functional.dequantize_4bit(gate_up_weight.contiguous().view(-1,1), quant_state=quant_state)

        # Saved gate_proj
        gate_proj_weight_name = n.replace("gate_up_proj", "gate_proj")
        gate_proj_absmax_name = n.replace("gate_up_proj", "gate_proj").replace(".weight", ".absmax")
        quant_state = QuantState(quantized_state_dict[gate_proj_absmax_name].cuda().contiguous().view(-1), dtype=torch.bfloat16)
        quant_state.shape = torch.Size([gate_up_weight.shape[0], gate_up_weight.shape[1] * 2 // 2])
        quant_state.blocksize = 64
        quant_state.quant_type = "nf4"
        quant_state.code = quant_map
        W_gate_proj_dq = bnb.functional.dequantize_4bit(quantized_state_dict[gate_proj_weight_name].contiguous().view(-1,1).cuda(), quant_state=quant_state)

        # Saved up_proj
        up_proj_weight_name = n.replace("gate_up_proj", "up_proj")
        up_proj_absmax_name = n.replace("gate_up_proj", "up_proj").replace(".weight", ".absmax")
        quant_state = QuantState(quantized_state_dict[up_proj_absmax_name].cuda().contiguous().view(-1), dtype=torch.bfloat16)
        quant_state.shape = torch.Size([gate_up_weight.shape[0], gate_up_weight.shape[1] * 2 // 2])
        quant_state.blocksize = 64
        quant_state.quant_type = "nf4"
        quant_state.code = quant_map
        W_up_proj_dq = bnb.functional.dequantize_4bit(quantized_state_dict[up_proj_weight_name].contiguous().view(-1,1).cuda(), quant_state=quant_state)
        
        assert torch.equal(W_dq, torch.cat([W_up_proj_dq, W_up_proj_dq], dim=1))
    
    
    else:
        assert torch.equal(quantized_state_dict[n].data, p.data.cpu())

    print(p.view(-1)[:10])

Cheking: model.embed_tokens.weight
tensor([ 1.2517e-06, -1.7881e-06, -4.3511e-06,  8.0466e-06,  1.9073e-06,
        -5.6028e-06,  3.0994e-06,  1.1921e-06, -6.7949e-06, -1.6689e-06],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Cheking: model.layers.0.self_attn.qkv_proj.weight
tensor([ 77,  36,  56, 101, 133, 154, 105,  85, 117,  40], device='cuda:0',
       dtype=torch.uint8)
Cheking: model.layers.0.self_attn.qkv_proj.absmax


AssertionError: 