### BNB with Tensor Parallelism

In [1]:
import torch
import copy
from bitsandbytes.nn.modules import Params4bit, Linear4bit
from bitsandbytes.functional import dequantize_4bit
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit, QuantState

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

In [7]:
blocksize = 64
quant_type = "nf4"
quant_storage = torch.uint8

data = torch.randn(128,256).to(torch.bfloat16)
param = Params4bit(data, blocksize=blocksize, quant_type=quant_type, 
                   quant_storage=quant_storage, compress_statistics=False)

In [8]:
if quant_storage == torch.uint8:
    pack_factor = 2

In [9]:
param

Parameter containing:
Parameter(Params4bit([[ 1.3125,  1.9609, -0.0356,  ...,  0.2197,  0.0903,  0.2695],
            [-2.0312, -0.8555,  0.9688,  ...,  0.4043,  0.6953,  0.2500],
            [ 0.2275,  0.7500,  0.1885,  ..., -1.0391,  1.3516,  1.3516],
            ...,
            [ 0.5742,  0.3477, -0.2676,  ...,  0.8906,  0.3848,  0.6484],
            [ 2.3281, -1.3125,  0.9180,  ..., -1.9609,  0.4219, -0.3652],
            [ 1.2969,  0.2656, -0.5664,  ...,  0.8164,  0.6562, -1.2109]],
           dtype=torch.bfloat16))

In [10]:
param.cuda();

In [11]:
param.shape, param.quant_type, param.quant_storage, param.blocksize

(torch.Size([16384, 1]), 'nf4', torch.uint8, 64)

In [12]:
param.numel() / data.numel()

0.5

