In [None]:
import torch
import bitsandbytes as bnb
import safetensors
from safetensors.torch import save_file


BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
Loading CUDA version: BNB_CUDA_VERSION=123


  warn((f'\n\n{"="*80}\n'


In [None]:
from bitsandbytes.nn import Linear4bit, Params4bit
import bitsandbytes.functional as F
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME

In [None]:
from transformers import AutoConfig, AutoModelForCausalLM
import torch.nn as nn

### Test Linear4bit Memory Eff Loading

This will test that each rank has the correct quant state and params, also compare with original weights loaded. 

In [None]:
params_rank0 = torch.load("../data/summoned_lora_layer0_q_proj_base_layer_params_rank0.pt")
params_rank1 = torch.load("../data/summoned_lora_layer0_q_proj_base_layer_params_rank1.pt")

In [None]:
quant_state_rank0 = torch.load("../data/summoned_lora_layer0_q_proj_quant_state_rank0.pt", map_location="cpu")
quant_state_rank1 = torch.load("../data/summoned_lora_layer0_q_proj_quant_state_rank1.pt",  map_location="cpu")

In [None]:
# check gathered quantized weights are same in each rank
for p1, p2 in zip(params_rank0, params_rank1):
    p1 = p1[~p1.data.isnan()]
    p2 = p2[~p2.data.isnan()]
    assert torch.allclose(p1, p2)

In [None]:
# check quant states are same in each rank
for k,v in quant_state_rank0.as_dict().items():
    print(k)
    if isinstance(v, torch.Tensor):
        assert torch.equal(v, quant_state_rank1.as_dict()[k])
    else:
        assert v == quant_state_rank1.as_dict()[k]

quant_type
absmax
blocksize
quant_map
dtype
shape
nested_absmax
nested_blocksize
nested_quant_map
nested_dtype
nested_offset


In [None]:
quantized_param = Params4bit(data=params_rank0[0], 
                               requires_grad=False, 
                               quant_state=quant_state_rank0,
                               quant_type=quant_state_rank0.quant_type,
                               quant_storage=params_rank0[0].dtype, 
                               bnb_quantized=True)

In [None]:
quant_state_rank0.to("cuda");

In [None]:
quant_state_rank0.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([230, 149,  74,  ..., 194, 175, 203], device='cuda:0',
        dtype=torch.uint8),
 'blocksize': 64,
 'quant_map': tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
          0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000]),
 'dtype': 'bfloat16',
 'shape': (8192, 8192),
 'nested_absmax': tensor([0.0736, 0.0258, 0.0224,  ..., 0.0658, 0.0902, 0.0638], device='cuda:0'),
 'nested_blocksize': 256,
 'nested_quant_map': tensor([-9.9297e-01, -9.7891e-01, -9.6484e-01, -9.5078e-01, -9.3672e-01,
         -9.2266e-01, -9.0859e-01, -8.9453e-01, -8.8047e-01, -8.6641e-01,
         -8.5234e-01, -8.3828e-01, -8.2422e-01, -8.1016e-01, -7.9609e-01,
         -7.8203e-01, -7.6797e-01, -7.5391e-01, -7.3984e-01, -7.2578e-01,
         -7.1172e-01, -6.9766e-01, -6.8359e-01, -6.6953e-01, -6.5547e-01,
         -6.4141e-01, -6.2734e-01, -6.1328e-01, -5.9922e-01, -5.8516e-01,
         -5.7109e-01, -5.5703e-01, -5.4297e-01,

In [None]:
data = params_rank0[0].data.to("cuda")

In [None]:
dequantized_weight = F.dequantize_4bit(data, quant_state_rank0)

In [None]:
# put here the model name used to save the summoned weights
model_name = "codellama/CodeLlama-34b-hf"

In [None]:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
orig_weight = None
for filename in files:
    weights = safetensors.torch.load_file(filename)
    for name, param in weights.items():
        if name == "model.layers.0.self_attn.q_proj.weight":
            orig_weight = param
            break

In [None]:
# some devation is expected from dequantization
# Taken from : peft/tests/.../test_4bit_merge_and_disable_lora - Stricter tolerance values needed?
assert torch.allclose(dequantized_weight.cpu(), orig_weight, atol=0.01, rtol=10)