# Fine-tuning Gemma on local hardware
In this notebook we are fine-tuning the [google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it) model using [huggingface](https://huggingface.co) infrastructure, locally on our computer using a consumer-grade graphics card with 16 GB of memory. Gemma is provided under and subject to the Gemma Terms of Use found at [ai.google.dev/gemma/terms](https://ai.google.dev/gemma/terms). This notebook was written, modifying code from [this article about fine-tuning LLama 3](https://www.datacamp.com/tutorial/llama3-fine-tuning-locally), which is highly recommeded.

## Read more
* [QLoRA](https://arxiv.org/abs/2305.14314)
* [Gemma Cookbook](https://github.com/google-gemini/gemma-cookbook)
* [example sft_qlora](https://huggingface.co/google/gemma-7b/blob/main/examples/example_sft_qlora.py)

## Troubleshooting
* If you run this notebook on Windows and receive error messages mentioning that CUDA initialization failed, make sure you have `bitsandbytes` version 0.43.2 or larger installed.
* If you run out of GPU memory, make sure you use the right hardware. This notebook was developed using an RTX 3080 mobile GPU with 16 GB of memory.

In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
from functools import partial

First, we define the model we want to fine-tune and the name under which we will store the new fine-tuned model.

In [2]:
base_model = "HuggingFaceTB/SmolLM-1.7B"
#"google/codegemma-1.1-7b-it"
#"google/gemma-2b"
#"google/codegemma-2b"
#"google/gemma-2b-it"
new_model = "SmolLM-1.7B-bia-proof-of-concept"

In [3]:
eos_token = "<|endoftext|>" # depends on base model

## Configuration

In [4]:
torch_dtype = torch.float16
attn_implementation = "eager"

We will use the [QLoRA](https://arxiv.org/abs/2305.14314) fine-tuning scheme, simply to save memory.

In [5]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

## Initialization of model and tokenizer
Here we download and initialize the model and initialize the tokenizer.

In [6]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

In [8]:
tokenizer.padding_side = 'right'

In [9]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

## Dataset preparation
Next we load our dataset for fine-tuning, from its [Huggingface Hub page](https://huggingface.co/datasets/haesleinhuepf/bio-image-analysis-qa). 

In [10]:
dataset_name = "haesleinhuepf/bio-image-analysis-qa"

In [11]:
dataset_raw = load_dataset(dataset_name, split="all")

def format_chat_template(row, tokenizer):
    row_json = [{"role": "user", "content": row["question"]},
               {"role": "assistant", "content": row["answer"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

format_chat_template_partial = partial(format_chat_template, tokenizer=tokenizer)

dataset_w_template = dataset_raw.map(
    format_chat_template_partial,
    num_proc=4,
)

print(dataset_w_template['text'][3])

<|im_start|>user
How can we use indices in Python to crop images, similar to cropping lists and tuples?<|im_end|>
<|im_start|>assistant

This code imports the necessary functions from the skimage.io module. It then reads an image called "blobs.tif" and assigns it to the variable 'image'. It crops the image, taking the first 128 rows, and assigns the result to 'cropped_image1'. The cropped image is then displayed using the 'imshow' function. Lastly, a list of numbers is created called 'mylist'.

```python
from skimage.io import imread, imshow, imshow

image = imread("../../data/blobs.tif")

cropped_image1 = image[0:128]

imshow(cropped_image1);

mylist = [1,2,2,3,4,5,78]
```
<|im_end|>



In [12]:
type(dataset_w_template)

datasets.arrow_dataset.Dataset

In [15]:
# Assuming dataset_w_template is already loaded as a Dataset with a 'text' column
def append_eos_token(example):
    example['text'] += eos_token
    return example

# Apply the append_eos_token function to each example in dataset_w_template
dataset_w_template = dataset_w_template.map(append_eos_token)

Map:   0%|          | 0/130 [00:00<?, ? examples/s]

In [16]:

print(dataset_w_template['text'][3])

<|im_start|>user
How can we use indices in Python to crop images, similar to cropping lists and tuples?<|im_end|>
<|im_start|>assistant

This code imports the necessary functions from the skimage.io module. It then reads an image called "blobs.tif" and assigns it to the variable 'image'. It crops the image, taking the first 128 rows, and assigns the result to 'cropped_image1'. The cropped image is then displayed using the 'imshow' function. Lastly, a list of numbers is created called 'mylist'.

```python
from skimage.io import imread, imshow, imshow

image = imread("../../data/blobs.tif")

cropped_image1 = image[0:128]

imshow(cropped_image1);

mylist = [1,2,2,3,4,5,78]
```
<|im_end|>
<|endoftext|>


We then split the data into two sets: for training and for testing.

In [17]:
dataset_train_test = dataset_w_template.train_test_split(test_size=0.3)
dataset_train_test

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'text'],
        num_rows: 91
    })
    test: Dataset({
        features: ['question', 'answer', 'text'],
        num_rows: 39
    })
})

In [18]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=10,
    eval_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="none"
)

In [19]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_train_test["train"],
    eval_dataset=dataset_train_test["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/39 [00:00<?, ? examples/s]

In [20]:
trainer.train()

Step,Training Loss,Validation Loss
90,0.8786,1.098734
180,0.5259,1.140132
270,0.1934,1.471423
360,0.0776,1.828713
450,0.0385,1.977736


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


TrainOutput(global_step=450, training_loss=0.5048243994928069, metrics={'train_runtime': 307.9515, 'train_samples_per_second': 2.955, 'train_steps_per_second': 1.461, 'total_flos': 1735161730867200.0, 'train_loss': 0.5048243994928069, 'epoch': 9.89010989010989})

In [16]:
trainer.save_model(new_model + "_temp")

In [22]:
trainer.model.save_pretrained(new_model)
#trainer.model.push_to_hub(new_model, use_temp_dir=False)

## Testing the model
After the model is trained, we can do first tests with it.

In [21]:
messages = [{"role": "user", "content": """
Write Python code to load the image ../11a_prompt_engineering/data/blobs.tif,
segment the nuclei in it and
show the result
"""}]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto",
)

outputs = pipe(prompt, max_new_tokens=250, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausal

<|im_start|>user

Write Python code to load the image ../11a_prompt_engineering/data/blobs.tif,
segment the nuclei in it and
show the result
<|im_end|>
<|im_start|>assistant

This code uses the `imread` function from the `skimage.io` module to load an image file named '../../11a_data/blobs.tif'. It then applies Voronoi Otsu Labeling to the image using a specified sigma value of 2.5 to adjust the sensitivity of the labeling algorithm. The resulting labeled image is then displayed using `matplotlib.pyplot.imshow` function.

```python

from skimage.io import imread

image = imread("../../11a_data/blobs.tif")

labeled, num_ Nuclei = voronoi_otsu_labeling(image, sigma=2.5)

imshow(labeled, labels=True)

```

This code imports the `imread` function from the `skimage.io` module, which is used to load an image file named '../../11a_data/blobs.tif'.

The `voronoi_otsu_labeling` function from the `pyclesperanto_prototype` library is then called with the `image` and a specified `sigma` value of 2

In [24]:
new_model + "_temp"

'SmolLM-1.7B-bia-proof-of-concept_temp'

In [25]:
def prompt_hf(request, model="SmolLM-1.7B-bia-proof-of-concept_temp"):
    global prompt_hf
    import transformers
    import torch
    
    if prompt_hf._pipeline is None:    
        prompt_hf._pipeline = transformers.pipeline(
            "text-generation", model=model, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
        )
    
    return prompt_hf._pipeline(request, max_length=1000, num_return_sequences=1)[0]['generated_text']
prompt_hf._pipeline = None

In [26]:
prompt_hf("""
Write Python code to load the image ../11a_prompt_engineering/data/blobs.tif,
segment the nuclei in it and
show the result
""")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device device because they were offloaded to the cpu.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


KeyboardInterrupt: 