# Finetuning Gemma 3 4b

In [None]:
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard torchvision

# Install Gemma release branch from Hugging Face
%pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" \
  "pillow==11.1.0" \
  protobuf \
  pip install bitsandbytes trl peft

In [None]:
from huggingface_hub import login
from datasets import load_dataset
from PIL import Image
from pathlib import Path
from datasets import load_from_disk
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, TrainerCallback, TrainerControl, TrainerState
from peft import LoraConfig
from trl import SFTConfig
from trl import SFTTrainer
from peft import get_peft_model
from transformers import AutoConfig


In [None]:
# ---------------------------------------------------------------------------
# 1. Authentifizierung und Datensatz laden
# ---------------------------------------------------------------------------
login("hf_login")  # Token hier sicher speichern
ds = load_dataset("flaviagiammarino/path-vqa")
dataset = ds['train']
eval_dataset = ds['validation']

In [None]:
# ---------------------------------------------------------------------------
# 2. Validierungs‑Subset definieren
# ---------------------------------------------------------------------------
indices_file = Path("validation_subset_indices.txt")
with indices_file.open("r", encoding="utf-8") as f:
    selected_indices = [int(line.strip()) for line in f if line.strip()]
 
# Filtere den Validierungsdatensatz
filtered_val_dataset = eval_dataset.select(selected_indices)
 
print("Anzahl der ausgewählten Einträge:", len(filtered_val_dataset))
eval_dataset = filtered_val_dataset

In [None]:
# ---------------------------------------------------------------------------
# 3. Hilfsfunktionen
# ---------------------------------------------------------------------------
# System message
system_message = "You are a medical pathology expert specializing in visual diagnosis. Your task is to answer questions based only on the provided histopathology image and the question. Do not use any external knowledge or assumptions. Your answers must be medically accurate, concise, and grounded in visible features of the image."

# Konvertiert ein Datensatz‑Sample in das OAI Chat‑Format.
def format_data(sample):
    return {
        "messages": [
    {"role": "system", "content": [{"type": "text", "text": system_message}]},
    {"role": "user", "content": [
        {"type": "image", "image": sample["image"]},
        {"type": "text", "text": f"Question: {sample['question']}\nAnswer based only on the image."}
    ]},
    {"role": "assistant", "content": [{"type": "text", "text": sample["answer"]}]}
],
    }


def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

In [None]:
# ---------------------------------------------------------------------------
# 4. Daten vorbereiten
# ---------------------------------------------------------------------------
dataset = [format_data(sample) for sample in dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]

In [None]:
# ---------------------------------------------------------------------------
# 5. Modell konfigurieren
# ---------------------------------------------------------------------------

# Hugging Face model id
model_id = "google/gemma-3-4b-it" 
config = AutoConfig.from_pretrained(model_id)
config.text_config.use_cache = False

# Definition von init Parametern
model_kwargs = dict(
    attn_implementation="eager",  
    torch_dtype=torch.bfloat16,
    device_map="auto",
    text_config=config.text_config
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"], 
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Lade Modell und Prozessor
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
processor.tokenizer.padding_side = 'right'

In [None]:
# LoRA‑Adapter konfigurieren
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=['q_proj','v_proj'],
    task_type="CAUSAL_LM")   

print(f"Befor adapter parameters: {model.num_parameters()}")
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
# ---------------------------------------------------------------------------
# 6. Training konfigurieren
# ---------------------------------------------------------------------------
args = SFTConfig(
    output_dir="gemma-product-description",     
    num_train_epochs=2,                         
    per_device_train_batch_size=2,              
    per_device_eval_batch_size = 2,             
    gradient_accumulation_steps=1,              
    eval_steps = 2000,
    eval_strategy = 'steps',
    gradient_checkpointing=True,                
    optim="paged_adamw_32bit",                  
    logging_steps=1000,                        
    save_steps = 2000,
    save_strategy="steps",                      
    learning_rate=2e-4,                         
    bf16=True,                                  
    metric_for_best_model = 'eval_loss',
    load_best_model_at_end = True,
    max_grad_norm=0.3,                         
    warmup_ratio=0.1,                           
    max_seq_length=128,                        
    lr_scheduler_type="constant",               
    push_to_hub=True,                           
    report_to="tensorboard",                    
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  
    dataset_text_field="",                      
    dataset_kwargs={"skip_prepare_dataset": True},  
)
args.remove_unused_columns = False 

In [None]:
# Erstellt einen Batch aus Text‑ und Bild‑Inputs für das Modell
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)
    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    eval_dataset = eval_dataset,
    processing_class=processor,
    data_collator=collate_fn,
    peft_config = peft_config
)

In [None]:
# ---------------------------------------------------------------------------
# 7. Training und Evaluation
# ---------------------------------------------------------------------------
print("-"*30)
print("Evaluating")
metric = trainer.evaluate()
print(metric)
print("-"*30)
print("Training")
trainer.train()
print("-"*30)
print("Saving Model to Hugging Face Hub")
trainer.save_model()
print("-"*30)
print("Congratulations you have succsesfully finetuned Gemma 3 4b it!")

In [None]:
# ---------------------------------------------------------------------------
# 8. LoRA‑Gewichte mergen
# ---------------------------------------------------------------------------

# Basemodel Laden
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge von Basemodel und LoRA Gewichten
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model_batchsize_2", safe_serialization=True, max_shard_size="2GB")
processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model_batchsize_2")