In [1]:
import torch
import bitsandbytes as bnb
import safetensors
from safetensors.torch import save_file


BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
Loading CUDA version: BNB_CUDA_VERSION=123


  warn((f'\n\n{"="*80}\n'


In [2]:
from bitsandbytes.nn import Linear4bit, Params4bit
import bitsandbytes.functional as F
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME

In [15]:
from transformers import AutoConfig, AutoModelForCausalLM
import torch.nn as nn

### Custom QLORA

In [9]:
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], **kwargs):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        linear_replacement (`torch.nn.Module`):
            The linear module that replaces the old one. Only expects standard arguments.
            If other arguments need to be passed, use a lambda.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear(module, linear_replacement, skip_modules, **kwargs)

        if isinstance(module, torch.nn.Linear) and name not in skip_modules:
            model._modules[name] = linear_replacement(
                module.in_features,
                module.out_features,
                module.bias is not None,
                **kwargs
            )
    return model

In [46]:
args = {}
args['lora_rank']: int = 64 # LoRA rank for lora/qlora
args['lora_alpha']: int = 16 # LoRA alpha for lora/qlora
args['lora_dropout']: float = 0.1 # LoRA dropout for lora/qlora
args['lora_target_modules'] = "all" #

In [82]:
class QLORA(nn.Module):
    def __init__(self, base_layer, device="cpu"):
        super().__init__()
        self.base_layer = base_layer
        dtype = base_layer.compute_dtype
        self.lora_A = nn.Linear(base_layer.weight.shape[0], args["lora_rank"], bias=False, device=device, dtype=dtype)
        self.lora_B = nn.Linear(args["lora_rank"], base_layer.weight.shape[1], bias=False, device=device, dtype=dtype)
        self.lora_alpha = args["lora_alpha"]
        self.lora_dropout = nn.Dropout(args["lora_dropout"])
        self.scaling = self.lora_alpha / args['lora_rank']

        for p in self.lora_A.parameters():
            p.requires_grad = True
        for p in self.lora_B.parameters():
            p.requires_grad = True
        
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:

        result = self.base_layer(x, *args, **kwargs)
        # As per Tim Dettmers, for 4bit, we need to defensively clone here.
        # The reason is that in some cases, an error can occur that backprop
        # does not work on a manipulated view. This issue may be solved with
        # newer PyTorch versions but this would need extensive testing to be
        # sure.
        result = result.clone()

        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            expected_dtype = result.dtype
            x = x.to(self.lora_A.weight.dtype)

        output = self.lora_B(self.lora_A(self.lora_dropout(x)))
        if requires_conversion:
            output = output.to(expected_dtype)
        output = output * self.scaling
        result += output

        return result

In [7]:
cfg = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
cfg.use_cache = False
cfg._attn_implementation = "flash_attention_2"
cfg.update(dict(num_hidden_layers=2)) # debug mode.

In [91]:
cfg.hidden_size

4096

In [78]:
model = AutoModelForCausalLM.from_config(cfg)

