# Fine-tuning Gemma using Huggingface Trainer and QLoRA
In this notebook we are fine-tuning the [google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it) model. Gemma is provided under and subject to the Gemma Terms of Use found at [ai.google.dev/gemma/terms](https://ai.google.dev/gemma/terms). This notebook was written, modifying code from [this article about fine-tuning LLama 3](https://www.datacamp.com/tutorial/llama3-fine-tuning-locally), which is highly recommeded.

## Read more
* [QLoRA](https://arxiv.org/abs/2305.14314)
* [Gemma Cookbook](https://github.com/google-gemini/gemma-cookbook)
* [example sft_qlora](https://huggingface.co/google/gemma-7b/blob/main/examples/example_sft_qlora.py)

## Troubleshooting
* If you run this notebook on Windows and receive error messages mentioning that CUDA initialization failed, make sure you have `bitsandbytes` version 0.43.2 or larger installed.
* If you run out of GPU memory, make sure you use the right hardware. This notebook was developed using an RTX 3080 mobile GPU with 16 GB of memory.

In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
from functools import partial

First, we define the model we want to fine-tune and the name under which we will store the new fine-tuned model.

In [2]:
base_model = "google/gemma-2b-it"
#"google/codegemma-1.1-7b-it"
#"google/gemma-2b"
#"google/codegemma-2b"
#"google/gemma-2b-it"
new_model = "haesleinhuepf/gemma-2b-it-bia-proof-of-concept"

## Configuration

In [3]:
torch_dtype = torch.float16
attn_implementation = "eager"

We will use the [QLoRA](https://arxiv.org/abs/2305.14314) fine-tuning scheme, simply to save memory.

In [4]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

## Initialization of model and tokenizer
Here we download and initialize the model and initialize the tokenizer.

In [5]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [6]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [7]:
tokenizer.padding_side = 'right'

In [8]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

## Dataset preparation
Next we load our dataset for fine-tuning, from its [Huggingface Hub page](https://huggingface.co/datasets/haesleinhuepf/bio-image-analysis-qa). 

In [9]:
dataset_name = "haesleinhuepf/bio-image-analysis-qa"

In [10]:
dataset_raw = load_dataset(dataset_name, split="all")

def format_chat_template(row, tokenizer):
    row_json = [{"role": "user", "content": row["question"]},
               {"role": "assistant", "content": row["answer"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

format_chat_template_partial = partial(format_chat_template, tokenizer=tokenizer)

dataset_w_template = dataset_raw.map(
    format_chat_template_partial,
    num_proc=4,
)

print(dataset_w_template['text'][3])

Map (num_proc=4):   0%|          | 0/130 [00:00<?, ? examples/s]

<|im_start|>user
How can we use indices in Python to crop images, similar to cropping lists and tuples?<|im_end|>
<|im_start|>assistant

This code imports the necessary functions from the skimage.io module. It then reads an image called "blobs.tif" and assigns it to the variable 'image'. It crops the image, taking the first 128 rows, and assigns the result to 'cropped_image1'. The cropped image is then displayed using the 'imshow' function. Lastly, a list of numbers is created called 'mylist'.

```python
from skimage.io import imread, imshow, imshow

image = imread("../../data/blobs.tif")

cropped_image1 = image[0:128]

imshow(cropped_image1);

mylist = [1,2,2,3,4,5,78]
```
<|im_end|>



We then split the data into two sets: for training and for testing.

In [12]:
dataset_train_test = dataset_w_template.train_test_split(test_size=0.3)
dataset_train_test

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'text'],
        num_rows: 91
    })
    test: Dataset({
        features: ['question', 'answer', 'text'],
        num_rows: 39
    })
})

In [13]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=10,
    eval_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="none"
)

In [14]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_train_test["train"],
    eval_dataset=dataset_train_test["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/39 [00:00<?, ? examples/s]

In [15]:
trainer.train()

Step,Training Loss,Validation Loss
90,0.5641,1.311606
180,0.2688,1.66749
270,0.0844,1.935554
360,0.0537,2.171257
450,0.0371,2.344546




TrainOutput(global_step=450, training_loss=0.53101396843791, metrics={'train_runtime': 212.1879, 'train_samples_per_second': 4.289, 'train_steps_per_second': 2.121, 'total_flos': 2065676036788224.0, 'train_loss': 0.53101396843791, 'epoch': 9.89010989010989})

In [18]:
trainer.save_model(new_model + "_temp")



## merging

In [19]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
from trl import setup_chat_format
# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)

base_model_reload = AutoModelForCausalLM.from_pretrained(
        base_model,
        return_dict=True,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
)

base_model_reload, tokenizer = setup_chat_format(base_model_reload, tokenizer)

# Merge adapter with base model
model = PeftModel.from_pretrained(base_model_reload, new_model + "_temp")

merged_model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Test model

In [20]:
messages = [{"role": "user", "content": """
Write Python code to load the image ../11a_prompt_engineering/data/blobs.tif,
segment the nuclei in it and
show the result
"""}]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto",
)

outputs = pipe(prompt, max_new_tokens=120, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM'

<|im_start|>user

Write Python code to load the image ../11a_prompt_engineering/data/blobs.tif,
segment the nuclei in it and
show the result
<|im_end|>
<|im_start|>assistant
The code imports the necessary libraries, loads an image, segments the nuclei in it, and shows the result.

```python

import pyclesperanto_prototype as cle
from skimage.io import imread
import matplotlib.pyplot as plt

image = imread('../11a_prompt_engineering/data/blobs.tif')
image.shape

cle.segment_nuclei(image)
nuclei = cle.label_segments(image)

cle.imshow(nuclei, labels=True)

```
This code will generate the following output:




In [21]:
merged_model.save_pretrained(new_model)

In [25]:
#merged_model.push_to_hub(new_model, use_temp_dir=False)