# `transformers` meets `bitsandbytes` for democratzing Large Language Models (LLMs) through 4bit quantization

<center>
<img src="https://github.com/huggingface/blog/blob/main/assets/96_hf_bitsandbytes_integration/Thumbnail_blue.png?raw=true" alt="drawing" width="700" class="center"/>
</center>

Welcome to this notebook that goes through the recent `bitsandbytes` integration that includes the work from XXX that introduces no performance degradation 4bit quantization techniques, for democratizing LLMs inference and training.

In this notebook, we will learn together how to load a large model in 4bit (`gpt-neo-x-20b`) and train it using Google Colab and PEFT library from Hugging Face 🤗.

[In the general usage notebook](https://colab.research.google.com/drive/1ge2F1QSK8Q7h0hn3YKuBCOAS0bK8E0wf?usp=sharing), you can learn how to propely load a model in 4bit with all its variants.

If you liked the previous work for integrating [*LLM.int8*](https://arxiv.org/abs/2208.07339), you can have a look at the [introduction blogpost](https://huggingface.co/blog/hf-bitsandbytes-integration) to lean more about that quantization method.


In [1]:
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproje

First let's load the model we are going to use - GPT-neo-x-20B! Note that the model itself is around 40GB in half precision

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "HuggingFaceH4/zephyr-7b-beta"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/638 [00:00<?, ?B/s]

Downloading (…)fetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/816M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Then we have to apply some preprocessing to the model to prepare it for training. For that use the `prepare_model_for_kbit_training` method from PEFT.

In [3]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [4]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [7]:
print(model)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=2)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
  

In [8]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj" ,"k_proj" ,"v_proj" ,"o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

trainable params: 6815744 || all params: 3758886912 || trainable%: 0.18132346515244138


Let's load a common dataset, english quotes, to fine tune our model on famous quotes.

In [23]:
from datasets import load_dataset

data = load_dataset("Moreza009/Tehran_Covid19_2")
data = data.map(lambda samples: tokenizer(samples["patient medical hidtory"]), batched=True)

In [24]:
def merge_columns(example):
    example["prediction"] = "does the patient survive or die based on the provided medical history? patient history is : "+example["patient medical hidtory"] + " ----->: " + str(example["Inhospital Mortality"])
    return example

data['train'] = data['train'].map(merge_columns)
data['train']["prediction"][:5]

['does the patient survive or die based on the provided medical history? patient history is : Adult , [present symptomps are : Dyspnea, nausea/Vomit, absent symptoms are : Caugh, Fever, Chiver, Mylagia, Weakness, loss of consciousness, Sore through, Rhinorrhea, Smelling disorder, Anorexia, Diarhhea, ChestPain, Seizure, Skin Lesion, Joint pain, Headache, Abdominal Pain, Earpain, Hemorrhasia, Hemiparesia] [past medical  history : Hyperlipidemia , conditions that are not in past medical history : Pregnancy, Curremt Smoker, Alcohol user, Opium user, Hookah user, hypertension, ischemic heart disease, coronary artery bypass graft, Congestive Heart Failure, Ashtma, COPD, diabetes mellitus, Pneumonia, cerebral vascular accident, gastrointestinal disorder, Chronic kidney disease, Rheumatoid Arthritis, Cancer, Hepatitis C, Thyroid dysfunction, Immunocompromised, Chronic Seizure, Tuberculosis , Anemia, Fatty liver disease, Psychological disorder, Parkinson, Alzhimer] [symptom to referral is in no

Run the cell below to run the training! For the sake of the demo, we just ran it for few steps just to showcase how to use this integration with existing tools on the HF ecosystem.

In [21]:
import transformers

# needed for gpt-neo-x tokenizer
tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()



Step,Training Loss
1,0.5447
2,0.554
3,0.4565
4,0.559
5,0.2922
6,0.2204
7,0.1714
8,0.172
9,0.1704
10,0.1441


TrainOutput(global_step=10, training_loss=0.32846537679433824, metrics={'train_runtime': 128.6238, 'train_samples_per_second': 0.311, 'train_steps_per_second': 0.078, 'total_flos': 903463913226240.0, 'train_loss': 0.32846537679433824, 'epoch': 0.01})

In [11]:
trainer

<transformers.trainer.Trainer at 0x7e067f456890>

In [25]:
text = "'does the patient survive or die based on the provided medical history? patient history is : Older adult , [present symptomps are : Weakness, nausea/Vomit, Diarhhea, absent symptoms are : Caugh, Dyspnea, Fever, Chiver, Mylagia, loss of consciousness, Sore through, Rhinorrhea, Smelling disorder, Anorexia, ChestPain, Seizure, Skin Lesion, Joint pain, Headache, Abdominal Pain, Earpain, Hemorrhasia, Hemiparesia] [conditions that are not in past medical history : Pregnancy, Curremt Smoker, Alcohol user, Opium user, Hookah user, hypertension, ischemic heart disease, coronary artery bypass graft, Congestive Heart Failure, Ashtma, COPD, diabetes mellitus, Pneumonia, cerebral vascular accident, gastrointestinal disorder, Chronic kidney disease, Rheumatoid Arthritis, Cancer, Hyperlipidemia , Hepatitis C, Thyroid dysfunction, Immunocompromised, Chronic Seizure, Tuberculosis , Anemia, Fatty liver disease, Psychological disorder, Parkinson, Alzhimer] [symptom to referral is in normal range. ] and [O2 saturation without supply is lower than normal range.  and pulse rate is higher than normal range. ] and diastolic Blood pressure is higher than normal range. ] and Systolic Blood pressure is higher than normal range. ] and respiratory rate is higher than normal range. ] and [Temperature is in normal range. ] and WBC is higher than normal range. ] and [Lymphocyte count is lower than normal range.  and Neutrophils percentage is higher than normal range. ] and Platelets is higher than normal range. ] and [Hemoglobin is lower than normal range.  and [MCV is in normal range. ] and CR is higher than normal range. ] and [sodium is lower than normal range.  and potassium is higher than normal range. ] and alkaline phosphatase is higher than normal range. ] and ESR is higher than normal range. ] and [CPK is lower than normal range.  and [PTT is in normal range. ] and [PT is in normal range. ] ----->: dies'"
device = "cuda:0"

inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=1)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

'does the patient survive or die based on the provided medical history? patient history is : Older adult , [present symptomps are : Weakness, nausea/Vomit, Diarhhea, absent symptoms are : Caugh, Dyspnea, Fever, Chiver, Mylagia, loss of consciousness, Sore through, Rhinorrhea, Smelling disorder, Anorexia, ChestPain, Seizure, Skin Lesion, Joint pain, Headache, Abdominal Pain, Earpain, Hemorrhasia, Hemiparesia] [conditions that are not in past medical history : Pregnancy, Curremt Smoker, Alcohol user, Opium user, Hookah user, hypertension, ischemic heart disease, coronary artery bypass graft, Congestive Heart Failure, Ashtma, COPD, diabetes mellitus, Pneumonia, cerebral vascular accident, gastrointestinal disorder, Chronic kidney disease, Rheumatoid Arthritis, Cancer, Hyperlipidemia , Hepatitis C, Thyroid dysfunction, Immunocompromised, Chronic Seizure, Tuberculosis , Anemia, Fatty liver disease, Psychological disorder, Parkinson, Alzhimer] [symptom to referral is in normal range. ] and [