# Fine-tune Llama 3 with QLoRA

> üó£Ô∏è [Large Language Model Course](https://github.com/mlabonne/llm-course)

‚ù§Ô∏è Created by [@maximelabonne](https://twitter.com/maximelabonne).

You can run this notebook on Google Colab (I use an L4 GPU).

In [1]:
!pip install -qqq -U transformers datasets huggingface_hub accelerate peft bitsandbytes wandb trl --progress-bar off
!FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install -qqq -U flash-attn --no-build-isolation pip install flash-attn --progress-bar off

In [2]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [3]:
import gc
import os

import torch
import wandb
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
    pipeline,
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
base2_model = "Meta-Llama-3-8B-tagllm-lang-10"
new_model = "Meta-Llama-3-8B-tagllm-translation-10"

# Defined in the secrets tab in Google Colab
wb_token = '1d395c70839c926f2dce7fc9403ad88f09e490ba'
wandb.login(key=wb_token)

# Set torch dtype and attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2"
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['embed_tokens', 'up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

num_token_per_tag = 10
translation_tokens = [f'<|TOK {i}|>' for i in range(100, 100 + num_token_per_tag)]

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base2_model, additional_special_tokens=translation_tokens)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation,
)
# Exclude the new translation_tag, expand embedding again after loading the LoRA weights
model.resize_token_embeddings(len(tokenizer) - len(translation_tokens))
model = PeftModel.from_pretrained(model, base2_model)
model = model.merge_and_unload()
model.resize_token_embeddings(len(tokenizer))
model = prepare_model_for_kbit_training(model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



In [5]:
peft_model = get_peft_model(model, peft_config)
print(peft_model.print_trainable_parameters())
print(peft_model)

trainable params: 44,061,312 || all params: 8,074,650,240 || trainable%: 0.5456745579112539
None
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): lora.Embedding(
          (base_layer): Embedding(128296, 4096)
          (lora_dropout): ModuleDict(
            (default): Dropout(p=0.05, inplace=False)
          )
          (lora_A): ModuleDict()
          (lora_B): ModuleDict()
          (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 16x128296 (cuda:0)])
          (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 4096x16 (cuda:0)])
        )
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaFlashAttention2(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                

In [6]:
from datasets import load_dataset, interleave_datasets

def get_dataset(num_existing_tokens=0):
    lm_datasets_train = []
    lm_datasets_test = []

    single_lang = ["eng", "yue", "cmn"]
    lang_datasets = ["eng-yue", "eng-cmn"]
    lang_pairs = ["eng-yue", "cmn-eng"]

    tag_name_dict = {}
    for lang in single_lang:
        tag_name_dict[lang] = "".join([f'<|TOK {i}|>' for i in range(num_existing_tokens, num_existing_tokens + num_token_per_tag)])
        num_existing_tokens += num_token_per_tag

    for i, lang_dataset in enumerate(lang_datasets):

        lm_dataset = load_dataset("AlienKevin/yue-cmn-eng", lang_dataset)
        lm_dataset_train = lm_dataset["train"]
        lm_dataset_test = lm_dataset["test"]

        source_lang, target_lang = lang_pairs[i].split("-")

        def preprocess_function(examples):
            examples["inputs"] = [tag_name_dict[source_lang] + example[source_lang] + '\n' + tag_name_dict[target_lang] + ''.join(translation_tokens) + example[target_lang] for example in examples["translation"]]
            del examples['translation']
            return examples

        lm_dataset_train = lm_dataset_train.map(preprocess_function, batched=True)
        lm_dataset_test = lm_dataset_test.map(preprocess_function, batched=True)
        lm_datasets_train.append(lm_dataset_train)
        lm_datasets_test.append(lm_dataset_test)

    train_dataset = interleave_datasets(lm_datasets_train)
    eval_dataset = interleave_datasets(lm_datasets_test)

    return train_dataset, eval_dataset, tag_name_dict

In [7]:
train_dataset, eval_dataset, tag_name_dict = get_dataset()

In [8]:
train_dataset[:10]

{'inputs': ['<|TOK 0|><|TOK 1|><|TOK 2|><|TOK 3|><|TOK 4|><|TOK 5|><|TOK 6|><|TOK 7|><|TOK 8|><|TOK 9|>Scoop up water\n<|TOK 10|><|TOK 11|><|TOK 12|><|TOK 13|><|TOK 14|><|TOK 15|><|TOK 16|><|TOK 17|><|TOK 18|><|TOK 19|><|TOK 100|><|TOK 101|><|TOK 102|><|TOK 103|><|TOK 104|><|TOK 105|><|TOK 106|><|TOK 107|><|TOK 108|><|TOK 109|>„ßæÊ∞¥',
  '<|TOK 20|><|TOK 21|><|TOK 22|><|TOK 23|><|TOK 24|><|TOK 25|><|TOK 26|><|TOK 27|><|TOK 28|><|TOK 29|>Â¶ÇÊûúÊàë‰ª¨Êó†Ê≥ïÂà∂ÂÆöÂá∫ËØÜÂà´ÂíåÈÅøÂÖçÂèØËÉΩÁöÑÊñ∞Âç±Èô©ÁöÑÊñπÂºèÂíåÊñπÊ≥ïÔºåÈÇ£ËøôÁßç‰ª§‰∫∫ÈºìËàûÁöÑÊñ∞ÊäÄÊúØÁöÑÂâçÊôØÂπ∂‰∏ç‰πêËßÇ„ÄÇ\n<|TOK 0|><|TOK 1|><|TOK 2|><|TOK 3|><|TOK 4|><|TOK 5|><|TOK 6|><|TOK 7|><|TOK 8|><|TOK 9|><|TOK 100|><|TOK 101|><|TOK 102|><|TOK 103|><|TOK 104|><|TOK 105|><|TOK 106|><|TOK 107|><|TOK 108|><|TOK 109|>If we do not develop the ways and means to spot and navigate around possible new risks, the outlook for this exciting new technology will be uncertain.',
  '<|TOK 0|><|TOK 1|><|TOK 2|><|TOK 3|><|TOK 4|><|TOK 5|><|TOK 6|

In [9]:
eval_dataset[:10]

{'inputs': ['<|TOK 0|><|TOK 1|><|TOK 2|><|TOK 3|><|TOK 4|><|TOK 5|><|TOK 6|><|TOK 7|><|TOK 8|><|TOK 9|>This is really amusing, a radio controlled car that can climb on walls.\n<|TOK 10|><|TOK 11|><|TOK 12|><|TOK 13|><|TOK 14|><|TOK 15|><|TOK 16|><|TOK 17|><|TOK 18|><|TOK 19|><|TOK 100|><|TOK 101|><|TOK 102|><|TOK 103|><|TOK 104|><|TOK 105|><|TOK 106|><|TOK 107|><|TOK 108|><|TOK 109|>Âë¢Êû∂ÈÅôÊéßËªäË≠òÂæóÁà¨ÁâÜÔºåÂèàÁúü‰øÇÂπæÁõûÈ¨ºÂñé„ÄÇ',
  '<|TOK 20|><|TOK 21|><|TOK 22|><|TOK 23|><|TOK 24|><|TOK 25|><|TOK 26|><|TOK 27|><|TOK 28|><|TOK 29|>ÂáèÈÄüÂ¢ûÈïøÂÜçÂàõÊîπÈù©Á∫¢Âà©\n<|TOK 0|><|TOK 1|><|TOK 2|><|TOK 3|><|TOK 4|><|TOK 5|><|TOK 6|><|TOK 7|><|TOK 8|><|TOK 9|><|TOK 100|><|TOK 101|><|TOK 102|><|TOK 103|><|TOK 104|><|TOK 105|><|TOK 106|><|TOK 107|><|TOK 108|><|TOK 109|>China Grows Down',
  '<|TOK 0|><|TOK 1|><|TOK 2|><|TOK 3|><|TOK 4|><|TOK 5|><|TOK 6|><|TOK 7|><|TOK 8|><|TOK 9|>What a shitty phone - it broke down after being used a few times.\n<|TOK 10|><|TOK 11|><|TOK 12|><|TOK 13|><|T

In [10]:
# Used to suppress:
# Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.
tokenizer.pad_token = tokenizer.eos_token

In [None]:
args = TrainingArguments(
    learning_rate=5e-05,
    lr_scheduler_type="linear",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    report_to="wandb",
    output_dir=f"./results-{new_model}/",
)

trainer = SFTTrainer(
    model=model,
    args=args,
    max_seq_length=512,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="inputs",
    tokenizer=tokenizer,
    peft_config=peft_config,
)
trainer.train()
trainer.save_model(new_model)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.


Step,Training Loss,Validation Loss
4714,1.5675,1.840561


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [None]:
# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()