In [90]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): QLORA(
            (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (lora_A): Linear(in_features=4096, out_features=64, bias=False)
            (lora_B): Linear(in_features=64, out_features=4096, bias=False)
            (lora_dropout): Dropout(p=0.1, inplace=False)
          )
          (k_proj): QLORA(
            (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (lora_A): Linear(in_features=4096, out_features=64, bias=False)
            (lora_B): Linear(in_features=64, out_features=4096, bias=False)
            (lora_dropout): Dropout(p=0.1, inplace=False)
          )
          (v_proj): QLORA(
            (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (lora_A): Linear(i

In [79]:
model.model = replace_linear(model.model, Linear4bit, compute_dtype=torch.bfloat16,
                             quant_type='nf4', quant_storage=torch.bfloat16)

In [80]:
lora_target_modules = ["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]

In [83]:
for name,module in model.named_modules():
    module_key, _, value_key = name.rpartition('.')
    if value_key in lora_target_modules:
        m = model.get_submodule(name)
        qlora_layer = QLORA(m)
        parent_module = model.get_submodule(module_key)
        setattr(parent_module, value_key, qlora_layer)

NameError: name 'base_layer' is not defined

In [88]:
for n,p in model.named_parameters():
    if any([lora_name in n for lora_name in ['lora_A', 'lora_B']]):
        p.requires_grad = True
    else:
        p.requires_grad = False
    
    print(n, p.requires_grad)

model.embed_tokens.weight False
model.layers.0.self_attn.q_proj.base_layer.weight False
model.layers.0.self_attn.q_proj.lora_A.weight True
model.layers.0.self_attn.q_proj.lora_B.weight True
model.layers.0.self_attn.k_proj.base_layer.weight False
model.layers.0.self_attn.k_proj.lora_A.weight True
model.layers.0.self_attn.k_proj.lora_B.weight True
model.layers.0.self_attn.v_proj.base_layer.weight False
model.layers.0.self_attn.v_proj.lora_A.weight True
model.layers.0.self_attn.v_proj.lora_B.weight True
model.layers.0.self_attn.o_proj.weight False
model.layers.0.mlp.gate_proj.base_layer.weight False
model.layers.0.mlp.gate_proj.lora_A.weight True
model.layers.0.mlp.gate_proj.lora_B.weight True
model.layers.0.mlp.up_proj.base_layer.weight False
model.layers.0.mlp.up_proj.lora_A.weight True
model.layers.0.mlp.up_proj.lora_B.weight True
model.layers.0.mlp.down_proj.base_layer.weight False
model.layers.0.mlp.down_proj.lora_A.weight True
model.layers.0.mlp.down_proj.lora_B.weight True
model.la

### Test Linear4bit Memory Eff Loading

This will test that each rank has the correct quant state and params, also compare with original weights loaded. 

In [3]:
params_rank0 = torch.load("../data/summoned_lora_layer0_q_proj_base_layer_params_rank0.pt")
params_rank1 = torch.load("../data/summoned_lora_layer0_q_proj_base_layer_params_rank1.pt")

In [4]:
quant_state_rank0 = torch.load("../data/summoned_lora_layer0_q_proj_quant_state_rank0.pt", map_location="cpu")
quant_state_rank1 = torch.load("../data/summoned_lora_layer0_q_proj_quant_state_rank1.pt",  map_location="cpu")

In [5]:
# check gathered quantized weights are same in each rank
for p1, p2 in zip(params_rank0, params_rank1):
    p1 = p1[~p1.data.isnan()]
    p2 = p2[~p2.data.isnan()]
    assert torch.allclose(p1, p2)

In [6]:
quant_state_rank0.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([230, 149,  74,  ..., 194, 175, 203], dtype=torch.uint8),
 'blocksize': 64,
 'quant_map': tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
          0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000]),
 'dtype': 'bfloat16',
 'shape': (8192, 8192),
 'nested_absmax': tensor([0.0736, 0.0258, 0.0224,  ..., 0.0658, 0.0902, 0.0638]),
 'nested_blocksize': 256,
 'nested_quant_map': tensor([-9.9297e-01, -9.7891e-01, -9.6484e-01, -9.5078e-01, -9.3672e-01,
         -9.2266e-01, -9.0859e-01, -8.9453e-01, -8.8047e-01, -8.6641e-01,
         -8.5234e-01, -8.3828e-01, -8.2422e-01, -8.1016e-01, -7.9609e-01,
         -7.8203e-01, -7.6797e-01, -7.5391e-01, -7.3984e-01, -7.2578e-01,
         -7.1172e-01, -6.9766e-01, -6.8359e-01, -6.6953e-01, -6.5547e-01,
         -6.4141e-01, -6.2734e-01, -6.1328e-01, -5.9922e-01, -5.8516e-01,
         -5.7109e-01, -5.5703e-01, -5.4297e-01, -5.2891e-01, -5.1484e-01,
         -5.007

In [7]:
quant_state_rank1.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([230, 149,  74,  ..., 194, 175, 203], dtype=torch.uint8),
 'blocksize': 64,
 'quant_map': tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
          0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000]),
 'dtype': 'bfloat16',
 'shape': (8192, 8192),
 'nested_absmax': tensor([0.0736, 0.0258, 0.0224,  ..., 0.0658, 0.0902, 0.0638]),
 'nested_blocksize': 256,
 'nested_quant_map': tensor([-9.9297e-01, -9.7891e-01, -9.6484e-01, -9.5078e-01, -9.3672e-01,
         -9.2266e-01, -9.0859e-01, -8.9453e-01, -8.8047e-01, -8.6641e-01,
         -8.5234e-01, -8.3828e-01, -8.2422e-01, -8.1016e-01, -7.9609e-01,
         -7.8203e-01, -7.6797e-01, -7.5391e-01, -7.3984e-01, -7.2578e-01,
         -7.1172e-01, -6.9766e-01, -6.8359e-01, -6.6953e-01, -6.5547e-01,
         -6.4141e-01, -6.2734e-01, -6.1328e-01, -5.9922e-01, -5.8516e-01,
         -5.7109e-01, -5.5703e-01, -5.4297e-01, -5.2891e-01, -5.1484e-01,
         -5.007

In [8]:
# check quant states are same in each rank
for k,v in quant_state_rank0.as_dict().items():
    print(k)
    if isinstance(v, torch.Tensor):
        assert torch.equal(v, quant_state_rank1.as_dict()[k])
    else:
        assert v == quant_state_rank1.as_dict()[k]

quant_type
absmax
blocksize
quant_map
dtype
shape
nested_absmax
nested_blocksize
nested_quant_map
nested_dtype
nested_offset


In [9]:
quant_state_rank0.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([230, 149,  74,  ..., 194, 175, 203], dtype=torch.uint8),
 'blocksize': 64,
 'quant_map': tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
          0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000]),
 'dtype': 'bfloat16',
 'shape': (8192, 8192),
 'nested_absmax': tensor([0.0736, 0.0258, 0.0224,  ..., 0.0658, 0.0902, 0.0638]),
 'nested_blocksize': 256,
 'nested_quant_map': tensor([-9.9297e-01, -9.7891e-01, -9.6484e-01, -9.5078e-01, -9.3672e-01,
         -9.2266e-01, -9.0859e-01, -8.9453e-01, -8.8047e-01, -8.6641e-01,
         -8.5234e-01, -8.3828e-01, -8.2422e-01, -8.1016e-01, -7.9609e-01,
         -7.8203e-01, -7.6797e-01, -7.5391e-01, -7.3984e-01, -7.2578e-01,
         -7.1172e-01, -6.9766e-01, -6.8359e-01, -6.6953e-01, -6.5547e-01,
         -6.4141e-01, -6.2734e-01, -6.1328e-01, -5.9922e-01, -5.8516e-01,
         -5.7109e-01, -5.5703e-01, -5.4297e-01, -5.2891e-01, -5.1484e-01,
         -5.007

In [10]:
params_rank0[0]

Parameter containing:
tensor([[ 4.9895e+33],
        [-7.9810e-25],
        [ 2.9687e+14],
        ...,
        [ 7.2876e+07],
        [-3.9808e-24],
        [-5.1300e+36]], dtype=torch.bfloat16, requires_grad=True)

In [11]:
quantized_param = Params4bit(data=params_rank0[0], 
                               requires_grad=False, 
                               quant_state=quant_state_rank0,
                               quant_type=quant_state_rank0.quant_type,
                               quant_storage=params_rank0[0].dtype, 
                               bnb_quantized=True)

In [12]:
params_rank0[0].data

tensor([[ 4.9895e+33],
        [-7.9810e-25],
        [ 2.9687e+14],
        ...,
        [ 7.2876e+07],
        [-3.9808e-24],
        [-5.1300e+36]], dtype=torch.bfloat16)

In [13]:
quant_state_rank0.to("cuda");

In [14]:
quant_state_rank0.as_dict()

{'quant_type': 'nf4',
 'absmax': tensor([230, 149,  74,  ..., 194, 175, 203], device='cuda:0',
        dtype=torch.uint8),
 'blocksize': 64,
 'quant_map': tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
          0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000]),
 'dtype': 'bfloat16',
 'shape': (8192, 8192),
 'nested_absmax': tensor([0.0736, 0.0258, 0.0224,  ..., 0.0658, 0.0902, 0.0638], device='cuda:0'),
 'nested_blocksize': 256,
 'nested_quant_map': tensor([-9.9297e-01, -9.7891e-01, -9.6484e-01, -9.5078e-01, -9.3672e-01,
         -9.2266e-01, -9.0859e-01, -8.9453e-01, -8.8047e-01, -8.6641e-01,
         -8.5234e-01, -8.3828e-01, -8.2422e-01, -8.1016e-01, -7.9609e-01,
         -7.8203e-01, -7.6797e-01, -7.5391e-01, -7.3984e-01, -7.2578e-01,
         -7.1172e-01, -6.9766e-01, -6.8359e-01, -6.6953e-01, -6.5547e-01,
         -6.4141e-01, -6.2734e-01, -6.1328e-01, -5.9922e-01, -5.8516e-01,
         -5.7109e-01, -5.5703e-01, -5.4297e-01,

In [15]:
data = params_rank0[0].data.to("cuda")

In [16]:
dequantized_weight = F.dequantize_4bit(data, quant_state_rank0)

In [17]:
dequantized_weight

tensor([[ 0.0000, -0.0076,  0.0000,  ...,  0.0024,  0.0131,  0.0024],
        [ 0.0093, -0.0066,  0.0166,  ...,  0.0105, -0.0057,  0.0105],
        [ 0.0012,  0.0038,  0.0052,  ..., -0.0061, -0.0028, -0.0044],
        ...,
        [ 0.0000,  0.0234, -0.0118,  ..., -0.0209,  0.0000,  0.0074],
        [ 0.0124,  0.0029,  0.0206,  ..., -0.0032, -0.0347, -0.0098],
        [ 0.0116, -0.0087,  0.0037,  ...,  0.0000,  0.0525,  0.0231]],
       device='cuda:0', dtype=torch.bfloat16)

In [18]:
model_name = "codellama/CodeLlama-34b-hf"

In [19]:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
orig_weight = None
for filename in files:
    weights = safetensors.torch.load_file(filename)
    for name, param in weights.items():
        # print(name)
        if name == "model.layers.0.self_attn.q_proj.weight":
            orig_weight = param
            break
        # load_param(model, name, param, dtype=torch_dtype, device=rank, 
        #            skip_names=load_param_skip_names, to_cpu=True)

In [20]:
# some devation is expected from dequantization
# Taken from : peft/tests/.../test_4bit_merge_and_disable_lora - Stricter tolerance values needed?
assert torch.allclose(dequantized_weight.cpu(), orig_weight, atol=0.01, rtol=10)

In [21]:
quantized_param

Parameter containing:
Parameter(Params4bit([[ 4.9895e+33],
            [-7.9810e-25],
            [ 2.9687e+14],
            ...,
            [ 7.2876e+07],
            [-3.9808e-24],
            [-5.1300e+36]], dtype=torch.bfloat16))