In [13]:
param.quant_state.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([2.5156, 2.8906, 2.3906, 2.8438, 2.5625, 3.0625, 1.9219, 2.3906, 2.8594,
         2.9531, 2.5469, 2.2812, 2.7812, 3.7969, 3.0625, 2.4062, 2.2500, 2.8125,
         2.0781, 2.7656, 2.8281, 2.7500, 3.0625, 3.2969, 3.2656, 1.9922, 2.3750,
         2.8594, 2.2344, 2.4844, 3.3125, 3.0625, 1.8828, 2.4688, 2.2031, 2.7344,
         2.8594, 2.7656, 2.2344, 2.9844, 2.1719, 2.9062, 2.2188, 1.8281, 2.6406,
         2.1094, 3.4688, 3.0000, 3.3750, 2.6250, 2.6719, 3.2969, 2.4688, 2.5312,
         2.1562, 2.7969, 2.2031, 2.6719, 3.0781, 2.1875, 2.1406, 2.7969, 2.3594,
         2.2969, 2.3750, 2.7188, 2.9844, 2.7656, 2.9688, 3.0781, 2.4219, 2.7969,
         2.8594, 3.1875, 3.2344, 3.2031, 2.6250, 2.2344, 2.1250, 2.3750, 2.3750,
         2.3594, 2.1875, 3.0000, 2.1562, 2.8594, 2.0156, 2.6250, 2.8906, 2.1562,
         2.2812, 2.4375, 2.6250, 3.1562, 2.2031, 2.3750, 2.2344, 2.3594, 2.5000,
         3.1875, 2.3281, 3.0625, 2.7344, 3.1719, 2.2031, 3.0781, 2.5156, 2.23

In [14]:
input_size_per_partition = 64
output_size_per_partition = 64

In [15]:
# row-major quantization, reshape for vllm tensor parallelism
qweight = param.data.reshape(data.size(0), data.size(1) // pack_factor); qweight.shape

torch.Size([128, 128])

In [16]:
qweight

tensor([[222, 115, 183,  ..., 106, 120, 120],
        [ 20, 185, 205,  ..., 217,  25, 168],
        [138, 132,  81,  ..., 160,  99, 221],
        ...,
        [152, 103, 182,  ..., 165,  43, 154],
        [226, 181, 167,  ..., 133, 145, 149],
        [200,  82, 216,  ..., 183, 154, 163]], device='cuda:0',
       dtype=torch.uint8)

In [17]:
qweight.view(-1,1)

tensor([[222],
        [115],
        [183],
        ...,
        [183],
        [154],
        [163]], device='cuda:0', dtype=torch.uint8)

In [18]:
deqweight = dequantize_4bit(qweight.view(-1,1), param.quant_state, blocksize=blocksize)

In [19]:
(data - torch.randn_like(data)).norm()

tensor(256., dtype=torch.bfloat16)

In [20]:
(data - deqweight.cpu()).norm()

tensor(16.7500, dtype=torch.bfloat16)

In [21]:
x = torch.randn(4, 128).cuda().to(torch.bfloat16)

In [25]:
param.quant_state.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([2.5156, 2.8906, 2.3906, 2.8438, 2.5625, 3.0625, 1.9219, 2.3906, 2.8594,
         2.9531, 2.5469, 2.2812, 2.7812, 3.7969, 3.0625, 2.4062, 2.2500, 2.8125,
         2.0781, 2.7656, 2.8281, 2.7500, 3.0625, 3.2969, 3.2656, 1.9922, 2.3750,
         2.8594, 2.2344, 2.4844, 3.3125, 3.0625, 1.8828, 2.4688, 2.2031, 2.7344,
         2.8594, 2.7656, 2.2344, 2.9844, 2.1719, 2.9062, 2.2188, 1.8281, 2.6406,
         2.1094, 3.4688, 3.0000, 3.3750, 2.6250, 2.6719, 3.2969, 2.4688, 2.5312,
         2.1562, 2.7969, 2.2031, 2.6719, 3.0781, 2.1875, 2.1406, 2.7969, 2.3594,
         2.2969, 2.3750, 2.7188, 2.9844, 2.7656, 2.9688, 3.0781, 2.4219, 2.7969,
         2.8594, 3.1875, 3.2344, 3.2031, 2.6250, 2.2344, 2.1250, 2.3750, 2.3750,
         2.3594, 2.1875, 3.0000, 2.1562, 2.8594, 2.0156, 2.6250, 2.8906, 2.1562,
         2.2812, 2.4375, 2.6250, 3.1562, 2.2031, 2.3750, 2.2344, 2.3594, 2.5000,
         3.1875, 2.3281, 3.0625, 2.7344, 3.1719, 2.2031, 3.0781, 2.5156, 2.23

### Column Parallel

The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p].

In [82]:
qweight.shape

torch.Size([128, 128])

In [83]:
output_size_per_partition

64

In [84]:
qweight_partitioned = qweight.split(output_size_per_partition, dim=1)

In [85]:
len(qweight_partitioned)

2

In [86]:
for w in qweight_partitioned: print(w.shape)

torch.Size([128, 64])
torch.Size([128, 64])


In [87]:
orig_absmax = param.quant_state.absmax

In [88]:
orig_absmax_reshaped = orig_absmax.reshape(-1, data.size(1) // blocksize)

In [89]:
orig_absmax_reshaped.dtype

torch.float32

In [90]:
orig_absmax_reshaped.shape

torch.Size([128, 4])

In [91]:
num_partitions = len(qweight_partitioned)

In [92]:
absmax_partitioned = orig_absmax_reshaped.split(orig_absmax_reshaped.size(1) // num_partitions, dim=1)

In [93]:
for a in absmax_partitioned: print(a.shape)

torch.Size([128, 2])
torch.Size([128, 2])


In [94]:
len(qweight_partitioned), len(absmax_partitioned)

(2, 2)

In [95]:
quant_state = copy.deepcopy(param.quant_state)

In [96]:
quant_state.shape = torch.Size([quant_state.shape[0], quant_state.shape[1]//num_partitions])

In [97]:
quant_state.shape

torch.Size([128, 128])

In [115]:
quant_state.absmax = absmax_partitioned[0].contiguous().view(-1)
deqweight_part1 = dequantize_4bit(qweight_partitioned[0].contiguous().view(-1,1), quant_state=quant_state)

quant_state.absmax = absmax_partitioned[1].contiguous().view(-1)
deqweight_part2 = dequantize_4bit(qweight_partitioned[1].contiguous().view(-1,1), quant_state=quant_state)

In [116]:
deqweight_part1.shape, deqweight_part2.shape

(torch.Size([128, 128]), torch.Size([128, 128]))

In [117]:
# quant_state.as_dict()

In [118]:
deqweight.shape

torch.Size([128, 256])

In [119]:
torch.cat([deqweight_part1, deqweight_part2], dim=1).shape

torch.Size([128, 256])

In [120]:
assert torch.equal(deqweight, torch.cat([deqweight_part1, deqweight_part2], dim=1))

In [121]:
out1 = (x @ deqweight_part2)

In [122]:
out2 = bnb.matmul_4bit(x, qweight_partitioned[1].contiguous().view(-1,1), quant_state=quant_state)

In [126]:
assert torch.equal(out1, out2)

In [127]:
out2.shape

torch.Size([4, 128])

### Row Parallel

The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:

```
    -   -
    | A_1 |
    | .   |
A = | .   |        X = [X_1, ..., X_p]
    | .   |
    | A_p |
    -   -
```

In [37]:
qweight_partitioned = qweight.split(output_size_per_partition, dim=0)

In [38]:
num_partitions = len(qweight_partitioned); num_partitions

2

In [39]:
for w in qweight_partitioned: print(w.shape)

torch.Size([64, 128])
torch.Size([64, 128])


In [50]:
orig_absmax = param.quant_state.absmax

In [52]:
orig_absmax_reshaped = orig_absmax.reshape(-1, data.size(1) // blocksize)

In [54]:
absmax_partitioned = orig_absmax.split(len(orig_absmax) // num_partitions, dim=0)

In [55]:
len(absmax_partitioned)

2

In [56]:
quant_state = copy.deepcopy(param.quant_state)

In [57]:
quant_state.shape = torch.Size([quant_state.shape[0]//num_partitions, quant_state.shape[1]]); quant_state.shape

torch.Size([64, 256])

In [58]:
quant_state.absmax = absmax_partitioned[0].contiguous().view(-1)
deqweight_part1 = dequantize_4bit(qweight_partitioned[0].contiguous().view(-1,1), quant_state=quant_state)

quant_state.absmax = absmax_partitioned[1].contiguous().view(-1)
deqweight_part2 = dequantize_4bit(qweight_partitioned[1].contiguous().view(-1,1), quant_state=quant_state)

In [59]:
assert torch.equal(deqweight, torch.cat([deqweight_part1, deqweight_part2], dim=0))

### Loading

In [3]:
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)

INFO 03-29 13:23:10 pynccl_utils.py:13] vLLM is using nccl==2.18.1


In [2]:
weights_iterator = hf_model_weights_iterator("meta-llama/Llama-2-7b-hf")

In [3]:
for name, loaded_weight in weights_iterator: break

INFO 03-29 08:48:45 weight_utils.py:177] Using model weights format ['*.safetensors']


model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

In [5]:
name, loaded_weight

('model.embed_tokens.weight',
 tensor([[ 1.2517e-06, -1.7881e-06, -4.3511e-06,  ...,  8.9407e-07,
          -6.5565e-06,  8.9407e-07],
         [ 1.8616e-03, -3.3722e-03,  3.9864e-04,  ..., -8.3008e-03,
           2.5787e-03, -3.9368e-03],
         [ 1.0986e-02,  9.8877e-03, -5.0964e-03,  ...,  2.5177e-03,
           7.7057e-04, -5.0049e-03],
         ...,
         [-1.3977e-02, -2.7313e-03, -1.9897e-02,  ..., -1.0437e-02,
           9.5825e-03, -1.8005e-03],
         [-1.0742e-02,  9.3384e-03,  1.2939e-02,  ..., -3.3203e-02,
          -1.6357e-02,  3.3875e-03],
         [-8.3008e-03, -4.0588e-03, -1.1063e-03,  ...,  3.4790e-03,
          -1.2939e-02,  3.1948e-05]], dtype=torch.float16))

In [29]:
weights_iterator = hf_model_weights_iterator("TheBloke/CodeUp-Alpha-13B-HF-AWQ")

In [32]:
for name, loaded_weight in weights_iterator: 
    if 'scales' in name or 'zeros' in name:
        print(name)

model.layers.0.mlp.down_proj.qzeros
model.layers.0.mlp.down_proj.scales
model.layers.0.mlp.gate_proj.qzeros
model.layers.0.mlp.gate_proj.scales
model.layers.0.mlp.up_proj.qzeros
model.layers.0.mlp.up_proj.scales
model.layers.0.self_attn.k_proj.qzeros
model.layers.0.self_attn.k_proj.scales
model.layers.0.self_attn.o_proj.qzeros
model.layers.0.self_attn.o_proj.scales
model.layers.0.self_attn.q_proj.qzeros
model.layers.0.self_attn.q_proj.scales
model.layers.0.self_attn.v_proj.qzeros
model.layers.0.self_attn.v_proj.scales
model.layers.1.mlp.down_proj.qzeros
model.layers.1.mlp.down_proj.scales
model.layers.1.mlp.gate_proj.qzeros
model.layers.1.mlp.gate_proj.scales
model.layers.1.mlp.up_proj.qzeros
model.layers.1.mlp.up_proj.scales
model.layers.1.self_attn.k_proj.qzeros
model.layers.1.self_attn.k_proj.scales
model.layers.1.self_attn.o_proj.qzeros
model.layers.1.self_attn.o_proj.scales
model.layers.1.self_attn.q_proj.qzeros
model.layers.1.self_attn.q_proj.scales
model.layers.1.self_attn.v_pro

### Create Quantized Model Files

In [29]:
from pathlib import Path
import os, json
from safetensors.torch import save_file

In [19]:
model_dir = Path("/home/ubuntu/models/llama-7b-hf-nf4-quantized")
os.makedirs(model_dir, exist_ok=True)

In [4]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# original quantized layers from fsdp_qlora/train.py
# ["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]

In [5]:
quantized_layers = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

In [8]:
quantized_state_dict = copy.deepcopy(model.state_dict())

In [9]:
pack_factor = 2
blocksize = 64

In [12]:
for n,p in model.state_dict().items():
    if any(l in n for l in quantized_layers) and "weight" in n:
        # output_size x input_size
        print(n, p.shape, p.t().shape)
        param = Params4bit(p.t(), quant_type="nf4", blocksize=blocksize, compress_statistics=False, quant_storage=torch.uint8)
        input_size, output_size = p.t().shape
        param.cuda();

        # reshape for tensor parallelism
        qweight, absmax = param.data.cpu(), param.quant_state.absmax.cpu()        
        qweight = qweight.reshape(input_size, output_size // pack_factor)
        absmax = absmax.reshape(-1, output_size // blocksize)
                
        quantized_state_dict[n] = qweight
        quantized_state_dict[n.replace(".weight", ".absmax")] = absmax

        param = None
        torch.cuda.empty_cache()

model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096]) torch.Size([4096, 11008])
model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096]) torch.Size([4096, 11008])
model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008]) torch.Size([11008, 4096])
model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.self_attn.k_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.self_attn.v_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096]) torch.Size([4096, 4096])
model.layers.1.mlp.gate_pr

In [20]:
# save quantized weights
save_file(quantized_state_dict, model_dir/"model_state_dict.safetensors")

In [22]:
# create and save quantization config
quant_config_filename = model_dir/"quantize_config.json"

In [28]:
quant_config_dict = {
    "weight_bits" : 4,
    "blocksize" : 64,
    "quant_type" : "nf4",
    "quant_storage" : "uint8",
    "compress_statistics" : False
}

In [30]:
with open(quant_config_filename, "w+") as f: json.dump(quant_config_dict, f)

In [31]:
model_config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")

In [33]:
# save model config
model_config_filename = model_dir/"config.json"

In [37]:
with open(model_config_filename, "w+") as f: json.dump(model_config.to_dict(), f)

### BNB Quantized VLLM

In [1]:
from vllm import LLM, SamplingParams

In [2]:
model_dir = model_dir = "/home/ubuntu/models/llama-7b-hf-nf4-quantized"

In [3]:
llm = LLM(model=model_dir, tokenizer="meta-llama/Llama-2-7b-hf", dtype="bfloat16", quantization="bnb")

INFO 03-29 14:30:46 llm_engine.py:70] Initializing an LLM engine (v0.3.3) with config: model='/home/ubuntu/models/llama-7b-hf-nf4-quantized', tokenizer='meta-llama/Llama-2-7b-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=True, quantization=bnb, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 03-29 14:30:47 pynccl_utils.py:13] vLLM is using nccl==2.18.1
INFO 03-29 14:30:48 selector.py:44] flash_attn is not found.
INFO 03-29 14:30:48 selector.py:20] Using XFormers backend.
INFO 03-29 14:30:51 model_runner.py:104] Loading model weights took 3.9585 GB
INFO 03-29 14:30:53 gpu_executor.py:94] # GPU blocks: 2087, # CPU blocks: 512
INFO 03-29 14:30:55 model_runner.py:770] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in ea

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 4096 (input tensor's size at dimension -1), but got split_sizes=[4096, 4096, 4096]