## Finetune Falcon-7b on a Google colab

Welcome to this Google Colab notebook that shows how to fine-tune the recent Falcon-7b model on a single Google colab and turn it into a chatbot

We will leverage PEFT library from Hugging Face ecosystem, as well as QLoRA for more memory efficient finetuning

## Setup

Run the cells below to setup and install the required libraries. For our experiment we will need `accelerate`, `peft`, `transformers`, `datasets` and TRL to leverage the recent [`SFTTrainer`](https://huggingface.co/docs/trl/main/en/sft_trainer). We will use `bitsandbytes` to [quantize the base model into 4bit](https://huggingface.co/blog/4bit-transformers-bitsandbytes). We will also install `einops` as it is a requirement to load Falcon models.

In [1]:
!pip install -q -U trl transformers accelerate git+https://github.com/huggingface/peft.git
!pip install -q datasets bitsandbytes einops wandb

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.0/124.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.9/99.9 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m98.6 MB/s[0m 

## Dataset

For our experiment, we will use the Guanaco dataset, which is a clean subset of the OpenAssistant dataset adapted to train general purpose chatbots.

The dataset can be found [here](https://huggingface.co/datasets/timdettmers/openassistant-guanaco)

In [2]:
def generate_prompt(example) -> str:
    if "instruction" in example:
        example["task"] = "text_summarization"
        instruction = example["instruction"]

    else:
        example["task"] = "qa"
        instruction = "Below is a text. Write a summary of the text."
    context = example["context"]
    response = example["response"]

    example["text"] = f"### Instructions: {instruction} ### Input: {context} ### Response: {response}"
    return example

In [3]:
from datasets import load_dataset, concatenate_datasets

dolly_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
xlsum_dataset = load_dataset("csebuetnlp/xlsum", "english", split="validation")

filtered_dolly_dataset = dolly_dataset.filter(lambda example: example["category"] in ["closed_qa", "information_extraction", "summarization"])
new_dolly_dataset = filtered_dolly_dataset.map(generate_prompt).remove_columns(["instruction", "context", "response", "category"])

xlsum_dataset = xlsum_dataset.rename_columns({"text": "context", "summary": "response"})
new_xlsum_dataset = xlsum_dataset.map(generate_prompt).remove_columns(["id", "url", "title", "response", "context"])

complete_dataset = concatenate_datasets([new_dolly_dataset, new_xlsum_dataset]).class_encode_column("task")
final_dataset = complete_dataset.train_test_split(test_size=0.2, stratify_by_column="task").remove_columns(["task"])

final_dataset

Downloading readme:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading builder script:   0%|          | 0.00/4.55k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/282M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/15011 [00:00<?, ? examples/s]

Map:   0%|          | 0/4467 [00:00<?, ? examples/s]

Map:   0%|          | 0/11535 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/16002 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 12801
    })
    test: Dataset({
        features: ['text'],
        num_rows: 3201
    })
})

## Loading the model

In this section we will load the [Falcon 7B model](https://huggingface.co/tiiuae/falcon-7b), quantize it in 4bit and attach LoRA adapters on it. Let's get started!

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer

model_name = "ybelkada/falcon-7b-sharded-bf16"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, trust_remote_code=True)
model.config.use_cache = False

Downloading (…)lve/main/config.json:   0%|          | 0.00/581 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/16.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00008.bin:   0%|          | 0.00/1.92G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00008.bin:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)l-00003-of-00008.bin:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

Downloading (…)l-00004-of-00008.bin:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

Downloading (…)l-00005-of-00008.bin:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)l-00006-of-00008.bin:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

Downloading (…)l-00007-of-00008.bin:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

Downloading (…)l-00008-of-00008.bin:   0%|          | 0.00/921M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Let's also load the tokenizer below

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

Downloading (…)okenizer_config.json:   0%|          | 0.00/180 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.73M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/281 [00:00<?, ?B/s]

Below we will load the configuration file in order to create the LoRA model. According to QLoRA paper, it is important to consider all linear layers in the transformer block for maximum performance. Therefore we will add `dense`, `dense_h_to_4_h` and `dense_4h_to_h` layers in the target modules in addition to the mixed query key value layer.

In [6]:
from peft import LoraConfig

lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ],
)

## Loading the trainer

Here we will use the [`SFTTrainer` from TRL library](https://huggingface.co/docs/trl/main/en/sft_trainer) that gives a wrapper around transformers `Trainer` to easily fine-tune models on instruction based datasets using PEFT adapters. Let's first load the training arguments below.

In [7]:
from transformers import TrainingArguments

output_dir = "./results"
per_device_train_batch_size = 4
gradient_accumulation_steps = 4
optim = "paged_adamw_32bit"
save_steps = 50
logging_steps = 20
learning_rate = 1e-4
max_grad_norm = 0.3
max_steps = 100
warmup_ratio = 0.03
lr_scheduler_type = "constant"

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    gradient_checkpointing=True,
)

Then finally pass everthing to the trainer

In [8]:
from trl import SFTTrainer

max_seq_length = 512

trainer = SFTTrainer(
    model=model,
    train_dataset=final_dataset["train"],
    eval_dataset=final_dataset["test"],
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)

Map:   0%|          | 0/12801 [00:00<?, ? examples/s]

Map:   0%|          | 0/3201 [00:00<?, ? examples/s]

We will also pre-process the model by upcasting the layer norms in float 32 for more stable training

In [9]:
for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

## Train the model

Now let's train the model! Simply call `trainer.train()`

In [10]:
trainer.train()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call `model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations.


Step,Training Loss
20,2.1322
40,1.9262
60,1.8976
80,1.9577
100,1.7285




TrainOutput(global_step=100, training_loss=1.9284332275390625, metrics={'train_runtime': 2601.3181, 'train_samples_per_second': 0.615, 'train_steps_per_second': 0.038, 'total_flos': 2.600092069880832e+16, 'train_loss': 1.9284332275390625, 'epoch': 0.12})

In [13]:
from peft import AutoPeftModelForCausalLM, get_peft_model

lora_config = LoraConfig.from_pretrained("/content/results/checkpoint-100")
model = get_peft_model(model, lora_config)

In [14]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): FalconForCausalLM(
      (transformer): FalconModel(
        (word_embeddings): Embedding(65024, 4544)
        (h): ModuleList(
          (0-31): 32 x FalconDecoderLayer(
            (self_attention): FalconAttention(
              (maybe_rotary): FalconRotaryEmbedding()
              (query_key_value): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4544, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4672, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4544, out_features=4672, bias=False)
              )
            

In [23]:
text = '### Instructions: Below is a text. Write a summary of the text. ### Input: His account is marked as being "suspended", but the BBC has learned that the US company has decided to permanently revoke his access. Twitter declined to comment, but it is understood the decision was taken after Mr Robinson was judged to have breached its "hateful conduct" policy. It had previously temporarily suspended the activist several times. Mr Robinson, whose real name is Stephen Yaxley-Lennon, left the EDL in 2013. Most recently, Twitter had blocked his account for a week at the start of March. He said at the time that the reason was that he had posted that "90% of grooming gang convictions are Muslims". The social network had removed his blue verification tick four months earlier. Mr Robinson\'s @TRobinsonNewEra account had about 413,000 followers at the time it was blocked. In addition, Twitter has acted against another account - @tommysnewspage - which had also been associated with the 35-year-old. Mr Robinson continues to operate a Facebook page, a YouTube account and a personal website. ### Response: Tommy Robinson - ex-leader of the English Defence League - has been banned from Twitter.'
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"], pad_token_id=tokenizer.eos_token_id)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

AttributeError: ignored

In [21]:
config = PeftConfig.from_pretrained("location where new model is stored")
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    #     load_in_8bit=True,
    #     device_map='auto',
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

model_inf = PeftModel.from_pretrained(model, "location where new model is stored")

{'input_ids': tensor([[19468, 29277,    37, 14695,   304,   241,  2288,    25, 14687,   241,
         11055,   275,   248,  2288,    25,   204, 19468, 21495,    37,  2591,
          1709,   304,  9673,   345,  1003,   204,    13,    94, 35180,  1882,
          1439,   480,   248, 15097,   504,  4784,   325,   248,  1571,  1438,
           504,  3571,   271, 19099, 64258,   545,  1565,    25,  5174, 16530,
           271,  2620,    23,   480,   334,   304,  8809,   248,  3241,   398,
          2629,   852,  3552, 18734,   398, 26008,   271,   413, 56122,   701,
           204,    13,    83,  9353,  3838,    13,  3244,    25,   605,   618,
          5903, 19301, 16613,   248, 22711,  1988,  1762,    25,  3552, 18734,
            23,  4678,  1059,  1536,   304, 12287,   551,   941,  1844,    24,
            55,  1611,   245,    23,  1786,   248, 11655,    55,   272,   204,
           626,    30,    25,  3892,  3358,    23,  5174,   618, 15018,   545,
          1709,   312,   241,  1356,  